File tree 2 files changed +19
-1
lines changed
2 files changed +19
-1
lines changed Original file line number Diff line number Diff line change @@ -1059,7 +1059,18 @@ def train(
1059
1059
# We load the model state dict on the CPU to avoid an OOM error.
1060
1060
state_dict = torch .load (os .path .join (resume_from_checkpoint , WEIGHTS_NAME ), map_location = "cpu" )
1061
1061
# If the model is on the GPU, it still works!
1062
- self .model .load_state_dict (state_dict )
1062
+ load_result = self .model .load_state_dict (state_dict , strict = False )
1063
+ if len (load_result .missing_keys ) != 0 :
1064
+ if load_result .missing_keys == self .model ._keys_to_ignore_on_save :
1065
+ self .model .tie_weights ()
1066
+ else :
1067
+ logger .warn (
1068
+ f"There were missing keys in the checkpoint model loaded: { load_result .missing_keys } ."
1069
+ )
1070
+ if len (load_result .unexpected_keys ) != 0 :
1071
+ logger .warn (
1072
+ f"There were unexpected keys in the checkpoint model loaded: { load_result .unexpected_keys } ."
1073
+ )
1063
1074
1064
1075
# If model was re-initialized, put it on the right device and update self.model_wrapped
1065
1076
if model_reloaded :
Original file line number Diff line number Diff line change @@ -177,6 +177,13 @@ def test_save_load__keys_to_ignore_on_save(self):
177
177
for k in _keys_to_ignore_on_save :
178
178
self .assertNotIn (k , state_dict_saved )
179
179
180
+ # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
181
+ load_result = model .load_state_dict (state_dict_saved , strict = False )
182
+ self .assertTrue (
183
+ len (load_result .missing_keys ) == 0 or load_result .missing_keys == model ._keys_to_ignore_on_save
184
+ )
185
+ self .assertTrue (len (load_result .unexpected_keys ) == 0 )
186
+
180
187
def _mock_init_weights (self , module ):
181
188
if hasattr (module , "weight" ) and module .weight is not None :
182
189
module .weight .data .fill_ (3 )
You can’t perform that action at this time.
0 commit comments