Skip to content

Commit 33f36c8

Browse files
Add a main_input_name attribute to all models (#14803)
* Add a main_input_name attribute to all models * Fix tests * Wtf Vs Code? * Update src/transformers/models/imagegpt/modeling_imagegpt.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Style * Fix copies Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 0940e9b commit 33f36c8

32 files changed

+61
-5
lines changed

src/transformers/modeling_flax_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
7676
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
7777
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
7878
derived classes of the same architecture adding modules on top of the base model.
79+
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
80+
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
7981
"""
8082
config_class = None
8183
base_model_prefix = ""
84+
main_input_name = "input_ids"
8285

8386
def __init__(
8487
self,

src/transformers/modeling_tf_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
653653
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
654654
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
655655
derived classes of the same architecture adding modules on top of the base model.
656+
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
657+
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
656658
"""
657659
config_class = None
658660
base_model_prefix = ""
661+
main_input_name = "input_ids"
662+
659663
# a list of re pattern of tensor names to ignore from the model when loading the model weights
660664
# (and avoid unnecessary warnings).
661665
_keys_to_ignore_on_load_missing = None

src/transformers/modeling_utils.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import inspect
1818
import os
1919
import re
20-
import warnings
2120
from contextlib import contextmanager
2221
from dataclasses import dataclass
2322
from functools import partial
@@ -376,11 +375,10 @@ def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> in
376375
Returns:
377376
:obj:`int`: The total number of tokens.
378377
"""
379-
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
380-
if token_inputs:
381-
return sum([token_input.numel() for token_input in token_inputs])
378+
if self.main_input_name in input_dict:
379+
return input_dict[self.main_input_name].numel()
382380
else:
383-
warnings.warn(
381+
logger.warn(
384382
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
385383
)
386384
return 0
@@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
438436
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
439437
derived classes of the same architecture adding modules on top of the base model.
440438
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
439+
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
440+
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
441441
"""
442442
config_class = None
443443
base_model_prefix = ""
444+
main_input_name = "input_ids"
445+
444446
# a list of re pattern of tensor names to ignore from the model when loading the model weights
445447
# (and avoid unnecessary warnings).
446448
_keys_to_ignore_on_load_missing = None

src/transformers/models/beit/modeling_beit.py

+1
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel):
523523

524524
config_class = BeitConfig
525525
base_model_prefix = "beit"
526+
main_input_name = "pixel_values"
526527
supports_gradient_checkpointing = True
527528

528529
def _init_weights(self, module):

src/transformers/models/beit/modeling_flax_beit.py

+1
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
590590

591591
config_class = BeitConfig
592592
base_model_prefix = "beit"
593+
main_input_name = "pixel_values"
593594
module_class: nn.Module = None
594595

595596
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):

src/transformers/models/clip/modeling_clip.py

