17
17
import torchvision .transforms as transforms
18
18
from torchvision .datasets import ImageFolder
19
19
20
+ from models .SmoothLabelCriterion import SmoothLabelCritierion
21
+ from warmup_scheduler import GradualWarmupScheduler
20
22
from utils import util
21
23
from utils import metrics
22
24
from models .resnet import res_net
23
25
from models .densenet import dense_net
24
26
25
27
26
28
def flops_params ():
27
- for name in ['densenet_201 ' , 'resnet-101_v2 ' ]:
28
- if name == 'densenet_201 ' :
29
- model = dense_net .densenet201 ()
29
+ for name in ['densenet_121 ' , 'resnet-34 ' ]:
30
+ if name == 'densenet_121 ' :
31
+ model = dense_net .densenet121 ()
30
32
else :
31
- model = res_net .resnet101_v2 ()
33
+ model = res_net .resnet34_v2 ()
32
34
gflops , params_size = metrics .compute_num_flops (model )
33
35
print ('{}: {:.3f} GFlops - {:.3f} MB' .format (name , gflops , params_size ))
34
36
35
37
36
38
def load_data (data_root_dir ):
37
- transform = transforms .Compose ([
38
- # transforms.ToPILImage(),
39
+ train_transform = transforms .Compose ([
39
40
transforms .Resize (256 ),
40
- transforms .RandomCrop (( 224 , 224 ) ),
41
+ transforms .RandomCrop (224 ),
41
42
transforms .RandomHorizontalFlip (),
43
+ transforms .ColorJitter (brightness = 0.1 , contrast = 0.1 , saturation = 0.1 , hue = 0.1 ),
42
44
transforms .ToTensor (),
45
+ transforms .RandomErasing (),
43
46
transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
44
47
])
45
48
49
+ # 测试阶段 Ten Crop test
50
+ test_transform = transforms .Compose ([
51
+ transforms .Resize (256 ),
52
+ transforms .TenCrop (224 ),
53
+ transforms .Lambda (lambda crops : torch .stack ([transforms .ToTensor ()(crop ) for crop in crops ])),
54
+ transforms .Lambda (lambda crops : torch .stack (
55
+ [transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))(crop ) for crop in crops ]))
56
+ ])
57
+
46
58
data_loaders = {}
47
59
data_sizes = {}
48
60
for name in ['train' , 'test' ]:
49
61
data_dir = os .path .join (data_root_dir , name + '_imgs' )
50
62
# print(data_dir)
51
63
52
- data_set = ImageFolder (data_dir , transform = transform )
53
- data_loader = DataLoader (data_set , batch_size = 96 , shuffle = True , num_workers = 8 )
64
+ if name == 'train' :
65
+ data_set = ImageFolder (data_dir , transform = train_transform )
66
+ data_loader = DataLoader (data_set , batch_size = 96 , shuffle = True , num_workers = 8 )
67
+ else :
68
+ data_set = ImageFolder (data_dir , transform = test_transform )
69
+ data_loader = DataLoader (data_set , batch_size = 48 , shuffle = True , num_workers = 8 )
54
70
data_loaders [name ] = data_loader
55
71
data_sizes [name ] = len (data_set )
56
72
return data_loaders , data_sizes
@@ -69,7 +85,7 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
69
85
top5_acc_dict = {'train' : [], 'test' : []}
70
86
for epoch in range (num_epochs ):
71
87
72
- print ('{} - Epoch {}/{}' .format (model_name , epoch , num_epochs - 1 ))
88
+ print ('{} - Epoch {}/{}' .format (model_name , epoch + 1 , num_epochs ))
73
89
print ('-' * 10 )
74
90
75
91
# Each epoch has a training and test phase
@@ -95,7 +111,12 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
95
111
# forward
96
112
# track history if only in train
97
113
with torch .set_grad_enabled (phase == 'train' ):
98
- outputs = model (inputs )
114
+ if phase == 'test' :
115
+ N , N_crops , C , H , W = inputs .size ()
116
+ result = model (inputs .view (- 1 , C , H , W )) # fuse batch size and ncrops
117
+ outputs = result .view (N , N_crops , - 1 ).mean (1 ) # avg over crops
118
+ else :
119
+ outputs = model (inputs )
99
120
# print(outputs.shape)
100
121
# _, preds = torch.max(outputs, 1)
101
122
loss = criterion (outputs , labels )
@@ -115,6 +136,7 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
115
136
# running_corrects += torch.sum(preds == labels.data)
116
137
if phase == 'train' :
117
138
lr_scheduler .step ()
139
+ print ('lr: {}' .format (optimizer .param_groups [0 ]['lr' ]))
118
140
119
141
epoch_loss = running_loss / data_sizes [phase ]
120
142
epoch_top1_acc = running_top1_acc / len (data_loaders [phase ])
@@ -134,9 +156,10 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
134
156
if phase == 'test' and epoch_top5_acc > best_top5_acc :
135
157
best_top5_acc = epoch_top5_acc
136
158
137
- # 每训练一轮就保存
138
- # util.save_model(model.cpu(), '../data/models/%s_%d.pth' % (model_name, epoch))
139
- # model = model.to(device)
159
+ # 每训练10轮保存一次
160
+ if (epoch + 1 ) % 10 == 0 :
161
+ util .save_model (model .cpu (), '../data/models/%s_%d.pth' % (model_name , epoch + 1 ))
162
+ model = model .to (device )
140
163
141
164
time_elapsed = time .time () - since
142
165
print ('Training {} complete in {:.0f}m {:.0f}s' .format (model_name , time_elapsed // 60 , time_elapsed % 60 ))
@@ -162,22 +185,26 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
162
185
res_top1_acc = dict ()
163
186
res_top5_acc = dict ()
164
187
num_classes = 20
165
- for name in ['densenet_201' , 'resnet-101_v2' ]:
166
- if name == 'densenet_201' :
167
- model = dense_net .densenet201 (num_classes = num_classes )
188
+ num_epochs = 100
189
+ for name in ['densenet_121' , 'resnet-34' ]:
190
+ if name == 'densenet_121' :
191
+ model = dense_net .densenet121 (num_classes = num_classes )
168
192
else :
169
- model = res_net .resnet101_v2 (num_classes = num_classes )
193
+ model = res_net .resnet34_v2 (num_classes = num_classes )
170
194
model .eval ()
171
195
# print(model)
172
196
model = model .to (device )
173
197
174
- criterion = nn .CrossEntropyLoss ()
175
- optimizer = optim .Adam (model .parameters (), lr = 1e-3 , weight_decay = 1e-4 )
176
- lr_schduler = optim .lr_scheduler .StepLR (optimizer , step_size = 7 , gamma = 0.96 )
198
+ criterion = SmoothLabelCritierion (label_smoothing = 0.1 )
199
+ # criterion = nn.CrossEntropyLoss()
200
+ optimizer = optim .Adam (model .parameters (), lr = 3e-4 , weight_decay = 3e-5 )
201
+ scheduler = optim .lr_scheduler .CosineAnnealingLR (optimizer , num_epochs - 5 , eta_min = 0 )
202
+ lr_scheduler = GradualWarmupScheduler (optimizer , multiplier = 1 , total_epoch = 5 , after_scheduler = scheduler )
177
203
178
204
util .check_dir ('../data/models/' )
179
205
best_model , loss_dict , top1_acc_dict , top5_acc_dict = train_model (
180
- data_loaders , data_sizes , name , model , criterion , optimizer , lr_schduler , num_epochs = 50 , device = device )
206
+ data_loaders , data_sizes , name , model , criterion , optimizer , lr_scheduler ,
207
+ num_epochs = num_epochs , device = device )
181
208
# 保存最好的模型参数
182
209
# util.save_model(best_model.cpu(), '../data/models/best_%s.pth' % name)
183
210
0 commit comments