Skip to content

Commit b623ddc

Browse files
authored
Pass kwargs to configuration (#3147)
* Pass kwargs to configuration * Setter * test
1 parent 0001d05 commit b623ddc

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

src/transformers/configuration_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,18 @@ def __init__(self, **kwargs):
9898
logger.error("Can't set {} with value {} for {}".format(key, value, self))
9999
raise err
100100

101+
@property
102+
def num_labels(self):
103+
return self._num_labels
104+
105+
@num_labels.setter
106+
def num_labels(self, num_labels):
107+
self._num_labels = num_labels
108+
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
109+
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
110+
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
111+
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
112+
101113
def save_pretrained(self, save_directory):
102114
"""
103115
Save a configuration object to the directory `save_directory`, so that it

tests/test_configuration_common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,18 @@ def create_and_test_config_from_and_save_pretrained(self):
5757

5858
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
5959

60+
def create_and_test_config_with_num_labels(self):
61+
config = self.config_class(**self.inputs_dict, num_labels=5)
62+
self.parent.assertEqual(len(config.id2label), 5)
63+
self.parent.assertEqual(len(config.label2id), 5)
64+
65+
config.num_labels = 3
66+
self.parent.assertEqual(len(config.id2label), 3)
67+
self.parent.assertEqual(len(config.label2id), 3)
68+
6069
def run_common_tests(self):
6170
self.create_and_test_config_common_properties()
6271
self.create_and_test_config_to_json_string()
6372
self.create_and_test_config_to_json_file()
6473
self.create_and_test_config_from_and_save_pretrained()
74+
self.create_and_test_config_with_num_labels()

0 commit comments

Comments
 (0)