Skip to content

Commit 5b5653f

Browse files
committed
add fcos mudule in training & validation
1 parent aea55fa commit 5b5653f

File tree

7 files changed

+280
-57
lines changed

7 files changed

+280
-57
lines changed

experiments/coco/hrnet/w32_256x192_adam_lr1e-3.yaml

+4-5
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@ CUDNN:
44
DETERMINISTIC: false
55
ENABLED: true
66
DATA_DIR: ''
7-
GPUS: (0,1)
7+
GPUS: (0,1,2,3)
88
OUTPUT_DIR: 'output'
99
LOG_DIR: 'log'
1010
WORKERS: 24
11-
PRINT_FREQ: 100
12-
11+
PRINT_FREQ: 50
1312
DATASET:
1413
COLOR_RGB: true
1514
DATASET: 'coco'
@@ -23,7 +22,7 @@ DATASET:
2322
TEST_SET: 'val2017'
2423
TRAIN_SET: 'train2017'
2524
MODEL:
26-
INIT_WEIGHTS: true
25+
INIT_WEIGHTS: True
2726
NAME: pose_hrnet
2827
NUM_JOINTS: 17
2928
PRETRAINED: 'models/pytorch/imagenet/hrnet_w32-36af842e.pth'
@@ -91,7 +90,7 @@ MODEL:
9190
LOSS:
9291
USE_TARGET_WEIGHT: true
9392
TRAIN:
94-
BATCH_SIZE_PER_GPU: 32
93+
BATCH_SIZE_PER_GPU: 64
9594
SHUFFLE: true
9695
BEGIN_EPOCH: 0
9796
END_EPOCH: 210

lib/core/function.py

+93-26
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,71 @@
1919
from core.inference import get_final_preds
2020
from utils.transforms import flip_back
2121
from utils.vis import save_debug_images
22-
22+
import pdb
2323

2424
logger = logging.getLogger(__name__)
2525

2626

27-
def train(config, train_loader, model, criterion, optimizer, epoch,
27+
def train(config, train_loader, model, criterion, regress_loss, optimizer, epoch,
2828
output_dir, tb_log_dir, writer_dict):
2929
batch_time = AverageMeter()
3030
data_time = AverageMeter()
31-
losses = AverageMeter()
31+
final_losses = AverageMeter()
32+
reg_losses = AverageMeter()
33+
mse_losses = AverageMeter()
3234
acc = AverageMeter()
3335

3436
# switch to train mode
3537
model.train()
3638

3739
end = time.time()
38-
for i, (input, target, target_weight, meta) in enumerate(train_loader):
40+
for i, (input, target, target_weight, cord, meta) in enumerate(train_loader):
3941
# measure data loading time
4042
data_time.update(time.time() - end)
4143

4244
# compute output
43-
outputs = model(input)
45+
outputs, locs = model(input)
4446

4547
target = target.cuda(non_blocking=True)
4648
target_weight = target_weight.cuda(non_blocking=True)
4749

4850
if isinstance(outputs, list):
49-
loss = criterion(outputs[0], target, target_weight)
50-
for output in outputs[1:]:
51-
loss += criterion(output, target, target_weight)
51+
loc = locs[0]
52+
loc_x = loc[:, 0:17, :, :]
53+
loc_y = loc[:, 17:, :, :]
54+
loc = torch.cat((torch.unsqueeze((loc_x), 0),
55+
torch.unsqueeze((loc_y), 0)))
56+
mse_loss = criterion(outputs[0], target, target_weight)
57+
mse_reg_loss = regress_loss(loc, cord, target_weight)
58+
for output, loc in outputs[1:], locs[1:]:
59+
loc_x = loc[:, 0:17, :, :]
60+
loc_y = loc[:, 17:, :, :]
61+
loc = torch.cat((torch.unsqueeze((loc_x), 0),
62+
torch.unsqueeze((loc_y), 0)))
63+
mse_loss += criterion(output, target, target_weight)
64+
reg_loss += regress_loss(loc, cord, target_weight)
5265
else:
5366
output = outputs
54-
loss = criterion(output, target, target_weight)
67+
loc = locs
68+
loc_x = loc[:, 0:17, :, :]
69+
loc_y = loc[:, 17:, :, :]
70+
loc = torch.cat((torch.unsqueeze((loc_x), 0),
71+
torch.unsqueeze((loc_y), 0)))
72+
mse_loss = criterion(output, target, target_weight)
73+
reg_loss = regress_loss(loc, cord, target_weight)
74+
final_loss = mse_loss+0.00001*reg_loss
5575

5676
# loss = criterion(output, target, target_weight)
5777

5878
# compute gradient and do update step
5979
optimizer.zero_grad()
60-
loss.backward()
80+
final_loss.backward()
6181
optimizer.step()
6282

6383
# measure accuracy and record loss
64-
losses.update(loss.item(), input.size(0))
84+
final_losses.update(final_loss.item(), input.size(0))
85+
reg_losses.update(reg_loss.item(), input.size(0))
86+
mse_losses.update(mse_loss.item(), input.size(0))
6587

6688
_, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
6789
target.detach().cpu().numpy())
@@ -73,19 +95,23 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
7395

