Skip to content

Commit 37ef91e

Browse files
committed
perf(densenet): 更新数据预处理、训练模型以及训练参数
1 parent 06871b1 commit 37ef91e

File tree

1 file changed

+49
-22
lines changed

1 file changed

+49
-22
lines changed

py/densenet/train.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,56 @@
1717
import torchvision.transforms as transforms
1818
from torchvision.datasets import ImageFolder
1919

20+
from models.SmoothLabelCriterion import SmoothLabelCritierion
21+
from warmup_scheduler import GradualWarmupScheduler
2022
from utils import util
2123
from utils import metrics
2224
from models.resnet import res_net
2325
from models.densenet import dense_net
2426

2527

2628
def flops_params():
27-
for name in ['densenet_201', 'resnet-101_v2']:
28-
if name == 'densenet_201':
29-
model = dense_net.densenet201()
29+
for name in ['densenet_121', 'resnet-34']:
30+
if name == 'densenet_121':
31+
model = dense_net.densenet121()
3032
else:
31-
model = res_net.resnet101_v2()
33+
model = res_net.resnet34_v2()
3234
gflops, params_size = metrics.compute_num_flops(model)
3335
print('{}: {:.3f} GFlops - {:.3f} MB'.format(name, gflops, params_size))
3436

3537

3638
def load_data(data_root_dir):
37-
transform = transforms.Compose([
38-
# transforms.ToPILImage(),
39+
train_transform = transforms.Compose([
3940
transforms.Resize(256),
40-
transforms.RandomCrop((224, 224)),
41+
transforms.RandomCrop(224),
4142
transforms.RandomHorizontalFlip(),
43+
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
4244
transforms.ToTensor(),
45+
transforms.RandomErasing(),
4346
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
4447
])
4548

49+
# 测试阶段 Ten Crop test
50+
test_transform = transforms.Compose([
51+
transforms.Resize(256),
52+
transforms.TenCrop(224),
53+
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
54+
transforms.Lambda(lambda crops: torch.stack(
55+
[transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(crop) for crop in crops]))
56+
])
57+
4658
data_loaders = {}
4759
data_sizes = {}
4860
for name in ['train', 'test']:
4961
data_dir = os.path.join(data_root_dir, name + '_imgs')
5062
# print(data_dir)
5163

52-
data_set = ImageFolder(data_dir, transform=transform)
53-
data_loader = DataLoader(data_set, batch_size=96, shuffle=True, num_workers=8)
64+
if name == 'train':
65+
data_set = ImageFolder(data_dir, transform=train_transform)
66+
data_loader = DataLoader(data_set, batch_size=96, shuffle=True, num_workers=8)
67+
else:
68+
data_set = ImageFolder(data_dir, transform=test_transform)
69+
data_loader = DataLoader(data_set, batch_size=48, shuffle=True, num_workers=8)
5470
data_loaders[name] = data_loader
5571
data_sizes[name] = len(data_set)
5672
return data_loaders, data_sizes
@@ -69,7 +85,7 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
6985
top5_acc_dict = {'train': [], 'test': []}
7086
for epoch in range(num_epochs):
7187

72-
print('{} - Epoch {}/{}'.format(model_name, epoch, num_epochs - 1))
88+
print('{} - Epoch {}/{}'.format(model_name, epoch + 1, num_epochs))
7389
print('-' * 10)
7490

7591
# Each epoch has a training and test phase
@@ -95,7 +111,12 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
95111
# forward
96112
# track history if only in train
97113
with torch.set_grad_enabled(phase == 'train'):
98-
outputs = model(inputs)
114+
if phase == 'test':
115+
N, N_crops, C, H, W = inputs.size()
116+
result = model(inputs.view(-1, C, H, W)) # fuse batch size and ncrops
117+
outputs = result.view(N, N_crops, -1).mean(1) # avg over crops
118+
else:
119+
outputs = model(inputs)
99120
# print(outputs.shape)
100121
# _, preds = torch.max(outputs, 1)
101122
loss = criterion(outputs, labels)
@@ -115,6 +136,7 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
115136
# running_corrects += torch.sum(preds == labels.data)
116137
if phase == 'train':
117138
lr_scheduler.step()
139+
print('lr: {}'.format(optimizer.param_groups[0]['lr']))
118140

119141
epoch_loss = running_loss / data_sizes[phase]
120142
epoch_top1_acc = running_top1_acc / len(data_loaders[phase])
@@ -134,9 +156,10 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
134156
if phase == 'test' and epoch_top5_acc > best_top5_acc:
135157
best_top5_acc = epoch_top5_acc
136158

137-
# 每训练一轮就保存
138-
# util.save_model(model.cpu(), '../data/models/%s_%d.pth' % (model_name, epoch))
139-
# model = model.to(device)
159+
# 每训练10轮保存一次
160+
if (epoch + 1) % 10 == 0:
161+
util.save_model(model.cpu(), '../data/models/%s_%d.pth' % (model_name, epoch + 1))
162+
model = model.to(device)
140163

141164
time_elapsed = time.time() - since
142165
print('Training {} complete in {:.0f}m {:.0f}s'.format(model_name, time_elapsed // 60, time_elapsed % 60))
@@ -162,22 +185,26 @@ def train_model(data_loaders, data_sizes, model_name, model, criterion, optimize
162185
res_top1_acc = dict()
163186
res_top5_acc = dict()
164187
num_classes = 20
165-
for name in ['densenet_201', 'resnet-101_v2']:
166-
if name == 'densenet_201':
167-
model = dense_net.densenet201(num_classes=num_classes)
188+
num_epochs = 100
189+
for name in ['densenet_121', 'resnet-34']:
190+
if name == 'densenet_121':
191+
model = dense_net.densenet121(num_classes=num_classes)
168192
else:
169-
model = res_net.resnet101_v2(num_classes=num_classes)
193+
model = res_net.resnet34_v2(num_classes=num_classes)
170194
model.eval()
171195
# print(model)
172196
model = model.to(device)
173197

174-
criterion = nn.CrossEntropyLoss()
175-
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
176-
lr_schduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.96)
198+
criterion = SmoothLabelCritierion(label_smoothing=0.1)
199+
# criterion = nn.CrossEntropyLoss()
200+
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=3e-5)
201+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs - 5, eta_min=0)
202+
lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler)
177203

178204
util.check_dir('../data/models/')
179205
best_model, loss_dict, top1_acc_dict, top5_acc_dict = train_model(
180-
data_loaders, data_sizes, name, model, criterion, optimizer, lr_schduler, num_epochs=50, device=device)
206+
data_loaders, data_sizes, name, model, criterion, optimizer, lr_scheduler,
207+
num_epochs=num_epochs, device=device)
181208
# 保存最好的模型参数
182209
# util.save_model(best_model.cpu(), '../data/models/best_%s.pth' % name)
183210

0 commit comments

Comments
 (0)