Skip to content

Commit f13f1f8

Browse files
authored
Test checkpointing (#11682)
* Add test and see where CI is unhappy * Load with strict=False
1 parent d9b2862 commit f13f1f8

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/transformers/trainer.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,18 @@ def train(
10591059
# We load the model state dict on the CPU to avoid an OOM error.
10601060
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
10611061
# If the model is on the GPU, it still works!
1062-
self.model.load_state_dict(state_dict)
1062+
load_result = self.model.load_state_dict(state_dict, strict=False)
1063+
if len(load_result.missing_keys) != 0:
1064+
if load_result.missing_keys == self.model._keys_to_ignore_on_save:
1065+
self.model.tie_weights()
1066+
else:
1067+
logger.warn(
1068+
f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}."
1069+
)
1070+
if len(load_result.unexpected_keys) != 0:
1071+
logger.warn(
1072+
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
1073+
)
10631074

10641075
# If model was re-initialized, put it on the right device and update self.model_wrapped
10651076
if model_reloaded:

tests/test_modeling_common.py

+7
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ def test_save_load__keys_to_ignore_on_save(self):
177177
for k in _keys_to_ignore_on_save:
178178
self.assertNotIn(k, state_dict_saved)
179179

180+
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
181+
load_result = model.load_state_dict(state_dict_saved, strict=False)
182+
self.assertTrue(
183+
len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save
184+
)
185+
self.assertTrue(len(load_result.unexpected_keys) == 0)
186+
180187
def _mock_init_weights(self, module):
181188
if hasattr(module, "weight") and module.weight is not None:
182189
module.weight.data.fill_(3)

0 commit comments

Comments
 (0)