7496
if i % config.PRINT_FREQ == 0:
7597
msg = 'Epoch: [{0}][{1}/{2}]\t' \
76-
'Time {batch_time.val:.3f}s ({batch_time.sum:.3f}s)\t' \
98+
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
7799
'Speed {speed:.1f} samples/s\t' \
78100
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
79101
'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
102+
'Reg Loss {Rloss.val:.5f} ({Rloss.avg:.5f})\t' \
103+
'MSE Loss {MSEloss.val:.5f} ({MSEloss.avg:.5f})\t' \
80104
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
81105
epoch, i, len(train_loader), batch_time=batch_time,
82106
speed=input.size(0)/batch_time.val,
83-
data_time=data_time, loss=losses, acc=acc)
107+
data_time=data_time, loss=final_losses, Rloss=reg_losses, MSEloss=mse_losses, acc=acc)
84108
logger.info(msg)
85109

86110
writer = writer_dict['writer']
87111
global_steps = writer_dict['train_global_steps']
88-
writer.add_scalar('train_loss', losses.val, global_steps)
112+
writer.add_scalar('train_loss', final_losses.val, global_steps)
113+
writer.add_scalar('train_reg_loss', reg_losses.val, global_steps)
114+
writer.add_scalar('train_mse_loss', mse_losses.val, global_steps)
89115
writer.add_scalar('train_acc', acc.val, global_steps)
90116
writer_dict['train_global_steps'] = global_steps + 1
91117

@@ -94,10 +120,12 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
94120
prefix)
95121

96122

97-
def validate(config, val_loader, val_dataset, model, criterion, output_dir,
123+
def validate(config, val_loader, val_dataset, model, criterion, regress_loss, output_dir,
98124
tb_log_dir, writer_dict=None):
99125
batch_time = AverageMeter()
100-
losses = AverageMeter()
126+
final_losses = AverageMeter()
127+
reg_losses = AverageMeter()
128+
mse_losses = AverageMeter()
101129
acc = AverageMeter()
102130

103131
# switch to evaluate mode
@@ -115,22 +143,39 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
115143
idx = 0
116144
with torch.no_grad():
117145
end = time.time()
118-
for i, (input, target, target_weight, meta) in enumerate(val_loader):
146+
for i, (input, target, target_weight, cord, meta) in enumerate(val_loader):
119147
# compute output
120-
outputs = model(input)
148+
outputs, locs = model(input)
149+
loc_x = locs[:, 0:17, :, :]
150+
loc_y = locs[:, 17:, :, :]
151+
locs = torch.cat((torch.unsqueeze((loc_x), 0),
152+
torch.unsqueeze((loc_y), 0)))
153+
xlocs = torch.squeeze((locs[0]), 0)
154+
ylocs = torch.squeeze((locs[1]), 0)
155+
121156
if isinstance(outputs, list):
122157
output = outputs[-1]
158+
xloc = xlocs[-1]
159+
yloc = ylocs[-1]
123160
else:
124161
output = outputs
162+
xloc = xlocs
163+
yloc = ylocs
125164

126165
if config.TEST.FLIP_TEST:
127166
input_flipped = input.flip(3)
128-
outputs_flipped = model(input_flipped)
167+
outputs_flipped, locs_flipped = model(input_flipped)
168+
xlocs_flipped = torch.squeeze((locs[0]), 0)
169+
ylocs_flipped = torch.squeeze((locs[1]), 0)
129170

130171
if isinstance(outputs_flipped, list):
131172
output_flipped = outputs_flipped[-1]
173+
xloc_flipped = xlocs_flipped[-1]
174+
yloc_flipped = xlocs_flipped[-1]
132175
else:
133176
output_flipped = outputs_flipped
177+
xloc_flipped = xlocs_flipped
178+
yloc_flipped = ylocs_flipped
134179

135180
output_flipped = flip_back(output_flipped.cpu().numpy(),
136181
val_dataset.flip_pairs)
@@ -142,15 +187,24 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
142187
output_flipped.clone()[:, :, :, 0:-1]
143188

144189
output = (output + output_flipped) * 0.5
190+
xloc = (xloc+xloc_flipped)*0.5
191+
yloc = (yloc+yloc_flipped)*0.5
145192

146193
target = target.cuda(non_blocking=True)
147194
target_weight = target_weight.cuda(non_blocking=True)
148195

149-
loss = criterion(output, target, target_weight)
196+
loc = torch.cat((torch.unsqueeze((xloc), 0),
197+
torch.unsqueeze((yloc), 0)))
198+
199+
mse_loss = criterion(output, target, target_weight)
200+
reg_loss = regress_loss(loc, cord, target_weight)
201+
final_loss = mse_loss+0.01*reg_loss
150202

151203
num_images = input.size(0)
152204
# measure accuracy and record loss
153-
losses.update(loss.item(), num_images)
205+
final_losses.update(final_loss.item(), input.size(0))
206+
reg_losses.update(reg_loss.item(), input.size(0))
207+
mse_losses.update(mse_loss.item(), input.size(0))
154208
_, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
155209
target.cpu().numpy())
156210

