@@ -320,12 +320,15 @@ def get_label_list(labels):
320
320
label_list .sort ()
321
321
return label_list
322
322
323
- if isinstance (features [label_column_name ].feature , ClassLabel ):
323
+ # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
324
+ # Otherwise, we have to get the list of labels manually.
325
+ labels_are_int = isinstance (features [label_column_name ].feature , ClassLabel )
326
+ if labels_are_int :
324
327
label_list = features [label_column_name ].feature .names
325
- label_keys = list ( range (len (label_list )))
328
+ label_to_id = { i : i for i in range (len (label_list ))}
326
329
else :
327
330
label_list = get_label_list (raw_datasets ["train" ][label_column_name ])
328
- label_keys = label_list
331
+ label_to_id = { l : i for i , l in enumerate ( label_list )}
329
332
330
333
num_labels = len (label_list )
331
334
@@ -365,21 +368,26 @@ def get_label_list(labels):
365
368
366
369
model .resize_token_embeddings (len (tokenizer ))
367
370
371
+ # Model has labels -> use them.
368
372
if model .config .label2id != PretrainedConfig (num_labels = num_labels ).label2id :
369
- label_name_to_id = {k : v for k , v in model .config .label2id .items ()}
370
- if list (sorted (label_name_to_id .keys ())) == list (sorted (label_list )):
371
- label_to_id = {k : int (label_name_to_id [k ]) for k in label_keys }
373
+ if list (sorted (model .config .label2id .keys ())) == list (sorted (label_list )):
374
+ # Reorganize `label_list` to match the ordering of the model.
375
+ if labels_are_int :
376
+ label_to_id = {i : int (model .config .label2id [l ]) for i , l in enumerate (label_list )}
377
+ label_list = [model .config .id2label [i ] for i in range (num_labels )]
378
+ else :
379
+ label_list = [model .config .id2label [i ] for i in range (num_labels )]
380
+ label_to_id = {l : i for i , l in enumerate (label_list )}
372
381
else :
373
382
logger .warning (
374
383
"Your model seems to have been trained with labels, but they don't match the dataset: " ,
375
- f"model labels: { list (sorted (label_name_to_id .keys ()))} , dataset labels: { list (sorted (label_list ))} ."
384
+ f"model labels: { list (sorted (model . config . label2id .keys ()))} , dataset labels: { list (sorted (label_list ))} ."
376
385
"\n Ignoring the model labels as a result." ,
377
386
)
378
- else :
379
- label_to_id = {k : i for i , k in enumerate (label_keys )}
380
387
381
- model .config .label2id = label_to_id
382
- model .config .id2label = {i : l for l , i in label_to_id .items ()}
388
+ # Set the correspondences label/ID inside the model config
389
+ model .config .label2id = {l : i for i , l in enumerate (label_list )}
390
+ model .config .id2label = {i : l for i , l in enumerate (label_list )}
383
391
384
392
# Map that sends B-Xxx label to its I-Xxx counterpart
385
393
b_to_i_label = []
0 commit comments