@@ -300,19 +300,25 @@ def test_run_image_classification_no_trainer(self):
300
300
tmp_dir = self .get_auto_remove_tmp_dir ()
301
301
testargs = f"""
302
302
{ self .examples_dir } /pytorch/image-classification/run_image_classification_no_trainer.py
303
- --dataset_name huggingface/image-classification-test-sample
303
+ --model_name_or_path google/vit-base-patch16-224-in21k
304
+ --dataset_name hf-internal-testing/cats_vs_dogs_sample
305
+ --learning_rate 1e-4
306
+ --per_device_train_batch_size 2
307
+ --per_device_eval_batch_size 1
308
+ --max_train_steps 2
309
+ --train_val_split 0.1
310
+ --seed 42
304
311
--output_dir { tmp_dir }
305
- --num_warmup_steps=8
306
- --learning_rate=3e-3
307
- --per_device_train_batch_size=2
308
- --per_device_eval_batch_size=1
309
- --checkpointing_steps epoch
310
312
--with_tracking
311
- --seed 42
313
+ --checkpointing_steps 1
312
314
""" .split ()
313
315
316
+ if is_cuda_and_apex_available ():
317
+ testargs .append ("--fp16" )
318
+
314
319
_ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
315
320
result = get_results (tmp_dir )
316
- self .assertGreaterEqual (result ["eval_accuracy" ], 0.50 )
317
- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
321
+ # The base model scores a 25%
322
+ self .assertGreaterEqual (result ["eval_accuracy" ], 0.625 )
323
+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "step_1" )))
318
324
self .assertTrue (os .path .exists (os .path .join (tmp_dir , "image_classification_no_trainer" )))
0 commit comments