4
4
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
5
# ------------------------------------------------------------------------------
6
6
7
- from __future__ import absolute_import
8
- from __future__ import division
9
- from __future__ import print_function
10
-
11
- import time
7
+ from __future__ import absolute_import , division , print_function
8
+
12
9
import logging
13
10
import os
11
+ import time
14
12
15
13
import numpy as np
16
14
import torch
17
-
18
15
from core .evaluate import accuracy
19
16
from core .inference import get_final_preds
20
17
from utils .transforms import flip_back
21
18
from utils .vis import save_debug_images
22
19
23
-
24
20
logger = logging .getLogger (__name__ )
25
21
26
22
27
- def train (config , train_loader , model , criterion , optimizer , epoch ,
28
- output_dir , tb_log_dir , writer_dict ):
23
+ def train (
24
+ config ,
25
+ train_loader ,
26
+ model ,
27
+ criterion ,
28
+ optimizer ,
29
+ epoch ,
30
+ output_dir ,
31
+ tb_log_dir ,
32
+ writer_dict ,
33
+ ):
29
34
batch_time = AverageMeter ()
30
35
data_time = AverageMeter ()
31
36
losses = AverageMeter ()
32
37
acc = AverageMeter ()
33
38
34
39
# switch to train mode
35
40
model .train ()
41
+ # freeze specified layers
42
+ extra = config .MODEL .EXTRA
43
+ if "FREEZE_LAYERS" in extra and extra ["FREEZE_LAYERS" ]:
44
+ frozen_layers = extra .FROZEN_LAYERS
45
+ for layer in frozen_layers :
46
+ eval ("model.module." + layer + ".requires_grad_(False)" )
36
47
37
48
end = time .time ()
38
49
for i , (input , target , target_weight , meta ) in enumerate (train_loader ):
@@ -63,39 +74,55 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
63
74
# measure accuracy and record loss
64
75
losses .update (loss .item (), input .size (0 ))
65
76
66
- _ , avg_acc , cnt , pred = accuracy (output .detach ().cpu ().numpy (),
67
- target .detach ().cpu ().numpy ())
77
+ _ , avg_acc , cnt , pred = accuracy (
78
+ output .detach ().cpu ().numpy (), target .detach ().cpu ().numpy ()
79
+ )
68
80
acc .update (avg_acc , cnt )
69
81
70
82
# measure elapsed time
71
83
batch_time .update (time .time () - end )
72
84
end = time .time ()
73
85
74
86
if i % config .PRINT_FREQ == 0 :
75
- msg = 'Epoch: [{0}][{1}/{2}]\t ' \
76
- 'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t ' \
77
- 'Speed {speed:.1f} samples/s\t ' \
78
- 'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t ' \
79
- 'Loss {loss.val:.5f} ({loss.avg:.5f})\t ' \
80
- 'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (
81
- epoch , i , len (train_loader ), batch_time = batch_time ,
82
- speed = input .size (0 )/ batch_time .val ,
83
- data_time = data_time , loss = losses , acc = acc )
87
+ msg = (
88
+ "Epoch: [{0}][{1}/{2}]\t "
89
+ "Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t "
90
+ "Speed {speed:.1f} samples/s\t "
91
+ "Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t "
92
+ "Loss {loss.val:.5f} ({loss.avg:.5f})\t "
93
+ "Accuracy {acc.val:.3f} ({acc.avg:.3f})" .format (
94
+ epoch ,
95
+ i ,
96
+ len (train_loader ),
97
+ batch_time = batch_time ,
98
+ speed = input .size (0 ) / batch_time .val ,
99
+ data_time = data_time ,
100
+ loss = losses ,
101
+ acc = acc ,
102
+ )
103
+ )
84
104
logger .info (msg )
85
105
86
- writer = writer_dict ['writer' ]
87
- global_steps = writer_dict ['train_global_steps' ]
88
- writer .add_scalar ('train_loss' , losses .val , global_steps )
89
- writer .add_scalar ('train_acc' , acc .val , global_steps )
90
- writer_dict ['train_global_steps' ] = global_steps + 1
91
-
92
- prefix = '{}_{}' .format (os .path .join (output_dir , 'train' ), i )
93
- save_debug_images (config , input , meta , target , pred * 4 , output ,
94
- prefix )
95
-
96
-
97
- def validate (config , val_loader , val_dataset , model , criterion , output_dir ,
98
- tb_log_dir , writer_dict = None ):
106
+ writer = writer_dict ["writer" ]
107
+ global_steps = writer_dict ["train_global_steps" ]
108
+ writer .add_scalar ("train_loss" , losses .val , global_steps )
109
+ writer .add_scalar ("train_acc" , acc .val , global_steps )
110
+ writer_dict ["train_global_steps" ] = global_steps + 1
111
+
112
+ prefix = "{}_{}" .format (os .path .join (output_dir , "train" ), i )
113
+ save_debug_images (config , input , meta , target , pred * 4 , output , prefix )
114
+
115
+
116
+ def validate (
117
+ config ,
118
+ val_loader ,
119
+ val_dataset ,
120
+ model ,
121
+ criterion ,
122
+ output_dir ,
123
+ tb_log_dir ,
124
+ writer_dict = None ,
125
+ ):
99
126
batch_time = AverageMeter ()
100
127
losses = AverageMeter ()
101
128
acc = AverageMeter ()
@@ -104,10 +131,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
104
131
model .eval ()
105
132
106
133
num_samples = len (val_dataset )
107
- all_preds = np .zeros (
108
- (num_samples , config .MODEL .NUM_JOINTS , 3 ),
109
- dtype = np .float32
110
- )
134
+ all_preds = np .zeros ((num_samples , config .MODEL .NUM_JOINTS , 3 ), dtype = np .float32 )
111
135
all_boxes = np .zeros ((num_samples , 6 ))
112
136
image_path = []
113
137
filenames = []
@@ -132,15 +156,14 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
132
156
else :
133
157
output_flipped = outputs_flipped
134
158
135
- output_flipped = flip_back (output_flipped .cpu ().numpy (),
136
- val_dataset .flip_pairs )
159
+ output_flipped = flip_back (
160
+ output_flipped .cpu ().numpy (), val_dataset .flip_pairs
161
+ )
137
162
output_flipped = torch .from_numpy (output_flipped .copy ()).cuda ()
138
163
139
-
140
164
# feature is not aligned, shift flipped heatmap for higher accuracy
141
165
if config .TEST .SHIFT_HEATMAP :
142
- output_flipped [:, :, :, 1 :] = \
143
- output_flipped .clone ()[:, :, :, 0 :- 1 ]
166
+ output_flipped [:, :, :, 1 :] = output_flipped .clone ()[:, :, :, 0 :- 1 ]
144
167
145
168
output = (output + output_flipped ) * 0.5
146
169
@@ -152,51 +175,47 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
152
175
num_images = input .size (0 )
153
176
# measure accuracy and record loss
154
177
losses .update (loss .item (), num_images )
155
- _ , avg_acc , cnt , pred = accuracy (output .cpu ().numpy (),
156
- target .cpu ().numpy ())
178
+ _ , avg_acc , cnt , pred = accuracy (output .cpu ().numpy (), target .cpu ().numpy ())
157
179
158
180
acc .update (avg_acc , cnt )
159
181
160
182
# measure elapsed time
161
183
batch_time .update (time .time () - end )
162
184
end = time .time ()
163
185
164
- c = meta [' center' ].numpy ()
165
- s = meta [' scale' ].numpy ()
166
- score = meta [' score' ].numpy ()
186
+ c = meta [" center" ].numpy ()
187
+ s = meta [" scale" ].numpy ()
188
+ score = meta [" score" ].numpy ()
167
189
168
- preds , maxvals = get_final_preds (
169
- config , output .clone ().cpu ().numpy (), c , s )
190
+ preds , maxvals = get_final_preds (config , output .clone ().cpu ().numpy (), c , s )
170
191
171
- all_preds [idx : idx + num_images , :, 0 :2 ] = preds [:, :, 0 :2 ]
172
- all_preds [idx : idx + num_images , :, 2 :3 ] = maxvals
192
+ all_preds [idx : idx + num_images , :, 0 :2 ] = preds [:, :, 0 :2 ]
193
+ all_preds [idx : idx + num_images , :, 2 :3 ] = maxvals
173
194
# double check this all_boxes parts
174
- all_boxes [idx : idx + num_images , 0 :2 ] = c [:, 0 :2 ]
175
- all_boxes [idx : idx + num_images , 2 :4 ] = s [:, 0 :2 ]
176
- all_boxes [idx : idx + num_images , 4 ] = np .prod (s * 200 , 1 )
177
- all_boxes [idx : idx + num_images , 5 ] = score
178
- image_path .extend (meta [' image' ])
195
+ all_boxes [idx : idx + num_images , 0 :2 ] = c [:, 0 :2 ]
196
+ all_boxes [idx : idx + num_images , 2 :4 ] = s [:, 0 :2 ]
197
+ all_boxes [idx : idx + num_images , 4 ] = np .prod (s * 200 , 1 )
198
+ all_boxes [idx : idx + num_images , 5 ] = score
199
+ image_path .extend (meta [" image" ])
179
200
180
201
idx += num_images
181
202
182
203
if i % config .PRINT_FREQ == 0 :
183
- msg = 'Test: [{0}/{1}]\t ' \
184
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' \
185
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t ' \
186
- 'Accuracy {acc.val:.3f} ({acc.avg:.3f})' .format (
187
- i , len (val_loader ), batch_time = batch_time ,
188
- loss = losses , acc = acc )
204
+ msg = (
205
+ "Test: [{0}/{1}]\t "
206
+ "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t "
207
+ "Loss {loss.val:.4f} ({loss.avg:.4f})\t "
208
+ "Accuracy {acc.val:.3f} ({acc.avg:.3f})" .format (
209
+ i , len (val_loader ), batch_time = batch_time , loss = losses , acc = acc
210
+ )
211
+ )
189
212
logger .info (msg )
190
213
191
- prefix = '{}_{}' .format (
192
- os .path .join (output_dir , 'val' ), i
193
- )
194
- save_debug_images (config , input , meta , target , pred * 4 , output ,
195
- prefix )
214
+ prefix = "{}_{}" .format (os .path .join (output_dir , "val" ), i )
215
+ save_debug_images (config , input , meta , target , pred * 4 , output , prefix )
196
216
197
217
name_values , perf_indicator = val_dataset .evaluate (
198
- config , all_preds , output_dir , all_boxes , image_path ,
199
- filenames , imgnums
218
+ config , all_preds , output_dir , all_boxes , image_path , filenames , imgnums
200
219
)
201
220
202
221
model_name = config .MODEL .NAME
@@ -207,32 +226,16 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
207
226
_print_name_value (name_values , model_name )
208
227
209
228
if writer_dict :
210
- writer = writer_dict ['writer' ]
211
- global_steps = writer_dict ['valid_global_steps' ]
212
- writer .add_scalar (
213
- 'valid_loss' ,
214
- losses .avg ,
215
- global_steps
216
- )
217
- writer .add_scalar (
218
- 'valid_acc' ,
219
- acc .avg ,
220
- global_steps
221
- )
229
+ writer = writer_dict ["writer" ]
230
+ global_steps = writer_dict ["valid_global_steps" ]
231
+ writer .add_scalar ("valid_loss" , losses .avg , global_steps )
232
+ writer .add_scalar ("valid_acc" , acc .avg , global_steps )
222
233
if isinstance (name_values , list ):
223
234
for name_value in name_values :
224
- writer .add_scalars (
225
- 'valid' ,
226
- dict (name_value ),
227
- global_steps
228
- )
235
+ writer .add_scalars ("valid" , dict (name_value ), global_steps )
229
236
else :
230
- writer .add_scalars (
231
- 'valid' ,
232
- dict (name_values ),
233
- global_steps
234
- )
235
- writer_dict ['valid_global_steps' ] = global_steps + 1
237
+ writer .add_scalars ("valid" , dict (name_values ), global_steps )
238
+ writer_dict ["valid_global_steps" ] = global_steps + 1
236
239
237
240
return perf_indicator
238
241
@@ -242,24 +245,23 @@ def _print_name_value(name_value, full_arch_name):
242
245
names = name_value .keys ()
243
246
values = name_value .values ()
244
247
num_values = len (name_value )
245
- logger .info (
246
- '| Arch ' +
247
- ' ' .join (['| {}' .format (name ) for name in names ]) +
248
- ' |'
249
- )
250
- logger .info ('|---' * (num_values + 1 ) + '|' )
248
+ logger .info ("| Arch " + " " .join (["| {}" .format (name ) for name in names ]) + " |" )
249
+ logger .info ("|---" * (num_values + 1 ) + "|" )
251
250
252
251
if len (full_arch_name ) > 15 :
253
- full_arch_name = full_arch_name [:8 ] + ' ...'
252
+ full_arch_name = full_arch_name [:8 ] + " ..."
254
253
logger .info (
255
- '| ' + full_arch_name + ' ' +
256
- ' ' .join (['| {:.3f}' .format (value ) for value in values ]) +
257
- ' |'
254
+ "| "
255
+ + full_arch_name
256
+ + " "
257
+ + " " .join (["| {:.3f}" .format (value ) for value in values ])
258
+ + " |"
258
259
)
259
260
260
261
261
262
class AverageMeter (object ):
262
263
"""Computes and stores the average and current value"""
264
+
263
265
def __init__ (self ):
264
266
self .reset ()
265
267
0 commit comments