Skip to content

Commit f588cf4

Browse files
authored
[Flax tests/FlaxBert] make from_pretrained test faster (#15561)
1 parent 7029240 commit f588cf4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/test_modeling_flax_bert.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def setUp(self):
141141

142142
@slow
143143
def test_model_from_pretrained(self):
144-
for model_class_name in self.all_model_classes:
145-
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True)
146-
outputs = model(np.ones((1, 1)))
147-
self.assertIsNotNone(outputs)
144+
# Only check this for base model, not necessary for all model classes.
145+
# This will also help speed-up tests.
146+
model = FlaxBertModel.from_pretrained("bert-base-cased")
147+
outputs = model(np.ones((1, 1)))
148+
self.assertIsNotNone(outputs)

0 commit comments

Comments
 (0)