|
18 | 18 | from core.evaluate import accuracy
|
19 | 19 | from core.inference import get_final_preds
|
20 | 20 | from utils.transforms import flip_back
|
21 |
| -from utils.vis import save_debug_images |
| 21 | +from utils.vis import save_result_images, save_debug_images |
22 | 22 |
|
23 | 23 |
|
24 | 24 | logger = logging.getLogger(__name__)
|
@@ -194,47 +194,108 @@ def validate(config, val_loader, val_dataset, model, criterion, output_dir,
|
194 | 194 | save_debug_images(config, input, meta, target, pred*4, output,
|
195 | 195 | prefix)
|
196 | 196 |
|
197 |
| - name_values, perf_indicator = val_dataset.evaluate( |
198 |
| - config, all_preds, output_dir, all_boxes, image_path, |
199 |
| - filenames, imgnums |
200 |
| - ) |
| 197 | + return 0 |
201 | 198 |
|
202 |
| - model_name = config.MODEL.NAME |
203 |
| - if isinstance(name_values, list): |
204 |
| - for name_value in name_values: |
205 |
| - _print_name_value(name_value, model_name) |
206 |
| - else: |
207 |
| - _print_name_value(name_values, model_name) |
| 199 | +def test(config, val_loader, val_dataset, model, criterion, output_dir, |
| 200 | + tb_log_dir, writer_dict=None): |
| 201 | + batch_time = AverageMeter() |
| 202 | + losses = AverageMeter() |
| 203 | + acc = AverageMeter() |
208 | 204 |
|
209 |
| - if writer_dict: |
210 |
| - writer = writer_dict['writer'] |
211 |
| - global_steps = writer_dict['valid_global_steps'] |
212 |
| - writer.add_scalar( |
213 |
| - 'valid_loss', |
214 |
| - losses.avg, |
215 |
| - global_steps |
216 |
| - ) |
217 |
| - writer.add_scalar( |
218 |
| - 'valid_acc', |
219 |
| - acc.avg, |
220 |
| - global_steps |
221 |
| - ) |
222 |
| - if isinstance(name_values, list): |
223 |
| - for name_value in name_values: |
224 |
| - writer.add_scalars( |
225 |
| - 'valid', |
226 |
| - dict(name_value), |
227 |
| - global_steps |
228 |
| - ) |
| 205 | + # switch to evaluate mode |
| 206 | + model.eval() |
| 207 | + |
| 208 | + num_samples = len(val_dataset) |
| 209 | + all_preds = np.zeros( |
| 210 | + (num_samples, config.MODEL.NUM_JOINTS, 3), |
| 211 | + dtype=np.float32 |
| 212 | + ) |
| 213 | + all_boxes = np.zeros((num_samples, 6)) |
| 214 | + image_path = [] |
| 215 | + filenames = [] |
| 216 | + imgnums = [] |
| 217 | + idx = 0 |
| 218 | + with torch.no_grad(): |
| 219 | + end = time.time() |
| 220 | + for i, (input, target, target_weight, meta) in enumerate(val_loader): |
| 221 | + # compute output |
| 222 | + outputs = model(input) |
| 223 | + if isinstance(outputs, list): |
| 224 | + output = outputs[-1] |
229 | 225 | else:
|
230 |
| - writer.add_scalars( |
231 |
| - 'valid', |
232 |
| - dict(name_values), |
233 |
| - global_steps |
234 |
| - ) |
235 |
| - writer_dict['valid_global_steps'] = global_steps + 1 |
| 226 | + output = outputs |
| 227 | + |
| 228 | + if config.TEST.FLIP_TEST: |
| 229 | + input_flipped = input.flip(3) |
| 230 | + outputs_flipped = model(input_flipped) |
| 231 | + |
| 232 | + if isinstance(outputs_flipped, list): |
| 233 | + output_flipped = outputs_flipped[-1] |
| 234 | + else: |
| 235 | + output_flipped = outputs_flipped |
| 236 | + |
| 237 | + output_flipped = flip_back(output_flipped.cpu().numpy(), |
| 238 | + val_dataset.flip_pairs) |
| 239 | + output_flipped = torch.from_numpy(output_flipped.copy()).cuda() |
| 240 | + |
| 241 | + |
| 242 | + # feature is not aligned, shift flipped heatmap for higher accuracy |
| 243 | + if config.TEST.SHIFT_HEATMAP: |
| 244 | + output_flipped[:, :, :, 1:] = \ |
| 245 | + output_flipped.clone()[:, :, :, 0:-1] |
| 246 | + |
| 247 | + output = (output + output_flipped) * 0.5 |
| 248 | + |
| 249 | + target = target.cuda(non_blocking=True) |
| 250 | + target_weight = target_weight.cuda(non_blocking=True) |
| 251 | + |
| 252 | + loss = criterion(output, target, target_weight) |
| 253 | + |
| 254 | + num_images = input.size(0) |
| 255 | + # measure accuracy and record loss |
| 256 | + losses.update(loss.item(), num_images) |
| 257 | + _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(), |
| 258 | + target.cpu().numpy()) |
| 259 | + |
| 260 | + acc.update(avg_acc, cnt) |
| 261 | + |
| 262 | + # measure elapsed time |
| 263 | + batch_time.update(time.time() - end) |
| 264 | + end = time.time() |
| 265 | + |
| 266 | + c = meta['center'].numpy() |
| 267 | + s = meta['scale'].numpy() |
| 268 | + score = meta['score'].numpy() |
| 269 | + |
| 270 | + preds, maxvals = get_final_preds( |
| 271 | + config, output.clone().cpu().numpy(), c, s) |
| 272 | + |
| 273 | + all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2] |
| 274 | + all_preds[idx:idx + num_images, :, 2:3] = maxvals |
| 275 | + # double check this all_boxes parts |
| 276 | + all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2] |
| 277 | + all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2] |
| 278 | + all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1) |
| 279 | + all_boxes[idx:idx + num_images, 5] = score |
| 280 | + image_path.extend(meta['image']) |
| 281 | + |
| 282 | + idx += num_images |
| 283 | + |
| 284 | + if i % 1 == 0: |
| 285 | + msg = 'Test: [{0}/{1}]\t' \ |
| 286 | + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 287 | + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
| 288 | + 'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format( |
| 289 | + i, len(val_loader), batch_time=batch_time, |
| 290 | + loss=losses, acc=acc) |
| 291 | + logger.info(msg) |
| 292 | + |
| 293 | + prefix = os.path.join(output_dir, 'result') |
| 294 | + |
| 295 | + save_result_images(config, input, meta, target, pred*4, output, |
| 296 | + prefix, i) |
236 | 297 |
|
237 |
| - return perf_indicator |
| 298 | + return 0 |
238 | 299 |
|
239 | 300 |
|
240 | 301 | # markdown format output
|
|
0 commit comments