Skip to content

Commit 6b480ad

Browse files
authored
fix bugs for ViT
1 parent 9b15251 commit 6b480ad

9 files changed

+92
-54
lines changed

image_classification/ViT/config.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
_C.DATA.DATA_PATH = '/dataset/imagenet/' # path to dataset
3434
_C.DATA.DATASET = 'imagenet2012' # dataset name
3535
_C.DATA.IMAGE_SIZE = 224 # input image size: 224 for pretrain, 384 for finetune
36-
_C.DATA.CROP_PCT = 1.0 # input image scale ratio, scale is applied before centercrop in eval mode
37-
_C.DATA.NUM_WORKERS = 4 # number of data loading threads
36+
_C.DATA.CROP_PCT = 0.875 # input image scale ratio, scale is applied before centercrop in eval mode
37+
_C.DATA.NUM_WORKERS = 2 # number of data loading threads
3838

3939
# model settings
4040
_C.MODEL = CN()
@@ -62,10 +62,10 @@
6262
_C.TRAIN.LAST_EPOCH = 0
6363
_C.TRAIN.NUM_EPOCHS = 300
6464
_C.TRAIN.WARMUP_EPOCHS = 3 #34 # ~ 10k steps for 4096 batch size
65-
_C.TRAIN.WEIGHT_DECAY = 0.01 #0.3 # 0.0 for finetune
65+
_C.TRAIN.WEIGHT_DECAY = 0.05 #0.3 # 0.0 for finetune
6666
_C.TRAIN.BASE_LR = 0.001 #0.003 for pretrain # 0.03 for finetune
6767
_C.TRAIN.WARMUP_START_LR = 1e-6 #0.0
68-
_C.TRAIN.END_LR = 1e-5
68+
_C.TRAIN.END_LR = 5e-4
6969
_C.TRAIN.GRAD_CLIP = 1.0
7070
_C.TRAIN.ACCUM_ITER = 2 #1
7171

@@ -84,13 +84,13 @@
8484
# misc
8585
_C.SAVE = "./output"
8686
_C.TAG = "default"
87-
_C.SAVE_FREQ = 20 # freq to save chpt
88-
_C.REPORT_FREQ = 50 # freq to logging info
89-
_C.VALIDATE_FREQ = 20 # freq to do validation
87+
_C.SAVE_FREQ = 10 # freq to save chpt
88+
_C.REPORT_FREQ = 100 # freq to logging info
89+
_C.VALIDATE_FREQ = 100 # freq to do validation
9090
_C.SEED = 0
9191
_C.EVAL = False # run evaluation only
9292
_C.LOCAL_RANK = 0
93-
_C.NGPUS = 1
93+
_C.NGPUS = -1
9494

9595

9696
def _update_config_from_file(config, cfg_file):

image_classification/ViT/main_multi_gpu.py

+32-22
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@
6060
else:
6161
config.SAVE = '{}/eval-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S'))
6262

63-
config.freeze()
64-
6563
if not os.path.exists(config.SAVE):
6664
os.makedirs(config.SAVE, exist_ok=True)
6765

