24
24
25
25
from transformers import is_torch_available
26
26
from transformers .file_utils import WEIGHTS_NAME
27
+ from transformers .models .auto import get_values
27
28
from transformers .testing_utils import require_torch , require_torch_multi_gpu , slow , torch_device
28
29
29
30
@@ -79,7 +80,7 @@ class ModelTesterMixin:
79
80
80
81
def _prepare_for_class (self , inputs_dict , model_class , return_labels = False ):
81
82
inputs_dict = copy .deepcopy (inputs_dict )
82
- if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING . values ( ):
83
+ if model_class in get_values ( MODEL_FOR_MULTIPLE_CHOICE_MAPPING ):
83
84
inputs_dict = {
84
85
k : v .unsqueeze (1 ).expand (- 1 , self .model_tester .num_choices , - 1 ).contiguous ()
85
86
if isinstance (v , torch .Tensor ) and v .ndim > 1
@@ -88,28 +89,28 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
88
89
}
89
90
90
91
if return_labels :
91
- if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING . values ( ):
92
+ if model_class in get_values ( MODEL_FOR_MULTIPLE_CHOICE_MAPPING ):
92
93
inputs_dict ["labels" ] = torch .ones (self .model_tester .batch_size , dtype = torch .long , device = torch_device )
93
- elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING . values ( ):
94
+ elif model_class in get_values ( MODEL_FOR_QUESTION_ANSWERING_MAPPING ):
94
95
inputs_dict ["start_positions" ] = torch .zeros (
95
96
self .model_tester .batch_size , dtype = torch .long , device = torch_device
96
97
)
97
98
inputs_dict ["end_positions" ] = torch .zeros (
98
99
self .model_tester .batch_size , dtype = torch .long , device = torch_device
99
100
)
100
101
elif model_class in [
101
- * MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING . values ( ),
102
- * MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING . values ( ),
103
- * MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING . values ( ),
102
+ * get_values ( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING ),
103
+ * get_values ( MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING ),
104
+ * get_values ( MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING ),
104
105
]:
105
106
inputs_dict ["labels" ] = torch .zeros (
106
107
self .model_tester .batch_size , dtype = torch .long , device = torch_device
107
108
)
108
109
elif model_class in [
109
- * MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING . values ( ),
110
- * MODEL_FOR_CAUSAL_LM_MAPPING . values ( ),
111
- * MODEL_FOR_MASKED_LM_MAPPING . values ( ),
112
- * MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING . values ( ),
110
+ * get_values ( MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING ),
111
+ * get_values ( MODEL_FOR_CAUSAL_LM_MAPPING ),
112
+ * get_values ( MODEL_FOR_MASKED_LM_MAPPING ),
113
+ * get_values ( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING ),
113
114
]:
114
115
inputs_dict ["labels" ] = torch .zeros (
115
116
(self .model_tester .batch_size , self .model_tester .seq_length ), dtype = torch .long , device = torch_device
@@ -229,7 +230,7 @@ def test_training(self):
229
230
config .return_dict = True
230
231
231
232
for model_class in self .all_model_classes :
232
- if model_class in MODEL_MAPPING . values ( ):
233
+ if model_class in get_values ( MODEL_MAPPING ):
233
234
continue
234
235
model = model_class (config )
235
236
model .to (torch_device )
@@ -248,7 +249,7 @@ def test_training_gradient_checkpointing(self):
248
249
config .return_dict = True
249
250
250
251
for model_class in self .all_model_classes :
251
- if model_class in MODEL_MAPPING . values ( ):
252
+ if model_class in get_values ( MODEL_MAPPING ):
252
253
continue
253
254
model = model_class (config )
254
255
model .to (torch_device )
@@ -312,7 +313,7 @@ def test_attention_outputs(self):
312
313
if "labels" in inputs_dict :
313
314
correct_outlen += 1 # loss is added to beginning
314
315
# Question Answering model returns start_logits and end_logits
315
- if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING . values ( ):
316
+ if model_class in get_values ( MODEL_FOR_QUESTION_ANSWERING_MAPPING ):
316
317
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
317
318
if "past_key_values" in outputs :
318
319
correct_outlen += 1 # past_key_values have been returned
0 commit comments