@@ -165,7 +219,9 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
165219
score = meta['score'].numpy()
166220

167221
preds, maxvals = get_final_preds(
168-
config, output.clone().cpu().numpy(), c, s)
222+
config, output.clone().cpu().numpy(),
223+
xloc.clone().cpu().numpy(),
224+
yloc.clone().cpu().numpy(), c, s)
169225

170226
all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
171227
all_preds[idx:idx + num_images, :, 2:3] = maxvals
@@ -182,9 +238,11 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
182238
msg = 'Test: [{0}/{1}]\t' \
183239
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
184240
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
241+
'MSE Loss {mse_loss.val:.4f} ({mse_loss.avg:.4f})\t' \
242+
'Reg Loss {reg_loss.val:.4f} ({reg_loss.avg:.4f})\t' \
185243
'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
186244
i, len(val_loader), batch_time=batch_time,
187-
loss=losses, acc=acc)
245+
loss=final_losses, mse_loss=mse_losses, reg_loss=reg_losses, acc=acc)
188246
logger.info(msg)
189247

190248
prefix = '{}_{}'.format(
@@ -209,8 +267,18 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
209267
writer = writer_dict['writer']
210268
global_steps = writer_dict['valid_global_steps']
211269
writer.add_scalar(
212-
'valid_loss',
213-
losses.avg,
270+
'valid_final_loss',
271+
final_losses.avg,
272+
global_steps
273+
)
274+
writer.add_scalar(
275+
'valid_mse_loss',
276+
mse_losses.avg,
277+
global_steps
278+
)
279+
writer.add_scalar(
280+
'valid_reg_loss',
281+
reg_losses.avg,
214282
global_steps
215283
)
216284
writer.add_scalar(
@@ -236,7 +304,6 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
236304
return perf_indicator
237305

238306

239-
# markdown format output
240307
def _print_name_value(name_value, full_arch_name):
241308
names = name_value.keys()
242309
values = name_value.values()

lib/core/inference.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_max_preds(batch_heatmaps):
4646
return preds, maxvals
4747

4848

49-
def get_final_preds(config, batch_heatmaps, center, scale):
49+
def get_final_preds(config, batch_heatmaps, reg_x, reg_y, center, scale):
5050
coords, maxvals = get_max_preds(batch_heatmaps)
5151

5252
heatmap_height = batch_heatmaps.shape[2]
@@ -57,6 +57,12 @@ def get_final_preds(config, batch_heatmaps, center, scale):
5757
for n in range(coords.shape[0]):
5858
for p in range(coords.shape[1]):
5959
hm = batch_heatmaps[n][p]
60+
shift_x = reg_x[n][p][int(
61+
coords[n][p][1])][int(coords[n][p][0])]
62+
shift_y = reg_y[n][p][int(
63+
coords[n][p][1])][int(coords[n][p][0])]
64+
coords[n][p][0] = coords[n][p][0]+shift_x
65+
coords[n][p][1] = coords[n][p][1]+shift_y
6066
px = int(math.floor(coords[n][p][0] + 0.5))
6167
py = int(math.floor(coords[n][p][1] + 0.5))
6268
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:

0 commit comments

Comments
 (0)