@@ -147,12 +145,14 @@ def validate(dataloader, model, criterion, total_batch, debug_steps=100):
147145
debug_steps: int, num of iters to log info
148146
Returns:
149147
val_loss_meter.avg
150-
val_acc_meter.avg
148+
val_acc1_meter.avg
149+
val_acc5_meter.avg
151150
val_time
152151
"""
153152
model.eval()
154153
val_loss_meter = AverageMeter()
155-
val_acc_meter = AverageMeter()
154+
val_acc1_meter = AverageMeter()
155+
val_acc5_meter = AverageMeter()
156156
time_st = time.time()
157157

158158
with paddle.no_grad():
@@ -164,27 +164,32 @@ def validate(dataloader, model, criterion, total_batch, debug_steps=100):
164164
loss = criterion(output, label)
165165

166166
pred = F.softmax(output)
167-
acc = paddle.metric.accuracy(pred, label.unsqueeze(1))
167+
acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1))
168+
acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5)
168169

169170
dist.all_reduce(loss)
170-
dist.all_reduce(acc)
171+
dist.all_reduce(acc1)
172+
dist.all_reduce(acc5)
171173
loss = loss / dist.get_world_size()
172-
acc = acc / dist.get_world_size()
174+
acc1 = acc1 / dist.get_world_size()
175+
acc5 = acc5 / dist.get_world_size()
173176

174177
batch_size = paddle.to_tensor(image.shape[0])
175178
dist.all_reduce(batch_size)
176179

177180
val_loss_meter.update(loss.numpy()[0], batch_size.numpy()[0])
178-
val_acc_meter.update(acc.numpy()[0], batch_size.numpy()[0])
181+
val_acc1_meter.update(acc1.numpy()[0], batch_size.numpy()[0])
182+
val_acc5_meter.update(acc5.numpy()[0], batch_size.numpy()[0])
179183

180184
if batch_id % debug_steps == 0:
181185
logger.info(
182186
f"Val Step[{batch_id:04d}/{total_batch:04d}], " +
183187
f"Avg Loss: {val_loss_meter.avg:.4f}, " +
184-
f"Avg Acc: {val_acc_meter.avg:.4f}")
188+
f"Avg Acc@1: {val_acc1_meter.avg:.4f}, "+
189+
f"Avg Acc@5: {val_acc5_meter.avg:.4f}")
185190

186191
val_time = time.time() - time_st
187-
return val_loss_meter.avg, val_acc_meter.avg, val_time
192+
return val_loss_meter.avg, val_acc1_meter.avg, val_acc5_meter.avg, val_time
188193

189194

190195
def main_worker(*args):
@@ -288,13 +293,15 @@ def main_worker(*args):
288293
# 6. Validation
289294
if config.EVAL:
290295
logger.info('----- Start Validating')
291-
val_loss, val_acc, val_time = validate(dataloader=dataloader_val,
292-
model=model,
293-
criterion=criterion,
294-
total_batch=total_batch_val,
295-
debug_steps=config.REPORT_FREQ)
296+
val_loss, val_acc1, val_acc5, val_time = validate(
297+
dataloader=dataloader_val,
298+
model=model,
299+
criterion=criterion,
300+
total_batch=total_batch_val,
301+
debug_steps=config.REPORT_FREQ)
296302
logger.info(f"Validation Loss: {val_loss:.4f}, " +
297-
f"Validation Acc: {val_acc:.4f}, " +
303+
f"Validation Acc@1: {val_acc1:.4f}, " +
304+
f"Validation Acc@5: {val_acc5:.4f}, " +
298305
f"time: {val_time:.2f}")
299306
return
300307

@@ -320,14 +327,16 @@ def main_worker(*args):
320327
# validation
321328
if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
322329
logger.info(f'----- Validation after Epoch: {epoch}')
323-
val_loss, val_acc, val_time = validate(dataloader=dataloader_val,
324-
model=model,
325-
criterion=criterion,
326-
total_batch=total_batch_val,
327-
debug_steps=config.REPORT_FREQ)
330+
val_loss, val_acc1, val_acc5, val_time = validate(
331+
dataloader=dataloader_val,
332+
model=model,
333+
criterion=criterion,
334+
total_batch=total_batch_val,
335+
debug_steps=config.REPORT_FREQ)
328336
logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
329337
f"Validation Loss: {val_loss:.4f}, " +
330-
f"Validation Acc: {val_acc:.4f}, " +
338+
f"Validation Acc@1: {val_acc1:.4f}, " +
339+
f"Validation Acc@5: {val_acc5:.4f}, " +
331340
f"time: {val_time:.2f}")
332341
# model save
333342
if local_rank == 0:
@@ -343,6 +352,7 @@ def main_worker(*args):
343352
def main():
344353
dataset_train = get_dataset(config, mode='train')
345354
dataset_val = get_dataset(config, mode='val')
355+
config.NGPUS = len(paddle.static.cuda_places()) if config.NGPUS == -1 else config.NGPUS
346356
dist.spawn(main_worker, args=(dataset_train, dataset_val, ), nprocs=config.NGPUS)
347357

348358

image_classification/ViT/main_single_gpu.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -145,12 +144,14 @@ def validate(dataloader, model, criterion, total_batch, debug_steps=100):
145144
debug_steps: int, num of iters to log info
146145
Returns:
147146
val_loss_meter.avg
148-
val_acc_meter.avg
147+
val_acc1_meter.avg
148+
val_acc5_meter.avg
149149
val_time
150150
"""
151151
model.eval()
152152
val_loss_meter = AverageMeter()
153-
val_acc_meter = AverageMeter()
153+
val_acc1_meter = AverageMeter()
154+
val_acc5_meter = AverageMeter()
154155
time_st = time.time()
155156

156157
with paddle.no_grad():
@@ -162,20 +163,23 @@ def validate(dataloader, model, criterion, total_batch, debug_steps=100):
162163
loss = criterion(output, label)
163164

164165
pred = F.softmax(output)
165-
acc = paddle.metric.accuracy(pred, label.unsqueeze(1))
166+
acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1))
167+
acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5)
166168

167169
batch_size = image.shape[0]
168170
val_loss_meter.update(loss.numpy()[0], batch_size)
169-
val_acc_meter.update(acc.numpy()[0], batch_size)
171+
val_acc1_meter.update(acc1.numpy()[0], batch_size)
172+
val_acc5_meter.update(acc5.numpy()[0], batch_size)
170173

