Skip to content

Commit acb709d

Browse files
authored
Change no trainer image_classification test (#17635)
* Adjust test arguments and use a new example test
1 parent e70abda commit acb709d

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

examples/pytorch/test_accelerate_examples.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -300,19 +300,25 @@ def test_run_image_classification_no_trainer(self):
300300
tmp_dir = self.get_auto_remove_tmp_dir()
301301
testargs = f"""
302302
{self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
303-
--dataset_name huggingface/image-classification-test-sample
303+
--model_name_or_path google/vit-base-patch16-224-in21k
304+
--dataset_name hf-internal-testing/cats_vs_dogs_sample
305+
--learning_rate 1e-4
306+
--per_device_train_batch_size 2
307+
--per_device_eval_batch_size 1
308+
--max_train_steps 2
309+
--train_val_split 0.1
310+
--seed 42
304311
--output_dir {tmp_dir}
305-
--num_warmup_steps=8
306-
--learning_rate=3e-3
307-
--per_device_train_batch_size=2
308-
--per_device_eval_batch_size=1
309-
--checkpointing_steps epoch
310312
--with_tracking
311-
--seed 42
313+
--checkpointing_steps 1
312314
""".split()
313315

316+
if is_cuda_and_apex_available():
317+
testargs.append("--fp16")
318+
314319
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
315320
result = get_results(tmp_dir)
316-
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
317-
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
321+
# The base model scores a 25%
322+
self.assertGreaterEqual(result["eval_accuracy"], 0.625)
323+
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "step_1")))
318324
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))

0 commit comments

Comments
 (0)