Skip to content

Commit a78d0d8

Browse files
authored
Add checkpoints used for preemption. (#3789)
1 parent c2ab0c5 commit a78d0d8

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

references/detection/train.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,19 @@ def main(args):
188188
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
189189
lr_scheduler.step()
190190
if args.output_dir:
191-
utils.save_on_master({
191+
checkpoint = {
192192
'model': model_without_ddp.state_dict(),
193193
'optimizer': optimizer.state_dict(),
194194
'lr_scheduler': lr_scheduler.state_dict(),
195195
'args': args,
196-
'epoch': epoch},
196+
'epoch': epoch
197+
}
198+
utils.save_on_master(
199+
checkpoint,
197200
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
201+
utils.save_on_master(
202+
checkpoint,
203+
os.path.join(args.output_dir, 'checkpoint.pth'))
198204

199205
# evaluate after every epoch
200206
evaluate(model, data_loader_test, device=device)

references/segmentation/train.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,19 @@ def main(args):
157157
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
158158
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
159159
print(confmat)
160+
checkpoint = {
161+
'model': model_without_ddp.state_dict(),
162+
'optimizer': optimizer.state_dict(),
163+
'lr_scheduler': lr_scheduler.state_dict(),
164+
'epoch': epoch,
165+
'args': args
166+
}
160167
utils.save_on_master(
161-
{
162-
'model': model_without_ddp.state_dict(),
163-
'optimizer': optimizer.state_dict(),
164-
'lr_scheduler': lr_scheduler.state_dict(),
165-
'epoch': epoch,
166-
'args': args
167-
},
168+
checkpoint,
168169
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
170+
utils.save_on_master(
171+
checkpoint,
172+
os.path.join(args.output_dir, 'checkpoint.pth'))
169173

170174
total_time = time.time() - start_time
171175
total_time_str = str(datetime.timedelta(seconds=int(total_time)))

0 commit comments

Comments
 (0)