Skip to content

Commit ba8b1f4

Browse files
authored
Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
1 parent 97ccf67 commit ba8b1f4

26 files changed

+188
-72
lines changed

src/transformers/modeling_flax_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
387387
# get abs dir
388388
save_directory = os.path.abspath(save_directory)
389389
# save config as well
390+
self.config.architectures = [self.__class__.__name__[4:]]
390391
self.config.save_pretrained(save_directory)
391392

392393
# save model

src/transformers/modeling_tf_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,7 @@ def save_pretrained(self, save_directory, saved_model=False, version=1):
10371037
logger.info(f"Saved model created in {saved_model_dir}")
10381038

10391039
# Save configuration file
1040+
self.config.architectures = [self.__class__.__name__[2:]]
10401041
self.config.save_pretrained(save_directory)
10411042

10421043
# If we save using the predefined names, we can load using `from_pretrained`

src/transformers/models/auto/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
_import_structure = {
25+
"auto_factory": ["get_values"],
2526
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
2627
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
2728
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
@@ -104,6 +105,7 @@
104105

105106

106107
if TYPE_CHECKING:
108+
from .auto_factory import get_values
107109
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
108110
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
109111
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer

src/transformers/models/auto/auto_factory.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,26 @@
328328
"""
329329

330330

331+
def _get_model_class(config, model_mapping):
332+
supported_models = model_mapping[type(config)]
333+
if not isinstance(supported_models, (list, tuple)):
334+
return supported_models
335+
336+
name_to_model = {model.__name__: model for model in supported_models}
337+
architectures = getattr(config, "architectures", [])
338+
for arch in architectures:
339+
if arch in name_to_model:
340+
return name_to_model[arch]
341+
elif f"TF{arch}" in name_to_model:
342+
return name_to_model[f"TF{arch}"]
343+
elif f"Flax{arch}" in name_to_model:
344+
return name_to_model[f"Flax{arch}"]
345+
346+
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
347+
# defaults.
348+
return supported_models[0]
349+
350+
331351
class _BaseAutoModelClass:
332352
# Base class for auto models.
333353
_model_mapping = None
@@ -341,7 +361,8 @@ def __init__(self):
341361

342362
def from_config(cls, config, **kwargs):
343363
if type(config) in cls._model_mapping.keys():
344-
return cls._model_mapping[type(config)](config, **kwargs)
364+
model_class = _get_model_class(config, cls._model_mapping)
365+
return model_class(config, **kwargs)
345366
raise ValueError(
346367
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
347368
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
@@ -356,9 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
356377
)
357378

358379
if type(config) in cls._model_mapping.keys():
359-
return cls._model_mapping[type(config)].from_pretrained(
360-
pretrained_model_name_or_path, *model_args, config=config, **kwargs
361-
)
380+
model_class = _get_model_class(config, cls._model_mapping)
381+
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
362382
raise ValueError(
363383
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
364384
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
418438
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
419439
new_class.from_pretrained = classmethod(from_pretrained)
420440
return new_class
441+
442+
443+
def get_values(model_mapping):
444+
result = []
445+
for model in model_mapping.values():
446+
if isinstance(model, (list, tuple)):
447+
result += list(model)
448+
else:
449+
result.append(model)
450+
451+
return result

src/transformers/models/auto/configuration_auto.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -247,29 +247,38 @@
247247
)
248248

249249

250+
def _get_class_name(model_class):
251+
if isinstance(model_class, (list, tuple)):
252+
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
253+
return f":class:`~transformers.{model_class.__name__}`"
254+
255+
250256
def _list_model_options(indent, config_to_class=None, use_model_types=True):
251257
if config_to_class is None and not use_model_types:
252258
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
253259
if use_model_types:
254260
if config_to_class is None:
255-
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()}
261+
model_type_to_name = {
262+
model_type: f":class:`~transformers.{config.__name__}`"
263+
for model_type, config in CONFIG_MAPPING.items()
264+
}
256265
else:
257266
model_type_to_name = {
258-
model_type: config_to_class[config].__name__
267+
model_type: _get_class_name(config_to_class[config])
259268
for model_type, config in CONFIG_MAPPING.items()
260269
if config in config_to_class
261270
}
262271
lines = [
263-
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
272+
f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
264273
for model_type in sorted(model_type_to_name.keys())
265274
]
266275
else:
267-
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
276+
config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
268277
config_to_model_name = {
269278
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
270279
}
271280
lines = [
272-
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
281+
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
273282
for config_name in sorted(config_to_name.keys())
274283
]
275284
return "\n".join(lines)

src/transformers/models/auto/modeling_auto.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
)
125125
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
126126
from ..funnel.modeling_funnel import (
127+
FunnelBaseModel,
127128
FunnelForMaskedLM,
128129
FunnelForMultipleChoice,
129130
FunnelForPreTraining,
@@ -377,7 +378,7 @@
377378
(CTRLConfig, CTRLModel),
378379
(ElectraConfig, ElectraModel),
379380
(ReformerConfig, ReformerModel),
380-
(FunnelConfig, FunnelModel),
381+
(FunnelConfig, (FunnelModel, FunnelBaseModel)),
381382
(LxmertConfig, LxmertModel),
382383
(BertGenerationConfig, BertGenerationEncoder),
383384
(DebertaConfig, DebertaModel),

src/transformers/models/auto/modeling_tf_auto.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
TFFlaubertWithLMHeadModel,
9292
)
9393
from ..funnel.modeling_tf_funnel import (
94+
TFFunnelBaseModel,
9495
TFFunnelForMaskedLM,
9596
TFFunnelForMultipleChoice,
9697
TFFunnelForPreTraining,
@@ -242,7 +243,7 @@
242243
(XLMConfig, TFXLMModel),
243244
(CTRLConfig, TFCTRLModel),
244245
(ElectraConfig, TFElectraModel),
245-
(FunnelConfig, TFFunnelModel),
246+
(FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
246247
(DPRConfig, TFDPRQuestionEncoder),
247248
(MPNetConfig, TFMPNetModel),
248249
(BartConfig, TFBartModel),

tests/test_modeling_albert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
from transformers import is_torch_available
20+
from transformers.models.auto import get_values
2021
from transformers.testing_utils import require_torch, slow, torch_device
2122

2223
from .test_configuration_common import ConfigTester
@@ -234,7 +235,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
234235
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
235236

236237
if return_labels:
237-
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
238+
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
238239
inputs_dict["labels"] = torch.zeros(
239240
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
240241
)

tests/test_modeling_auto.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
16+
import copy
17+
import tempfile
1718
import unittest
1819

1920
from transformers import is_torch_available
@@ -46,6 +47,8 @@
4647
BertForSequenceClassification,
4748
BertForTokenClassification,
4849
BertModel,
50+
FunnelBaseModel,
51+
FunnelModel,
4952
GPT2Config,
5053
GPT2LMHeadModel,
5154
RobertaForMaskedLM,
@@ -218,6 +221,21 @@ def test_from_identifier_from_model_type(self):
218221
self.assertEqual(model.num_parameters(), 14410)
219222
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
220223

224+
def test_from_pretrained_with_tuple_values(self):
225+
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
226+
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
227+
self.assertIsInstance(model, FunnelModel)
228+
229+
config = copy.deepcopy(model.config)
230+
config.architectures = ["FunnelBaseModel"]
231+
model = AutoModel.from_config(config)
232+
self.assertIsInstance(model, FunnelBaseModel)
233+
234+
with tempfile.TemporaryDirectory() as tmp_dir:
235+
model.save_pretrained(tmp_dir)
236+
model = AutoModel.from_pretrained(tmp_dir)
237+
self.assertIsInstance(model, FunnelBaseModel)
238+
221239
def test_parents_and_children_in_mappings(self):
222240
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
223241
# by the parents and will return the wrong configuration type when using auto models
@@ -242,6 +260,12 @@ def test_parents_and_children_in_mappings(self):
242260
assert not issubclass(
243261
child_config, parent_config
244262
), f"{child_config.__name__} is child of {parent_config.__name__}"
245-
assert not issubclass(
246-
child_model, parent_model
247-
), f"{child_config.__name__} is child of {parent_config.__name__}"
263+
264+
# Tuplify child_model and parent_model since some of them could be tuples.
265+
if not isinstance(child_model, (list, tuple)):
266+
child_model = (child_model,)
267+
if not isinstance(parent_model, (list, tuple)):
268+
parent_model = (parent_model,)
269+
270+
for child, parent in [(a, b) for a in child_model for b in parent_model]:
271+
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"

tests/test_modeling_bert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
from transformers import is_torch_available
20+
from transformers.models.auto import get_values
2021
from transformers.testing_utils import require_torch, slow, torch_device
2122

2223
from .test_configuration_common import ConfigTester
@@ -444,7 +445,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
444445
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
445446

446447
if return_labels:
447-
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
448+
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
448449
inputs_dict["labels"] = torch.zeros(
449450
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
450451
)

tests/test_modeling_big_bird.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from tests.test_modeling_common import floats_tensor
2121
from transformers import is_torch_available
22+
from transformers.models.auto import get_values
2223
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
2324
from transformers.testing_utils import require_torch, slow, torch_device
2425

@@ -458,7 +459,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
458459
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
459460

460461
if return_labels:
461-
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
462+
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
462463
inputs_dict["labels"] = torch.zeros(
463464
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
464465
)

tests/test_modeling_common.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from transformers import is_torch_available
2626
from transformers.file_utils import WEIGHTS_NAME
27+
from transformers.models.auto import get_values
2728
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
2829

2930

@@ -79,7 +80,7 @@ class ModelTesterMixin:
7980

8081
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
8182
inputs_dict = copy.deepcopy(inputs_dict)
82-
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
83+
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
8384
inputs_dict = {
8485
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
8586
if isinstance(v, torch.Tensor) and v.ndim > 1
@@ -88,28 +89,28 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
8889
}
8990

9091
if return_labels:
91-
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
92+
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
9293
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
93-
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
94+
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
9495
inputs_dict["start_positions"] = torch.zeros(
9596
self.model_tester.batch_size, dtype=torch.long, device=torch_device
9697
)
9798
inputs_dict["end_positions"] = torch.zeros(
9899
self.model_tester.batch_size, dtype=torch.long, device=torch_device
99100
)
100101
elif model_class in [
101-
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(),
102-
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(),
103-
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(),
102+
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
103+
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
104+
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
104105
]:
105106
inputs_dict["labels"] = torch.zeros(
106107
self.model_tester.batch_size, dtype=torch.long, device=torch_device
107108
)
108109
elif model_class in [
109-
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
110-
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
111-
*MODEL_FOR_MASKED_LM_MAPPING.values(),
112-
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
110+
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
111+
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
112+
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
113+
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
113114
]:
114115
inputs_dict["labels"] = torch.zeros(
115116
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
@@ -229,7 +230,7 @@ def test_training(self):
229230
config.return_dict = True
230231

231232
for model_class in self.all_model_classes:
232-
if model_class in MODEL_MAPPING.values():
233+
if model_class in get_values(MODEL_MAPPING):
233234
continue
234235
model = model_class(config)
235236
model.to(torch_device)
@@ -248,7 +249,7 @@ def test_training_gradient_checkpointing(self):
248249
config.return_dict = True
249250

250251
for model_class in self.all_model_classes:
251-
if model_class in MODEL_MAPPING.values():
252+
if model_class in get_values(MODEL_MAPPING):
252253
continue
253254
model = model_class(config)
254255
model.to(torch_device)
@@ -312,7 +313,7 @@ def test_attention_outputs(self):
312313
if "labels" in inputs_dict:
313314
correct_outlen += 1 # loss is added to beginning
314315
# Question Answering model returns start_logits and end_logits
315-
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
316+
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
316317
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
317318
if "past_key_values" in outputs:
318319
correct_outlen += 1 # past_key_values have been returned

tests/test_modeling_convbert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from tests.test_modeling_common import floats_tensor
2121
from transformers import is_torch_available
22+
from transformers.models.auto import get_values
2223
from transformers.testing_utils import require_torch, slow, torch_device
2324

2425
from .test_configuration_common import ConfigTester
@@ -352,7 +353,7 @@ def test_attention_outputs(self):
352353
if "labels" in inputs_dict:
353354
correct_outlen += 1 # loss is added to beginning
354355
# Question Answering model returns start_logits and end_logits
355-
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
356+
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
356357
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
357358
if "past_key_values" in outputs:
358359
correct_outlen += 1 # past_key_values have been returned

tests/test_modeling_electra.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
from transformers import is_torch_available
20+
from transformers.models.auto import get_values
2021
from transformers.testing_utils import require_torch, slow, torch_device
2122

2223
from .test_configuration_common import ConfigTester
@@ -292,7 +293,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
292293
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
293294

294295
if return_labels:
295-
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values():
296+
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
296297
inputs_dict["labels"] = torch.zeros(
297298
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
298299
)

0 commit comments

Comments
 (0)