Skip to content

Commit 6aa9194

Browse files
committed
Update run_xnli to save optimizer and scheduler states, then resume training from a checkpoint
1 parent 89896fe commit 6aa9194

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

examples/run_xnli.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ def train(args, train_dataset, model, tokenizer):
9292
]
9393
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
9494
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
95+
96+
# Check if saved optimizer or scheduler states exist
97+
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
98+
# Load in optimizer and scheduler states
99+
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
100+
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
101+
95102
if args.fp16:
96103
try:
97104
from apex import amp
@@ -120,13 +127,32 @@ def train(args, train_dataset, model, tokenizer):
120127
logger.info(" Total optimization steps = %d", t_total)
121128

122129
global_step = 0
130+
epochs_trained = 0
131+
steps_trained_in_current_epoch = 0
132+
# Check if continuing training from a checkpoint
133+
if os.path.exists(args.model_name_or_path):
134+
# set global_step to gobal_step of last saved checkpoint from model path
135+
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
136+
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
137+
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
138+
139+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
140+
logger.info(" Continuing training from epoch %d", epochs_trained)
141+
logger.info(" Continuing training from global step %d", global_step)
142+
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
143+
123144
tr_loss, logging_loss = 0.0, 0.0
124145
model.zero_grad()
125-
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
146+
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
126147
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
127148
for _ in train_iterator:
128149
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
129150
for step, batch in enumerate(epoch_iterator):
151+
# Skip past any already trained steps if resuming training
152+
if steps_trained_in_current_epoch > 0:
153+
steps_trained_in_current_epoch -= 1
154+
continue
155+
130156
model.train()
131157
batch = tuple(t.to(args.device) for t in batch)
132158
inputs = {'input_ids': batch[0],
@@ -177,9 +203,15 @@ def train(args, train_dataset, model, tokenizer):
177203
os.makedirs(output_dir)
178204
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
179205
model_to_save.save_pretrained(output_dir)
206+
tokenizer.save_pretrained(output_dir)
207+
180208
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
181209
logger.info("Saving model checkpoint to %s", output_dir)
182210

211+
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
212+
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
213+
logger.info("Saving optimizer and scheduler states to %s", output_dir)
214+
183215
if args.max_steps > 0 and global_step > args.max_steps:
184216
epoch_iterator.close()
185217
break

0 commit comments

Comments
 (0)