@@ -92,6 +92,13 @@ def train(args, train_dataset, model, tokenizer):
92
92
]
93
93
optimizer = AdamW (optimizer_grouped_parameters , lr = args .learning_rate , eps = args .adam_epsilon )
94
94
scheduler = get_linear_schedule_with_warmup (optimizer , num_warmup_steps = args .warmup_steps , num_training_steps = t_total )
95
+
96
+ # Check if saved optimizer or scheduler states exist
97
+ if os .path .isfile (os .path .join (args .model_name_or_path , 'optimizer.pt' )) and os .path .isfile (os .path .join (args .model_name_or_path , 'scheduler.pt' )):
98
+ # Load in optimizer and scheduler states
99
+ optimizer .load_state_dict (torch .load (os .path .join (args .model_name_or_path , 'optimizer.pt' )))
100
+ scheduler .load_state_dict (torch .load (os .path .join (args .model_name_or_path , 'scheduler.pt' )))
101
+
95
102
if args .fp16 :
96
103
try :
97
104
from apex import amp
@@ -120,13 +127,32 @@ def train(args, train_dataset, model, tokenizer):
120
127
logger .info (" Total optimization steps = %d" , t_total )
121
128
122
129
global_step = 0
130
+ epochs_trained = 0
131
+ steps_trained_in_current_epoch = 0
132
+ # Check if continuing training from a checkpoint
133
+ if os .path .exists (args .model_name_or_path ):
134
+ # set global_step to gobal_step of last saved checkpoint from model path
135
+ global_step = int (args .model_name_or_path .split ('-' )[- 1 ].split ('/' )[0 ])
136
+ epochs_trained = global_step // (len (train_dataloader ) // args .gradient_accumulation_steps )
137
+ steps_trained_in_current_epoch = global_step % (len (train_dataloader ) // args .gradient_accumulation_steps )
138
+
139
+ logger .info (" Continuing training from checkpoint, will skip to saved global_step" )
140
+ logger .info (" Continuing training from epoch %d" , epochs_trained )
141
+ logger .info (" Continuing training from global step %d" , global_step )
142
+ logger .info (" Will skip the first %d steps in the first epoch" , steps_trained_in_current_epoch )
143
+
123
144
tr_loss , logging_loss = 0.0 , 0.0
124
145
model .zero_grad ()
125
- train_iterator = trange (int (args .num_train_epochs ), desc = "Epoch" , disable = args .local_rank not in [- 1 , 0 ])
146
+ train_iterator = trange (epochs_trained , int (args .num_train_epochs ), desc = "Epoch" , disable = args .local_rank not in [- 1 , 0 ])
126
147
set_seed (args ) # Added here for reproductibility (even between python 2 and 3)
127
148
for _ in train_iterator :
128
149
epoch_iterator = tqdm (train_dataloader , desc = "Iteration" , disable = args .local_rank not in [- 1 , 0 ])
129
150
for step , batch in enumerate (epoch_iterator ):
151
+ # Skip past any already trained steps if resuming training
152
+ if steps_trained_in_current_epoch > 0 :
153
+ steps_trained_in_current_epoch -= 1
154
+ continue
155
+
130
156
model .train ()
131
157
batch = tuple (t .to (args .device ) for t in batch )
132
158
inputs = {'input_ids' : batch [0 ],
@@ -177,9 +203,15 @@ def train(args, train_dataset, model, tokenizer):
177
203
os .makedirs (output_dir )
178
204
model_to_save = model .module if hasattr (model , 'module' ) else model # Take care of distributed/parallel training
179
205
model_to_save .save_pretrained (output_dir )
206
+ tokenizer .save_pretrained (output_dir )
207
+
180
208
torch .save (args , os .path .join (output_dir , 'training_args.bin' ))
181
209
logger .info ("Saving model checkpoint to %s" , output_dir )
182
210
211
+ torch .save (optimizer .state_dict (), os .path .join (output_dir , 'optimizer.pt' ))
212
+ torch .save (scheduler .state_dict (), os .path .join (output_dir , 'scheduler.pt' ))
213
+ logger .info ("Saving optimizer and scheduler states to %s" , output_dir )
214
+
183
215
if args .max_steps > 0 and global_step > args .max_steps :
184
216
epoch_iterator .close ()
185
217
break
0 commit comments