+1
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def forward(
789789

790790
class CLIPVisionModel(CLIPPreTrainedModel):
791791
config_class = CLIPVisionConfig
792+
main_input_name = "pixel_values"
792793

793794
def __init__(self, config: CLIPVisionConfig):
794795
super().__init__(config)

src/transformers/models/clip/modeling_flax_clip.py

+1
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def __call__(
653653

654654
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
655655
config_class = CLIPVisionConfig
656+
main_input_name = "pixel_values"
656657
module_class: nn.Module = None
657658

658659
def __init__(

src/transformers/models/deit/modeling_deit.py

+1
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
385385

386386
config_class = DeiTConfig
387387
base_model_prefix = "deit"
388+
main_input_name = "pixel_values"
388389
supports_gradient_checkpointing = True
389390

390391
def _init_weights(self, module):

src/transformers/models/detr/modeling_detr.py

+1
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,7 @@ def forward(self, hidden_states: torch.Tensor):
784784
class DetrPreTrainedModel(PreTrainedModel):
785785
config_class = DetrConfig
786786
base_model_prefix = "model"
787+
main_input_name = "pixel_values"
787788

788789
def _init_weights(self, module):
789790
std = self.config.init_std

src/transformers/models/hubert/modeling_hubert.py

+1
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel):
776776

777777
config_class = HubertConfig
778778
base_model_prefix = "hubert"
779+
main_input_name = "input_values"
779780
supports_gradient_checkpointing = True
780781
_keys_to_ignore_on_load_missing = [r"position_ids"]
781782

src/transformers/models/hubert/modeling_tf_hubert.py

+1
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
12651265

12661266
config_class = HubertConfig
12671267
base_model_prefix = "hubert"
1268+
main_input_name = "input_values"
12681269

12691270
@property
12701271
def dummy_inputs(self) -> Dict[str, tf.Tensor]:

src/transformers/models/imagegpt/modeling_imagegpt.py

+1
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
496496
config_class = ImageGPTConfig
497497
load_tf_weights = load_tf_weights_in_imagegpt
498498
base_model_prefix = "transformer"
499+
main_input_name = "input_ids"
499500
supports_gradient_checkpointing = True
500501

501502
def __init__(self, *inputs, **kwargs):

src/transformers/models/perceiver/modeling_perceiver.py

+1
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel):
619619

620620
config_class = PerceiverConfig
621621
base_model_prefix = "perceiver"
622+
main_input_name = "inputs"
622623

623624
def _init_weights(self, module):
624625
"""Initialize the weights"""

src/transformers/models/segformer/modeling_segformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel):
406406

407407
config_class = SegformerConfig
408408
base_model_prefix = "segformer"
409+
main_input_name = "pixel_values"
409410

410411
def _init_weights(self, module):
411412
"""Initialize the weights"""

src/transformers/models/sew/modeling_sew.py

+1
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel):
675675

676676
config_class = SEWConfig
677677
base_model_prefix = "sew"
678+
main_input_name = "input_values"
678679
supports_gradient_checkpointing = True
679680
_keys_to_ignore_on_load_missing = [r"position_ids"]
680681

src/transformers/models/sew_d/modeling_sew_d.py

+1
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
12011201

12021202
config_class = SEWDConfig
12031203
base_model_prefix = "sew-d"
1204+
main_input_name = "input_values"
12041205
_keys_to_ignore_on_load_missing = [r"position_ids"]
12051206
supports_gradient_checkpointing = True
12061207

src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
180180
"""
181181
config_class = SpeechEncoderDecoderConfig
182182
base_model_prefix = "speech_encoder_decoder"
183+
main_input_name = "input_values"
183184

184185
def __init__(
185186
self,

src/transformers/models/speech_to_text/modeling_speech_to_text.py

+1
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,7 @@ def forward(
539539
class Speech2TextPreTrainedModel(PreTrainedModel):
540540
config_class = Speech2TextConfig
541541
base_model_prefix = "model"
542+
main_input_name = "input_features"
542543
supports_gradient_checkpointing = True
543544

544545
def _init_weights(self, module):

src/transformers/models/unispeech/modeling_unispeech.py

+1
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
912912

913913
config_class = UniSpeechConfig
914914
base_model_prefix = "unispeech"
915+
main_input_name = "input_values"
915916
_keys_to_ignore_on_load_missing = [r"position_ids"]
916917
supports_gradient_checkpointing = True
917918

src/transformers/models/unispeech_sat/modeling_unispeech_sat.py

+1
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
947947

948948
config_class = UniSpeechSatConfig
949949
base_model_prefix = "unispeech_sat"
950+
main_input_name = "input_values"
950951
_keys_to_ignore_on_load_missing = [r"position_ids"]
951952
supports_gradient_checkpointing = True
952953

src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
283283
"""
284284
config_class = VisionEncoderDecoderConfig
285285
base_model_prefix = "vision_encoder_decoder"
286+
main_input_name = "pixel_values"
286287
module_class = FlaxVisionEncoderDecoderModule
287288