171174
if batch_id % debug_steps == 0:
172175
logger.info(
173176
f"Val Step[{batch_id:04d}/{total_batch:04d}], " +
174177
f"Avg Loss: {val_loss_meter.avg:.4f}, " +
175-
f"Avg Acc: {val_acc_meter.avg:.4f}")
178+
f"Avg Acc@1: {val_acc1_meter.avg:.4f}, ",
179+
f"Avg Acc@5: {val_acc5_meter.avg:.4f}")
176180

177181
val_time = time.time() - time_st
178-
return val_loss_meter.avg, val_acc_meter.avg, val_time
182+
return val_loss_meter.avg, val_acc1_meter.avg, val_acc5_meter.avg, val_time
179183

180184

181185
def main():
@@ -257,13 +261,15 @@ def main():
257261
# 7. Validation
258262
if config.EVAL:
259263
logger.info('----- Start Validating')
260-
val_loss, val_acc, val_time = validate(dataloader=dataloader_val,
261-
model=model,
262-
criterion=criterion,
263-
total_batch=len(dataloader_val),
264-
debug_steps=config.REPORT_FREQ)
264+
val_loss, val_acc1, val_acc5, val_time = validate(
265+
dataloader=dataloader_val,
266+
model=model,
267+
criterion=criterion,
268+
total_batch=len(dataloader_val),
269+
debug_steps=config.REPORT_FREQ)
265270
logger.info(f"Validation Loss: {val_loss:.4f}, " +
266-
f"Validation Acc: {val_acc:.4f}, " +
271+
f"Validation Acc@1: {val_acc1:.4f}, " +
272+
f"Validation Acc@5: {val_acc5:.4f}, " +
267273
f"time: {val_time:.2f}")
268274
return
269275
# 8. Start training and validation
@@ -288,14 +294,16 @@ def main():
288294
# validation
289295
if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
290296
logger.info(f'----- Validation after Epoch: {epoch}')
291-
val_loss, val_acc, val_time = validate(dataloader=dataloader_val,
292-
model=model,
293-
criterion=criterion,
294-
total_batch=len(dataloader_val),
295-
debug_steps=config.REPORT_FREQ)
297+
val_loss, val_acc1, val_acc5, val_time = validate(
298+
dataloader=dataloader_val,
299+
model=model,
300+
criterion=criterion,
301+
total_batch=len(dataloader_val),
302+
debug_steps=config.REPORT_FREQ)
296303
logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
297304
f"Validation Loss: {val_loss:.4f}, " +
298-
f"Validation Acc: {val_acc:.4f}, " +
305+
f"Validation Acc@1: {val_acc1:.4f}, " +
306+
f"Validation Acc@5: {val_acc5:.4f}, " +
299307
f"time: {val_time:.2f}")
300308
# model save
301309
if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
2+
python main_multi_gpu.py \
3+
-cfg='./configs/vit_base_patch16_384.yaml' \
4+
-dataset='imagenet2012' \
5+
-batch_size=4 \
6+
-data_path='/dataset/imagenet' \
7+
-eval \
8+
-pretrained='./vit_base_patch16_384'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
2+
python main_multi_gpu.py \
3+
-cfg='./configs/vit_large_patch16_224.yaml' \
4+
-dataset='imagenet2012' \
5+
-batch_size=4 \
6+
-data_path='/dataset/imagenet' \
7+
-eval \
8+
-pretrained='./vit_large_patch16_224'

image_classification/ViT/run_eval_multi.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 \
22
python main_multi_gpu.py \
33
-cfg='./configs/vit_base_patch16_224.yaml' \
44
-dataset='imagenet2012' \
5-
-batch_size=512 \
5+
-batch_size=8 \
66
-data_path='/dataset/imagenet' \
77
-eval \
88
-pretrained='./vit_base_patch16_224' \
+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
CUDA_VISIBLE_DEVICES=4,5,6,7 \
1+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
22
python main_multi_gpu.py \
33
-cfg='./configs/vit_base_patch16_224.yaml' \
44
-dataset='imagenet2012' \
5-
-batch_size=4 \
5+
-batch_size=32 \
66
-data_path='/dataset/imagenet' \
7-
-ngpus=4
7+
-ngpus=8
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
#init
1+
# init

image_classification/ViT/transformer.py

+4
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,7 @@ def forward(self, x):
354354
x, self_attn = self.transformer(x)
355355
logits = self.classifier(x[:, 0]) # take only cls_token as classifier
356356
return logits, self_attn
357+
358+
def flops(self):
359+
flops = 0
360+
flops += self.transformer.flops()

0 commit comments

Comments
 (0)