19
19
from core .inference import get_final_preds
20
20
from utils .transforms import flip_back
21
21
from utils .vis import save_debug_images
22
-
22
+ import pdb
23
23
24
24
logger = logging .getLogger (__name__ )
25
25
26
26
27
- def train (config , train_loader , model , criterion , optimizer , epoch ,
27
+ def train (config , train_loader , model , criterion , regress_loss , optimizer , epoch ,
28
28
output_dir , tb_log_dir , writer_dict ):
29
29
batch_time = AverageMeter ()
30
30
data_time = AverageMeter ()
31
- losses = AverageMeter ()
31
+ final_losses = AverageMeter ()
32
+ reg_losses = AverageMeter ()
33
+ mse_losses = AverageMeter ()
32
34
acc = AverageMeter ()
33
35
34
36
# switch to train mode
35
37
model .train ()
36
38
37
39
end = time .time ()
38
- for i , (input , target , target_weight , meta ) in enumerate (train_loader ):
40
+ for i , (input , target , target_weight , cord , meta ) in enumerate (train_loader ):
39
41
# measure data loading time
40
42
data_time .update (time .time () - end )
41
43
42
44
# compute output
43
- outputs = model (input )
45
+ outputs , locs = model (input )
44
46
45
47
target = target .cuda (non_blocking = True )
46
48
target_weight = target_weight .cuda (non_blocking = True )
47
49
48
50
if isinstance (outputs , list ):
49
- loss = criterion (outputs [0 ], target , target_weight )
50
- for output in outputs [1 :]:
51
- loss += criterion (output , target , target_weight )
51
+ loc = locs [0 ]
52
+ loc_x = loc [:, 0 :17 , :, :]
53
+ loc_y = loc [:, 17 :, :, :]
54
+ loc = torch .cat ((torch .unsqueeze ((loc_x ), 0 ),
55
+ torch .unsqueeze ((loc_y ), 0 )))
56
+ mse_loss = criterion (outputs [0 ], target , target_weight )
57
+ mse_reg_loss = regress_loss (loc , cord , target_weight )
58
+ for output , loc in outputs [1 :], locs [1 :]:
59
+ loc_x = loc [:, 0 :17 , :, :]
60
+ loc_y = loc [:, 17 :, :, :]
61
+ loc = torch .cat ((torch .unsqueeze ((loc_x ), 0 ),
62
+ torch .unsqueeze ((loc_y ), 0 )))
63
+ mse_loss += criterion (output , target , target_weight )
64
+ reg_loss += regress_loss (loc , cord , target_weight )
52
65
else :
53
66
output = outputs
54
- loss = criterion (output , target , target_weight )
67
+ loc = locs
68
+ loc_x = loc [:, 0 :17 , :, :]
69
+ loc_y = loc [:, 17 :, :, :]
70
+ loc = torch .cat ((torch .unsqueeze ((loc_x ), 0 ),
71
+ torch .unsqueeze ((loc_y ), 0 )))
72
+ mse_loss = criterion (output , target , target_weight )
73
+ reg_loss = regress_loss (loc , cord , target_weight )
74
+ final_loss = mse_loss + 0.00001 * reg_loss
55
75
56
76
# loss = criterion(output, target, target_weight)
57
77
58
78
# compute gradient and do update step
59
79
optimizer .zero_grad ()
60
- loss .backward ()
80
+ final_loss .backward ()
61
81
optimizer .step ()
62
82
63
83
# measure accuracy and record loss
64
- losses .update (loss .item (), input .size (0 ))
84
+ final_losses .update (final_loss .item (), input .size (0 ))
85
+ reg_losses .update (reg_loss .item (), input .size (0 ))
86
+ mse_losses .update (mse_loss .item (), input .size (0 ))
65
87
66
88
_ , avg_acc , cnt , pred = accuracy (output .detach ().cpu ().numpy (),
67
89
target .detach ().cpu ().numpy ())
@@ -73,19 +95,23 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
73
95
74
96
if i % config .PRINT_FREQ == 0 :
75
97
msg = 'Epoch: [{0}][{1}/{2}]\t ' \
76
- 'Time {batch_time.val:.3f}s ({batch_time.sum :.3f}s)\t ' \
98
+ 'Time {batch_time.val:.3f}s ({batch_time.avg :.3f}s)\t ' \
77
99
'Speed {speed:.1f} samples/s\t ' \
78
100
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t ' \
79
101
'Loss {loss.val:.5f} ({loss.avg:.5f})\t ' \
102
+ 'Reg Loss {Rloss.val:.5f} ({Rloss.avg:.5f})\t ' \
103
+ 'MSE Loss {MSEloss.val:.5f} ({MSEloss.avg:.5f})\t ' \
80
104
'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (
81
105
epoch , i , len (train_loader ), batch_time = batch_time ,
82
106
speed = input .size (0 )/ batch_time .val ,
83
- data_time = data_time , loss = losses , acc = acc )
107
+ data_time = data_time , loss = final_losses , Rloss = reg_losses , MSEloss = mse_losses , acc = acc )
84
108
logger .info (msg )
85
109
86
110
writer = writer_dict ['writer' ]
87
111
global_steps = writer_dict ['train_global_steps' ]
88
- writer .add_scalar ('train_loss' , losses .val , global_steps )
112
+ writer .add_scalar ('train_loss' , final_losses .val , global_steps )
113
+ writer .add_scalar ('train_reg_loss' , reg_losses .val , global_steps )
114
+ writer .add_scalar ('train_mse_loss' , mse_losses .val , global_steps )
89
115
writer .add_scalar ('train_acc' , acc .val , global_steps )
90
116
writer_dict ['train_global_steps' ] = global_steps + 1
91
117
@@ -94,10 +120,12 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
94
120
prefix )
95
121
96
122
97
- def validate (config , val_loader , val_dataset , model , criterion , output_dir ,
123
+ def validate (config , val_loader , val_dataset , model , criterion , regress_loss , output_dir ,
98
124
tb_log_dir , writer_dict = None ):
99
125
batch_time = AverageMeter ()
100
- losses = AverageMeter ()
126
+ final_losses = AverageMeter ()
127
+ reg_losses = AverageMeter ()
128
+ mse_losses = AverageMeter ()
101
129
acc = AverageMeter ()
102
130
103
131
# switch to evaluate mode
@@ -115,22 +143,39 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
115
143
idx = 0
116
144
with torch .no_grad ():
117
145
end = time .time ()
118
- for i , (input , target , target_weight , meta ) in enumerate (val_loader ):
146
+ for i , (input , target , target_weight , cord , meta ) in enumerate (val_loader ):
119
147
# compute output
120
- outputs = model (input )
148
+ outputs , locs = model (input )
149
+ loc_x = locs [:, 0 :17 , :, :]
150
+ loc_y = locs [:, 17 :, :, :]
151
+ locs = torch .cat ((torch .unsqueeze ((loc_x ), 0 ),
152
+ torch .unsqueeze ((loc_y ), 0 )))
153
+ xlocs = torch .squeeze ((locs [0 ]), 0 )
154
+ ylocs = torch .squeeze ((locs [1 ]), 0 )
155
+
121
156
if isinstance (outputs , list ):
122
157
output = outputs [- 1 ]
158
+ xloc = xlocs [- 1 ]
159
+ yloc = ylocs [- 1 ]
123
160
else :
124
161
output = outputs
162
+ xloc = xlocs
163
+ yloc = ylocs
125
164
126
165
if config .TEST .FLIP_TEST :
127
166
input_flipped = input .flip (3 )
128
- outputs_flipped = model (input_flipped )
167
+ outputs_flipped , locs_flipped = model (input_flipped )
168
+ xlocs_flipped = torch .squeeze ((locs [0 ]), 0 )
169
+ ylocs_flipped = torch .squeeze ((locs [1 ]), 0 )
129
170
130
171
if isinstance (outputs_flipped , list ):
131
172
output_flipped = outputs_flipped [- 1 ]
173
+ xloc_flipped = xlocs_flipped [- 1 ]
174
+ yloc_flipped = xlocs_flipped [- 1 ]
132
175
else :
133
176
output_flipped = outputs_flipped
177
+ xloc_flipped = xlocs_flipped
178
+ yloc_flipped = ylocs_flipped
134
179
135
180
output_flipped = flip_back (output_flipped .cpu ().numpy (),
136
181
val_dataset .flip_pairs )
@@ -142,15 +187,24 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
142
187
output_flipped .clone ()[:, :, :, 0 :- 1 ]
143
188
144
189
output = (output + output_flipped ) * 0.5
190
+ xloc = (xloc + xloc_flipped )* 0.5
191
+ yloc = (yloc + yloc_flipped )* 0.5
145
192
146
193
target = target .cuda (non_blocking = True )
147
194
target_weight = target_weight .cuda (non_blocking = True )
148
195
149
- loss = criterion (output , target , target_weight )
196
+ loc = torch .cat ((torch .unsqueeze ((xloc ), 0 ),
197
+ torch .unsqueeze ((yloc ), 0 )))
198
+
199
+ mse_loss = criterion (output , target , target_weight )
200
+ reg_loss = regress_loss (loc , cord , target_weight )
201
+ final_loss = mse_loss + 0.01 * reg_loss
150
202
151
203
num_images = input .size (0 )
152
204
# measure accuracy and record loss
153
- losses .update (loss .item (), num_images )
205
+ final_losses .update (final_loss .item (), input .size (0 ))
206
+ reg_losses .update (reg_loss .item (), input .size (0 ))
207
+ mse_losses .update (mse_loss .item (), input .size (0 ))
154
208
_ , avg_acc , cnt , pred = accuracy (output .cpu ().numpy (),
155
209
target .cpu ().numpy ())
156
210
@@ -165,7 +219,9 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
165
219
score = meta ['score' ].numpy ()
166
220
167
221
preds , maxvals = get_final_preds (
168
- config , output .clone ().cpu ().numpy (), c , s )
222
+ config , output .clone ().cpu ().numpy (),
223
+ xloc .clone ().cpu ().numpy (),
224
+ yloc .clone ().cpu ().numpy (), c , s )
169
225
170
226
all_preds [idx :idx + num_images , :, 0 :2 ] = preds [:, :, 0 :2 ]
171
227
all_preds [idx :idx + num_images , :, 2 :3 ] = maxvals
@@ -182,9 +238,11 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
182
238
msg = 'Test: [{0}/{1}]\t ' \
183
239
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' \
184
240
'Loss {loss.val:.4f} ({loss.avg:.4f})\t ' \
241
+ 'MSE Loss {mse_loss.val:.4f} ({mse_loss.avg:.4f})\t ' \
242
+ 'Reg Loss {reg_loss.val:.4f} ({reg_loss.avg:.4f})\t ' \
185
243
'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (
186
244
i , len (val_loader ), batch_time = batch_time ,
187
- loss = losses , acc = acc )
245
+ loss = final_losses , mse_loss = mse_losses , reg_loss = reg_losses , acc = acc )
188
246
logger .info (msg )
189
247
190
248
prefix = '{}_{}' .format (
@@ -209,8 +267,18 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
209
267
writer = writer_dict ['writer' ]
210
268
global_steps = writer_dict ['valid_global_steps' ]
211
269
writer .add_scalar (
212
- 'valid_loss' ,
213
- losses .avg ,
270
+ 'valid_final_loss' ,
271
+ final_losses .avg ,
272
+ global_steps
273
+ )
274
+ writer .add_scalar (
275
+ 'valid_mse_loss' ,
276
+ mse_losses .avg ,
277
+ global_steps
278
+ )
279
+ writer .add_scalar (
280
+ 'valid_reg_loss' ,
281
+ reg_losses .avg ,
214
282
global_steps
215
283
)
216
284
writer .add_scalar (
@@ -236,7 +304,6 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
236
304
return perf_indicator
237
305
238
306
239
- # markdown format output
240
307
def _print_name_value (name_value , full_arch_name ):
241
308
names = name_value .keys ()
242
309
values = name_value .values ()
0 commit comments