60
60
else :
61
61
config .SAVE = '{}/eval-{}' .format (config .SAVE , time .strftime ('%Y%m%d-%H-%M-%S' ))
62
62
63
- config .freeze ()
64
-
65
63
if not os .path .exists (config .SAVE ):
66
64
os .makedirs (config .SAVE , exist_ok = True )
67
65
@@ -147,12 +145,14 @@ def validate(dataloader, model, criterion, total_batch, debug_steps=100):
147
145
debug_steps: int, num of iters to log info
148
146
Returns:
149
147
val_loss_meter.avg
150
- val_acc_meter.avg
148
+ val_acc1_meter.avg
149
+ val_acc5_meter.avg
151
150
val_time
152
151
"""
153
152
model .eval ()
154
153
val_loss_meter = AverageMeter ()
155
- val_acc_meter = AverageMeter ()
154
+ val_acc1_meter = AverageMeter ()
155
+ val_acc5_meter = AverageMeter ()
156
156
time_st = time .time ()
157
157
158
158
with paddle .no_grad ():
@@ -164,27 +164,32 @@ def validate(dataloader, model, criterion, total_batch, debug_steps=100):
164
164
loss = criterion (output , label )
165
165
166
166
pred = F .softmax (output )
167
- acc = paddle .metric .accuracy (pred , label .unsqueeze (1 ))
167
+ acc1 = paddle .metric .accuracy (pred , label .unsqueeze (1 ))
168
+ acc5 = paddle .metric .accuracy (pred , label .unsqueeze (1 ), k = 5 )
168
169
169
170
dist .all_reduce (loss )
170
- dist .all_reduce (acc )
171
+ dist .all_reduce (acc1 )
172
+ dist .all_reduce (acc5 )
171
173
loss = loss / dist .get_world_size ()
172
- acc = acc / dist .get_world_size ()
174
+ acc1 = acc1 / dist .get_world_size ()
175
+ acc5 = acc5 / dist .get_world_size ()
173
176
174
177
batch_size = paddle .to_tensor (image .shape [0 ])
175
178
dist .all_reduce (batch_size )
176
179
177
180
val_loss_meter .update (loss .numpy ()[0 ], batch_size .numpy ()[0 ])
178
- val_acc_meter .update (acc .numpy ()[0 ], batch_size .numpy ()[0 ])
181
+ val_acc1_meter .update (acc1 .numpy ()[0 ], batch_size .numpy ()[0 ])
182
+ val_acc5_meter .update (acc5 .numpy ()[0 ], batch_size .numpy ()[0 ])
179
183
180
184
if batch_id % debug_steps == 0 :
181
185
logger .info (
182
186
f"Val Step[{ batch_id :04d} /{ total_batch :04d} ], " +
183
187
f"Avg Loss: { val_loss_meter .avg :.4f} , " +
184
- f"Avg Acc: { val_acc_meter .avg :.4f} " )
188
+ f"Avg Acc@1: { val_acc1_meter .avg :.4f} , " +
189
+ f"Avg Acc@5: { val_acc5_meter .avg :.4f} " )
185
190
186
191
val_time = time .time () - time_st
187
- return val_loss_meter .avg , val_acc_meter .avg , val_time
192
+ return val_loss_meter .avg , val_acc1_meter . avg , val_acc5_meter .avg , val_time
188
193
189
194
190
195
def main_worker (* args ):
@@ -288,13 +293,15 @@ def main_worker(*args):
288
293
# 6. Validation
289
294
if config .EVAL :
290
295
logger .info ('----- Start Validating' )
291
- val_loss , val_acc , val_time = validate (dataloader = dataloader_val ,
292
- model = model ,
293
- criterion = criterion ,
294
- total_batch = total_batch_val ,
295
- debug_steps = config .REPORT_FREQ )
296
+ val_loss , val_acc1 , val_acc5 , val_time = validate (
297
+ dataloader = dataloader_val ,
298
+ model = model ,
299
+ criterion = criterion ,
300
+ total_batch = total_batch_val ,
301
+ debug_steps = config .REPORT_FREQ )
296
302
logger .info (f"Validation Loss: { val_loss :.4f} , " +
297
- f"Validation Acc: { val_acc :.4f} , " +
303
+ f"Validation Acc@1: { val_acc1 :.4f} , " +
304
+ f"Validation Acc@5: { val_acc5 :.4f} , " +
298
305
f"time: { val_time :.2f} " )
299
306
return
300
307
@@ -320,14 +327,16 @@ def main_worker(*args):
320
327
# validation
321
328
if epoch % config .VALIDATE_FREQ == 0 or epoch == config .TRAIN .NUM_EPOCHS :
322
329
logger .info (f'----- Validation after Epoch: { epoch } ' )
323
- val_loss , val_acc , val_time = validate (dataloader = dataloader_val ,
324
- model = model ,
325
- criterion = criterion ,
326
- total_batch = total_batch_val ,
327
- debug_steps = config .REPORT_FREQ )
330
+ val_loss , val_acc1 , val_acc5 , val_time = validate (
331
+ dataloader = dataloader_val ,
332
+ model = model ,
333
+ criterion = criterion ,
334
+ total_batch = total_batch_val ,
335
+ debug_steps = config .REPORT_FREQ )
328
336
logger .info (f"----- Epoch[{ epoch :03d} /{ config .TRAIN .NUM_EPOCHS :03d} ], " +
329
337
f"Validation Loss: { val_loss :.4f} , " +
330
- f"Validation Acc: { val_acc :.4f} , " +
338
+ f"Validation Acc@1: { val_acc1 :.4f} , " +
339
+ f"Validation Acc@5: { val_acc5 :.4f} , " +
331
340
f"time: { val_time :.2f} " )
332
341
# model save
333
342
if local_rank == 0 :
@@ -343,6 +352,7 @@ def main_worker(*args):
343
352
def main ():
344
353
dataset_train = get_dataset (config , mode = 'train' )
345
354
dataset_val = get_dataset (config , mode = 'val' )
355
+ config .NGPUS = len (paddle .static .cuda_places ()) if config .NGPUS == - 1 else config .NGPUS
346
356
dist .spawn (main_worker , args = (dataset_train , dataset_val , ), nprocs = config .NGPUS )
347
357
348
358
0 commit comments