15
15
import numpy as np
16
16
import torch
17
17
18
- from core .evaluate import accuracy
18
+ from core .evaluate import accuracy , accuracy_classification , accuracy_landmark
19
19
from core .inference import get_final_preds
20
20
from utils .transforms import flip_back
21
21
from utils .vis import save_result_images , save_debug_images
@@ -29,6 +29,9 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
29
29
batch_time = AverageMeter ()
30
30
data_time = AverageMeter ()
31
31
losses = AverageMeter ()
32
+ loss_classifier = AverageMeter ()
33
+ loss_heatmap = AverageMeter ()
34
+ loss_landmark = AverageMeter ()
32
35
acc = AverageMeter ()
33
36
34
37
# switch to train mode
@@ -40,19 +43,27 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
40
43
data_time .update (time .time () - end )
41
44
42
45
# compute output
43
- outputs = model (input )
46
+ heatmap , classification , landmark = model (input )
44
47
45
48
target = target .cuda (non_blocking = True )
46
49
target_weight = target_weight .cuda (non_blocking = True )
47
50
48
- if isinstance (outputs , list ):
49
- loss = criterion ( outputs [0 ], target , target_weight )
50
- for output in outputs [1 :]:
51
- loss += criterion (output , target , target_weight )
51
+ if isinstance (heatmap , list ):
52
+ heatloss = criterion [ 0 ]( heatmap [0 ], target , target_weight )
53
+ for output in heatmap [1 :]:
54
+ heatloss += criterion [ 0 ] (output , target , target_weight )
52
55
else :
53
- output = outputs
54
- loss = criterion (output , target , target_weight )
56
+ output = heatmap
57
+ heatloss = criterion [ 0 ] (output , target , target_weight )
55
58
59
+ #target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True).view(classification.size(0),-1)
60
+ target2 = meta ["visible" ].type (torch .FloatTensor ).cuda (non_blocking = True )
61
+ classloss = criterion [1 ](classification , target2 )
62
+
63
+ target3 = meta ["joints" ].reshape (- 1 ,64 ).type (torch .FloatTensor ).cuda (non_blocking = True )
64
+ lmloss = criterion [2 ](landmark , target3 )
65
+
66
+ loss = config .TRAIN .LOSS_WEIGHT [1 ]* classloss + config .TRAIN .LOSS_WEIGHT [2 ]* lmloss
56
67
# loss = criterion(output, target, target_weight)
57
68
58
69
# compute gradient and do update step
@@ -62,6 +73,9 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
62
73
63
74
# measure accuracy and record loss
64
75
losses .update (loss .item (), input .size (0 ))
76
+ loss_classifier .update (classloss .item (), input .size (0 ))
77
+ loss_heatmap .update (heatloss .item (), input .size (0 ))
78
+ loss_landmark .update (lmloss .item (), input .size (0 ))
65
79
66
80
_ , avg_acc , cnt , pred = accuracy (output .detach ().cpu ().numpy (),
67
81
target .detach ().cpu ().numpy ())
@@ -76,19 +90,15 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
76
90
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t ' \
77
91
'Speed {speed:.1f} samples/s\t ' \
78
92
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t ' \
79
- 'Loss {loss.val:.5f} ({loss.avg:.5f})\t ' \
80
- 'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (
93
+ 'Loss {loss.val:.5f} ({loss.avg:.5f}) ({classific.avg: .5f}+{lm.avg: .5f}) \t ' \
94
+ 'Accuracy(heatmap) {acc.val:.3f} ({acc.avg:.3f})' .format (
81
95
epoch , i , len (train_loader ), batch_time = batch_time ,
82
96
speed = input .size (0 )/ batch_time .val ,
83
- data_time = data_time , loss = losses , acc = acc )
97
+ data_time = data_time ,
98
+ loss = losses , classific = loss_classifier , lm = loss_landmark ,
99
+ acc = acc )
84
100
logger .info (msg )
85
101
86
- writer = writer_dict ['writer' ]
87
- global_steps = writer_dict ['train_global_steps' ]
88
- writer .add_scalar ('train_loss' , losses .val , global_steps )
89
- writer .add_scalar ('train_acc' , acc .val , global_steps )
90
- writer_dict ['train_global_steps' ] = global_steps + 1
91
-
92
102
prefix = '{}_{}' .format (os .path .join (output_dir , 'train' ), i )
93
103
save_debug_images (config , input , meta , target , pred * 4 , output ,
94
104
prefix )
@@ -117,11 +127,11 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
117
127
end = time .time ()
118
128
for i , (input , target , target_weight , meta ) in enumerate (val_loader ):
119
129
# compute output
120
- outputs = model (input )
121
- if isinstance (outputs , list ):
122
- output = outputs [- 1 ]
130
+ heatmap , classification , landmark = model (input )
131
+ if isinstance (heatmap , list ):
132
+ output = heatmap [- 1 ]
123
133
else :
124
- output = outputs
134
+ output = heatmap
125
135
126
136
if config .TEST .FLIP_TEST :
127
137
input_flipped = input .flip (3 )
@@ -147,7 +157,11 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
147
157
target = target .cuda (non_blocking = True )
148
158
target_weight = target_weight .cuda (non_blocking = True )
149
159
150
- loss = criterion (output , target , target_weight )
160
+ target2 = meta ["visible" ].type (torch .FloatTensor ).cuda (non_blocking = True )
161
+ target3 = meta ["joints" ].reshape (- 1 , 64 ).type (torch .FloatTensor ).cuda (non_blocking = True )
162
+
163
+ loss = config .TRAIN .LOSS_WEIGHT [1 ]* criterion [1 ](classification , target2 ) \
164
+ + config .TRAIN .LOSS_WEIGHT [2 ] * criterion [2 ](landmark , target3 )
151
165
152
166
num_images = input .size (0 )
153
167
# measure accuracy and record loss
@@ -201,6 +215,7 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
201
215
batch_time = AverageMeter ()
202
216
losses = AverageMeter ()
203
217
acc = AverageMeter ()
218
+ acc_clas = AverageMeter ()
204
219
205
220
# switch to evaluate mode
206
221
model .eval ()
@@ -219,11 +234,11 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
219
234
end = time .time ()
220
235
for i , (input , target , target_weight , meta ) in enumerate (val_loader ):
221
236
# compute output
222
- outputs = model (input )
223
- if isinstance (outputs , list ):
224
- output = outputs [- 1 ]
237
+ heatmap , classification = model (input )
238
+ if isinstance (heatmap , list ):
239
+ output = heatmap [- 1 ]
225
240
else :
226
- output = outputs
241
+ output = heatmap
227
242
228
243
if config .TEST .FLIP_TEST :
229
244
input_flipped = input .flip (3 )
@@ -249,16 +264,21 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
249
264
target = target .cuda (non_blocking = True )
250
265
target_weight = target_weight .cuda (non_blocking = True )
251
266
252
- loss = criterion (output , target , target_weight )
267
+ target_class = meta ["visible" ].type (torch .FloatTensor ).cuda (non_blocking = True )
268
+
269
+ loss = config .TRAIN .LOSS_WEIGHT [0 ]* criterion [0 ](output , target , target_weight ) + criterion [1 ](classification ,target_class )
253
270
254
271
num_images = input .size (0 )
255
272
# measure accuracy and record loss
256
273
losses .update (loss .item (), num_images )
257
274
_ , avg_acc , cnt , pred = accuracy (output .cpu ().numpy (),
258
275
target .cpu ().numpy ())
259
-
260
276
acc .update (avg_acc , cnt )
261
277
278
+ avg_acc , cnt = accuracy_classification (classification .cpu ().numpy (),
279
+ target_class .cpu ().numpy ())
280
+ acc_clas .update (avg_acc , cnt )
281
+
262
282
# measure elapsed time
263
283
batch_time .update (time .time () - end )
264
284
end = time .time ()
@@ -285,9 +305,10 @@ def test(config, val_loader, val_dataset, model, criterion, output_dir,
285
305
msg = 'Test: [{0}/{1}]\t ' \
286
306
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' \
287
307
'Loss {loss.val:.4f} ({loss.avg:.4f})\t ' \
288
- 'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (
308
+ 'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t ' \
309
+ 'Accuracy {acc2.val:.3f} ({acc2.avg:.3f})' .format (
289
310
i , len (val_loader ), batch_time = batch_time ,
290
- loss = losses , acc = acc )
311
+ loss = losses , acc = acc , acc2 = acc_clas )
291
312
logger .info (msg )
292
313
293
314
prefix = os .path .join (output_dir , 'result' )
0 commit comments