Skip to content

Commit e364387

Browse files
author
Daniel King
committed
Add back confidences and adjust default params
1 parent 7389745 commit e364387

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

scripts/train.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
export SEED=13270
3+
export SEED=15270
44
export PYTORCH_SEED=`expr $SEED / 10`
55
export NUMPY_SEED=`expr $PYTORCH_SEED / 10`
66

@@ -20,13 +20,13 @@ export WITH_CRF=false # CRF only works for the baseline
2020
# training params
2121
export cuda_device=0
2222
export BATCH_SIZE=4
23-
export LR=2e-5
24-
export TRAINING_DATA_INSTANCES=2000
25-
export NUM_EPOCHS=4
23+
export LR=5e-5
24+
export TRAINING_DATA_INSTANCES=1668
25+
export NUM_EPOCHS=2
2626

2727
# limit number of sentneces per examples, and number of words per sentence. This is dataset dependant
2828
export MAX_SENT_PER_EXAMPLE=10
29-
export SENT_MAX_LEN=40
29+
export SENT_MAX_LEN=80
3030

3131
# this is for the evaluation of the summarization dataset
3232
export SCI_SUM=false

sequential_sentence_classification/dataset_reader.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def read_one_example(self, json_dict):
7575
else:
7676
labels = None
7777

78+
confidences = json_dict.get("confs", None)
79+
7880
additional_features = None
7981
if self.sci_sum:
8082
if self.sci_sum_fake_scores:
@@ -98,18 +100,19 @@ def read_one_example(self, json_dict):
98100
if len(sentences) == 0:
99101
return []
100102

101-
for sentences_loop, labels_loop, additional_features_loop in \
102-
self.enforce_max_sent_per_example(sentences, labels, additional_features):
103+
for sentences_loop, labels_loop, confidences_loop, additional_features_loop in \
104+
self.enforce_max_sent_per_example(sentences, labels, confidences, additional_features):
103105

104106
instance = self.text_to_instance(
105107
sentences=sentences_loop,
106108
labels=labels_loop,
109+
confidences=confidences_loop,
107110
additional_features=additional_features_loop,
108111
)
109112
instances.append(instance)
110113
return instances
111114

112-
def enforce_max_sent_per_example(self, sentences, labels=None, additional_features=None):
115+
def enforce_max_sent_per_example(self, sentences, labels=None, confidences=None, additional_features=None):
113116
"""
114117
Splits examples with len(sentences) > self.max_sent_per_example into multiple smaller examples
115118
with len(sentences) <= self.max_sent_per_example.
@@ -121,20 +124,24 @@ def enforce_max_sent_per_example(self, sentences, labels=None, additional_featur
121124
"""
122125
if labels is not None:
123126
assert len(sentences) == len(labels)
127+
if confidences is not None:
128+
assert len(sentences) == len(confidences)
124129
if additional_features is not None:
125130
assert len(sentences) == len(additional_features)
126131

127132
if len(sentences) > self.max_sent_per_example and self.max_sent_per_example > 0:
128133
i = len(sentences) // 2
129134
l1 = self.enforce_max_sent_per_example(
130135
sentences[:i], None if labels is None else labels[:i],
136+
None if confidences is None else confidences[:i],
131137
None if additional_features is None else additional_features[:i])
132138
l2 = self.enforce_max_sent_per_example(
133139
sentences[i:], None if labels is None else labels[i:],
140+
None if confidences is None else confidences[i:],
134141
None if additional_features is None else additional_features[i:])
135142
return l1 + l2
136143
else:
137-
return [(sentences, labels, additional_features)]
144+
return [(sentences, labels, confidences, additional_features)]
138145

139146
def is_bad_sentence(self, sentence: str):
140147
if len(sentence) > 10 and len(sentence) < 600:
@@ -171,10 +178,13 @@ def filter_bad_sci_sum_sentences(self, sentences, labels):
171178
def text_to_instance(self,
172179
sentences: List[str],
173180
labels: List[str] = None,
181+
confidences: List[float] = None,
174182
additional_features: List[float] = None,
175183
) -> Instance:
176184
if not self.predict:
177185
assert len(sentences) == len(labels)
186+
if confidences is not None:
187+
assert len(sentences) == len(confidences)
178188
if additional_features is not None:
179189
assert len(sentences) == len(additional_features)
180190

@@ -209,6 +219,8 @@ def text_to_instance(self,
209219
LabelField(str(label)+"_label") for label in labels
210220
])
211221

222+
if confidences is not None:
223+
fields['confidences'] = ArrayField(np.array(confidences))
212224
if additional_features is not None:
213225
fields["additional_features"] = ArrayField(np.array(additional_features))
214226

sequential_sentence_classification/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self, vocab: Vocabulary,
7373
def forward(self, # type: ignore
7474
sentences: torch.LongTensor,
7575
labels: torch.IntTensor = None,
76+
confidences: torch.Tensor = None,
7677
additional_features: torch.Tensor = None,
7778
) -> Dict[str, torch.Tensor]:
7879
# pylint: disable=arguments-differ
@@ -120,6 +121,9 @@ def forward(self, # type: ignore
120121

121122
labels = labels[labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch
122123
assert labels.dim() == 1
124+
if confidences is not None:
125+
confidences = confidences[labels_mask]
126+
assert confidences.dim() == 1
123127
if additional_features is not None:
124128
additional_features = additional_features[labels_mask]
125129
assert additional_features.dim() == 2
@@ -132,6 +136,13 @@ def forward(self, # type: ignore
132136
# We are ignoring this problem for now.
133137
# TODO: fix, at least for testing
134138

139+
# do the same for `confidences`
140+
if confidences is not None:
141+
num_confidences = confidences.shape[0]
142+
if num_confidences != num_sentences:
143+
assert num_confidences > num_sentences
144+
confidences = confidences[:num_sentences]
145+
135146
# and for `additional_features`
136147
if additional_features is not None:
137148
num_additional_features = additional_features.shape[0]
@@ -141,6 +152,8 @@ def forward(self, # type: ignore
141152

142153
# similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
143154
labels = labels.unsqueeze(dim=0)
155+
if confidences is not None:
156+
confidences = confidences.unsqueeze(dim=0)
144157
if additional_features is not None:
145158
additional_features = additional_features.unsqueeze(dim=0)
146159
else:
@@ -185,6 +198,8 @@ def forward(self, # type: ignore
185198

186199
if not self.with_crf:
187200
label_loss = self.loss(flattened_logits.squeeze(), flattened_gold)
201+
if confidences is not None:
202+
label_loss = label_loss * confidences.type_as(label_loss).view(-1)
188203
label_loss = label_loss.mean()
189204
flattened_probs = torch.softmax(flattened_logits, dim=-1)
190205
else:

0 commit comments

Comments
 (0)