288289
def __init__(

src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
160160
"""
161161
config_class = VisionEncoderDecoderConfig
162162
base_model_prefix = "vision_encoder_decoder"
163+
main_input_name = "pixel_values"
163164

164165
def __init__(
165166
self,

src/transformers/models/vit/modeling_flax_vit.py

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
406406

407407
config_class = ViTConfig
408408
base_model_prefix = "vit"
409+
main_input_name = "pixel_values"
409410
module_class: nn.Module = None
410411

411412
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):

src/transformers/models/vit/modeling_tf_vit.py

+1
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel):
555555

556556
config_class = ViTConfig
557557
base_model_prefix = "vit"
558+
main_input_name = "pixel_values"
558559

559560
@property
560561
def dummy_inputs(self) -> Dict[str, tf.Tensor]:

src/transformers/models/vit/modeling_vit.py

+1
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
412412

413413
config_class = ViTConfig
414414
base_model_prefix = "vit"
415+
main_input_name = "pixel_values"
415416
supports_gradient_checkpointing = True
416417

417418
def _init_weights(self, module):

src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py

+1
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
775775

776776
config_class = Wav2Vec2Config
777777
base_model_prefix: str = "wav2vec2"
778+
main_input_name = "input_values"
778779
module_class: nn.Module = None
779780

780781
def __init__(

src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py

+1
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
12561256

12571257
config_class = Wav2Vec2Config
12581258
base_model_prefix = "wav2vec2"
1259+
main_input_name = "input_values"
12591260

12601261
@property
12611262
def dummy_inputs(self) -> Dict[str, tf.Tensor]:

src/transformers/models/wav2vec2/modeling_wav2vec2.py

+1
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
10441044

10451045
config_class = Wav2Vec2Config
10461046
base_model_prefix = "wav2vec2"
1047+
main_input_name = "input_values"
10471048
_keys_to_ignore_on_load_missing = [r"position_ids"]
10481049
supports_gradient_checkpointing = True
10491050

src/transformers/models/wavlm/modeling_wavlm.py

+1
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
996996

997997
config_class = WavLMConfig
998998
base_model_prefix = "wavlm"
999+
main_input_name = "input_values"
9991000
_keys_to_ignore_on_load_missing = [r"position_ids"]
10001001
supports_gradient_checkpointing = True
10011002

tests/test_modeling_common.py

+7
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,13 @@ def test_model_common_attributes(self):
13151315
x = model.get_output_embeddings()
13161316
self.assertTrue(x is None or isinstance(x, nn.Linear))
13171317

1318+
def test_model_main_input_name(self):
1319+
for model_class in self.all_model_classes:
1320+
model_signature = inspect.signature(getattr(model_class, "forward"))
1321+
# The main input is the name of the argument after `self`
1322+
observed_main_input_name = list(model_signature.parameters.keys())[1]
1323+
self.assertEqual(model_class.main_input_name, observed_main_input_name)
1324+
13181325
def test_correct_missing_keys(self):
13191326
if not self.test_missing_keys:
13201327
return

tests/test_modeling_flax_common.py

+7
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,13 @@ def test_save_load_in_bf16(self):
778778
for name, type_ in types.items():
779779
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
780780

781+
def test_model_main_input_name(self):
782+
for model_class in self.all_model_classes:
783+
model_signature = inspect.signature(getattr(model_class, "__call__"))
784+
# The main input is the name of the argument after `self`
785+
observed_main_input_name = list(model_signature.parameters.keys())[1]
786+
self.assertEqual(model_class.main_input_name, observed_main_input_name)
787+
781788
def test_headmasking(self):
782789
if not self.test_head_masking:
783790
return

tests/test_modeling_tf_common.py

+7
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,13 @@ def test_load_with_mismatched_shapes(self):
11831183
else:
11841184
new_model_without_prefix(input_ids)
11851185

1186+
def test_model_main_input_name(self):
1187+
for model_class in self.all_model_classes:
1188+
model_signature = inspect.signature(getattr(model_class, "call"))
1189+
# The main input is the name of the argument after `self`
1190+
observed_main_input_name = list(model_signature.parameters.keys())[1]
1191+
self.assertEqual(model_class.main_input_name, observed_main_input_name)
1192+
11861193
def _generate_random_bad_tokens(self, num_bad_tokens, model):
11871194
# special tokens cannot be bad tokens
11881195
special_tokens = []

0 commit comments

Comments
 (0)