Skip to content

Commit 45cac3f

Browse files
authored
Fix labels stored in model config for token classification examples (#15482)
* Playing * Properly set labels in model config for token classification example * Port to run_ner_no_trainer * Quality
1 parent c74f3d4 commit 45cac3f

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

examples/pytorch/token-classification/run_ner.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,15 @@ def get_label_list(labels):
295295
label_list.sort()
296296
return label_list
297297

298-
if isinstance(features[label_column_name].feature, ClassLabel):
298+
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
299+
# Otherwise, we have to get the list of labels manually.
300+
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
301+
if labels_are_int:
299302
label_list = features[label_column_name].feature.names
300-
label_keys = list(range(len(label_list)))
303+
label_to_id = {i: i for i in range(len(label_list))}
301304
else:
302305
label_list = get_label_list(raw_datasets["train"][label_column_name])
303-
label_keys = label_list
306+
label_to_id = {l: i for i, l in enumerate(label_list)}
304307

305308
num_labels = len(label_list)
306309

@@ -354,21 +357,26 @@ def get_label_list(labels):
354357
"requirement"
355358
)
356359

360+
# Model has labels -> use them.
357361
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
358-
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
359-
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
360-
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
362+
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
363+
# Reorganize `label_list` to match the ordering of the model.
364+
if labels_are_int:
365+
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
366+
label_list = [model.config.id2label[i] for i in range(num_labels)]
367+
else:
368+
label_list = [model.config.id2label[i] for i in range(num_labels)]
369+
label_to_id = {l: i for i, l in enumerate(label_list)}
361370
else:
362371
logger.warning(
363372
"Your model seems to have been trained with labels, but they don't match the dataset: ",
364-
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
373+
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
365374
"\nIgnoring the model labels as a result.",
366375
)
367-
else:
368-
label_to_id = {k: i for i, k in enumerate(label_keys)}
369376

370-
model.config.label2id = label_to_id
371-
model.config.id2label = {i: l for l, i in label_to_id.items()}
377+
# Set the correspondences label/ID inside the model config
378+
model.config.label2id = {l: i for i, l in enumerate(label_list)}
379+
model.config.id2label = {i: l for i, l in enumerate(label_list)}
372380

373381
# Map that sends B-Xxx label to its I-Xxx counterpart
374382
b_to_i_label = []

examples/pytorch/token-classification/run_ner_no_trainer.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,15 @@ def get_label_list(labels):
320320
label_list.sort()
321321
return label_list
322322

323-
if isinstance(features[label_column_name].feature, ClassLabel):
323+
# If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
324+
# Otherwise, we have to get the list of labels manually.
325+
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
326+
if labels_are_int:
324327
label_list = features[label_column_name].feature.names
325-
label_keys = list(range(len(label_list)))
328+
label_to_id = {i: i for i in range(len(label_list))}
326329
else:
327330
label_list = get_label_list(raw_datasets["train"][label_column_name])
328-
label_keys = label_list
331+
label_to_id = {l: i for i, l in enumerate(label_list)}
329332

330333
num_labels = len(label_list)
331334

@@ -365,21 +368,26 @@ def get_label_list(labels):
365368

366369
model.resize_token_embeddings(len(tokenizer))
367370

371+
# Model has labels -> use them.
368372
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
369-
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
370-
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
371-
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
373+
if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
374+
# Reorganize `label_list` to match the ordering of the model.
375+
if labels_are_int:
376+
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
377+
label_list = [model.config.id2label[i] for i in range(num_labels)]
378+
else:
379+
label_list = [model.config.id2label[i] for i in range(num_labels)]
380+
label_to_id = {l: i for i, l in enumerate(label_list)}
372381
else:
373382
logger.warning(
374383
"Your model seems to have been trained with labels, but they don't match the dataset: ",
375-
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
384+
f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
376385
"\nIgnoring the model labels as a result.",
377386
)
378-
else:
379-
label_to_id = {k: i for i, k in enumerate(label_keys)}
380387

381-
model.config.label2id = label_to_id
382-
model.config.id2label = {i: l for l, i in label_to_id.items()}
388+
# Set the correspondences label/ID inside the model config
389+
model.config.label2id = {l: i for i, l in enumerate(label_list)}
390+
model.config.id2label = {i: l for i, l in enumerate(label_list)}
383391

384392
# Map that sends B-Xxx label to its I-Xxx counterpart
385393
b_to_i_label = []

0 commit comments

Comments
 (0)