Skip to content

Commit 46ed56c

Browse files
authored
Switch metrics in run_ner to datasets (#9567)
* Switch metrics in run_ner to datasets * Add flag to return all metrics * Upstream (and rename) sortish_sampler * Revert "Upstream (and rename) sortish_sampler" This reverts commit e07d0dc.
1 parent 5e1bea4 commit 46ed56c

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

examples/test_examples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_run_ner(self):
184184

185185
with patch.object(sys, "argv", testargs):
186186
result = run_ner.main()
187-
self.assertGreaterEqual(result["eval_accuracy_score"], 0.75)
187+
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
188188
self.assertGreaterEqual(result["eval_precision"], 0.75)
189189
self.assertLess(result["eval_loss"], 0.5)
190190

examples/token-classification/run_ner.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
from typing import Optional
2626

2727
import numpy as np
28-
from datasets import ClassLabel, load_dataset
29-
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
28+
from datasets import ClassLabel, load_dataset, load_metric
3029

3130
import transformers
3231
from transformers import (
@@ -124,6 +123,10 @@ class DataTrainingArguments:
124123
"one (in which case the other tokens will have a padding index)."
125124
},
126125
)
126+
return_entity_level_metrics: bool = field(
127+
default=False,
128+
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
129+
)
127130

128131
def __post_init__(self):
129132
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -323,6 +326,8 @@ def tokenize_and_align_labels(examples):
323326
data_collator = DataCollatorForTokenClassification(tokenizer)
324327

325328
# Metrics
329+
metric = load_metric("seqeval")
330+
326331
def compute_metrics(p):
327332
predictions, labels = p
328333
predictions = np.argmax(predictions, axis=2)
@@ -337,12 +342,24 @@ def compute_metrics(p):
337342
for prediction, label in zip(predictions, labels)
338343
]
339344

340-
return {
341-
"accuracy_score": accuracy_score(true_labels, true_predictions),
342-
"precision": precision_score(true_labels, true_predictions),
343-
"recall": recall_score(true_labels, true_predictions),
344-
"f1": f1_score(true_labels, true_predictions),
345-
}
345+
results = metric.compute(predictions=true_predictions, references=true_labels)
346+
if data_args.return_entity_level_metrics:
347+
# Unpack nested dictionaries
348+
final_results = {}
349+
for key, value in results.items():
350+
if isinstance(value, dict):
351+
for n, v in value.items():
352+
final_results[f"{key}_{n}"] = v
353+
else:
354+
final_results[key] = value
355+
return final_results
356+
else:
357+
return {
358+
"precision": results["overall_precision"],
359+
"recall": results["overall_recall"],
360+
"f1": results["overall_f1"],
361+
"accuracy": results["overall_accuracy"],
362+
}
346363

347364
# Initialize our Trainer
348365
trainer = Trainer(

0 commit comments

Comments
 (0)