25
25
from typing import Optional
26
26
27
27
import numpy as np
28
- from datasets import ClassLabel , load_dataset
29
- from seqeval .metrics import accuracy_score , f1_score , precision_score , recall_score
28
+ from datasets import ClassLabel , load_dataset , load_metric
30
29
31
30
import transformers
32
31
from transformers import (
@@ -124,6 +123,10 @@ class DataTrainingArguments:
124
123
"one (in which case the other tokens will have a padding index)."
125
124
},
126
125
)
126
+ return_entity_level_metrics : bool = field (
127
+ default = False ,
128
+ metadata = {"help" : "Whether to return all the entity levels during evaluation or just the overall ones." },
129
+ )
127
130
128
131
def __post_init__ (self ):
129
132
if self .dataset_name is None and self .train_file is None and self .validation_file is None :
@@ -323,6 +326,8 @@ def tokenize_and_align_labels(examples):
323
326
data_collator = DataCollatorForTokenClassification (tokenizer )
324
327
325
328
# Metrics
329
+ metric = load_metric ("seqeval" )
330
+
326
331
def compute_metrics (p ):
327
332
predictions , labels = p
328
333
predictions = np .argmax (predictions , axis = 2 )
@@ -337,12 +342,24 @@ def compute_metrics(p):
337
342
for prediction , label in zip (predictions , labels )
338
343
]
339
344
340
- return {
341
- "accuracy_score" : accuracy_score (true_labels , true_predictions ),
342
- "precision" : precision_score (true_labels , true_predictions ),
343
- "recall" : recall_score (true_labels , true_predictions ),
344
- "f1" : f1_score (true_labels , true_predictions ),
345
- }
345
+ results = metric .compute (predictions = true_predictions , references = true_labels )
346
+ if data_args .return_entity_level_metrics :
347
+ # Unpack nested dictionaries
348
+ final_results = {}
349
+ for key , value in results .items ():
350
+ if isinstance (value , dict ):
351
+ for n , v in value .items ():
352
+ final_results [f"{ key } _{ n } " ] = v
353
+ else :
354
+ final_results [key ] = value
355
+ return final_results
356
+ else :
357
+ return {
358
+ "precision" : results ["overall_precision" ],
359
+ "recall" : results ["overall_recall" ],
360
+ "f1" : results ["overall_f1" ],
361
+ "accuracy" : results ["overall_accuracy" ],
362
+ }
346
363
347
364
# Initialize our Trainer
348
365
trainer = Trainer (
0 commit comments