-
Notifications
You must be signed in to change notification settings - Fork 28.4k
/
Copy pathtest_training_args.py
67 lines (52 loc) · 2.56 KB
/
test_training_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import tempfile
import unittest
from transformers import TrainingArguments
class TestTrainingArguments(unittest.TestCase):
def test_default_output_dir(self):
"""Test that output_dir defaults to 'trainer_output' when not specified."""
args = TrainingArguments(output_dir=None)
self.assertEqual(args.output_dir, "trainer_output")
def test_custom_output_dir(self):
"""Test that output_dir is respected when specified."""
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(output_dir=tmp_dir)
self.assertEqual(args.output_dir, tmp_dir)
def test_output_dir_creation(self):
"""Test that output_dir is created only when needed."""
with tempfile.TemporaryDirectory() as tmp_dir:
output_dir = os.path.join(tmp_dir, "test_output")
# Directory should not exist before creating args
self.assertFalse(os.path.exists(output_dir))
# Create args with save_strategy="no" - should not create directory
args = TrainingArguments(
output_dir=output_dir,
do_train=True,
save_strategy="no",
report_to=None,
)
self.assertFalse(os.path.exists(output_dir))
# Now set save_strategy="steps" - should create directory when needed
args.save_strategy = "steps"
args.save_steps = 1
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
# Directory should be created when actually needed (e.g. in Trainer)
def test_torch_empty_cache_steps_requirements(self):
"""Test that torch_empty_cache_steps is a positive integer or None."""
# None is acceptable (feature is disabled):
args = TrainingArguments(torch_empty_cache_steps=None)
self.assertIsNone(args.torch_empty_cache_steps)
# non-int is unacceptable:
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps=1.0)
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps="none")
# negative int is unacceptable:
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps=-1)
# zero is unacceptable:
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps=0)
# positive int is acceptable:
args = TrainingArguments(torch_empty_cache_steps=1)
self.assertEqual(args.torch_empty_cache_steps, 1)