14
14
15
15
import numpy as np
16
16
import torch
17
+ import cv2
17
18
18
19
from core .evaluate import accuracy
19
20
from core .inference import get_final_preds
@@ -94,8 +95,23 @@ def train(config, train_loader, model, criterion, optimizer, epoch,
94
95
prefix )
95
96
96
97
98
+ def compute_joints (batch_image , batch_joints , batch_joints_vis ):
99
+ for k in range (batch_image .size (0 )):
100
+ image_tensor = batch_image [k ]
101
+ image = image_tensor .mul (255 ).clamp (0 , 255 ).byte ().permute (1 , 2 , 0 ).cpu ().numpy ()
102
+ image = cv2 .cvtColor (image , cv2 .COLOR_BGR2RGB )
103
+
104
+ joints = batch_joints [k ]
105
+ joints_vis = batch_joints_vis [k ]
106
+ for joint in joints :
107
+ cv2 .circle (image , (int (joint [0 ]), int (joint [1 ])), 2 , [0 , 0 , 255 ], 2 )
108
+
109
+ cv2 .imshow ("im" , image )
110
+ cv2 .waitKey ()
111
+
112
+
97
113
def validate (config , val_loader , val_dataset , model , criterion , output_dir ,
98
- tb_log_dir , writer_dict = None ):
114
+ tb_log_dir , writer_dict = None , predict_only = False ):
99
115
batch_time = AverageMeter ()
100
116
losses = AverageMeter ()
101
117
acc = AverageMeter ()
@@ -116,6 +132,8 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
116
132
with torch .no_grad ():
117
133
end = time .time ()
118
134
for i , (input , target , target_weight , meta ) in enumerate (val_loader ):
135
+ img = input .data [0 ].mul (255 ).clamp (0 , 255 ).byte ().permute (1 , 2 , 0 ).cpu ().numpy ()
136
+
119
137
# compute output
120
138
outputs = model (input )
121
139
if isinstance (outputs , list ):
@@ -147,12 +165,12 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
147
165
148
166
output = (output + output_flipped ) * 0.5
149
167
168
+ num_images = input .size (0 )
169
+
150
170
target = target .cuda (non_blocking = True )
151
171
target_weight = target_weight .cuda (non_blocking = True )
152
172
153
173
loss = criterion (output , target , target_weight )
154
-
155
- num_images = input .size (0 )
156
174
# measure accuracy and record loss
157
175
losses .update (loss .item (), num_images )
158
176
_ , avg_acc , cnt , pred = accuracy (output .cpu ().numpy (),
@@ -181,6 +199,7 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
181
199
image_path .extend (meta ['image' ])
182
200
183
201
idx += num_images
202
+ compute_joints (input , pred * 4 , meta ['joints_vis' ])
184
203
185
204
if i % config .PRINT_FREQ == 0 :
186
205
msg = 'Test: [{0}/{1}]\t ' \
@@ -196,18 +215,20 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
196
215
)
197
216
save_debug_images (config , input , meta , target , pred * 4 , output ,
198
217
prefix )
218
+ name_values = None
219
+ perf_indicator = None
220
+ if not predict_only :
221
+ name_values , perf_indicator = val_dataset .evaluate (
222
+ config , all_preds , output_dir , all_boxes , image_path ,
223
+ filenames , imgnums
224
+ )
199
225
200
- name_values , perf_indicator = val_dataset .evaluate (
201
- config , all_preds , output_dir , all_boxes , image_path ,
202
- filenames , imgnums
203
- )
204
-
205
- model_name = config .MODEL .NAME
206
- if isinstance (name_values , list ):
207
- for name_value in name_values :
208
- _print_name_value (name_value , model_name )
209
- else :
210
- _print_name_value (name_values , model_name )
226
+ model_name = config .MODEL .NAME
227
+ if isinstance (name_values , list ):
228
+ for name_value in name_values :
229
+ _print_name_value (name_value , model_name )
230
+ else :
231
+ _print_name_value (name_values , model_name )
211
232
212
233
if writer_dict :
213
234
writer = writer_dict ['writer' ]
@@ -222,19 +243,22 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
222
243
acc .avg ,
223
244
global_steps
224
245
)
225
- if isinstance (name_values , list ):
226
- for name_value in name_values :
246
+
247
+ if not predict_only :
248
+ if isinstance (name_values , list ):
249
+ for name_value in name_values :
250
+ writer .add_scalars (
251
+ 'valid' ,
252
+ dict (name_value ),
253
+ global_steps
254
+ )
255
+ else :
227
256
writer .add_scalars (
228
257
'valid' ,
229
- dict (name_value ),
258
+ dict (name_values ),
230
259
global_steps
231
260
)
232
- else :
233
- writer .add_scalars (
234
- 'valid' ,
235
- dict (name_values ),
236
- global_steps
237
- )
261
+
238
262
writer_dict ['valid_global_steps' ] = global_steps + 1
239
263
240
264
return perf_indicator
0 commit comments