|
| 1 | +# ------------------------------------------------------------------------------ |
| 2 | +# Copyright (c) Microsoft |
| 3 | +# Licensed under the MIT License. |
| 4 | +# Written by Bin Xiao (Bin.Xiao@microsoft.com) |
| 5 | +# ------------------------------------------------------------------------------ |
| 6 | + |
| 7 | +from __future__ import absolute_import |
| 8 | +from __future__ import division |
| 9 | +from __future__ import print_function |
| 10 | + |
| 11 | +import time |
| 12 | +import logging |
| 13 | +import os |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import torch |
| 17 | + |
| 18 | +from core.evaluate import accuracy, accuracy_classification, accuracy_landmark |
| 19 | +from utils.vis import save_result_images, save_debug_images, save_images_landmark |
| 20 | + |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +def train(config, train_loader, model, criterion, optimizer, epoch, |
| 26 | + output_dir, tb_log_dir, writer_dict): |
| 27 | + batch_time = AverageMeter() |
| 28 | + data_time = AverageMeter() |
| 29 | + losses = AverageMeter() |
| 30 | + loss_classifier = AverageMeter() |
| 31 | + loss_landmark = AverageMeter() |
| 32 | + acc = AverageMeter() |
| 33 | + acc_cls = AverageMeter() |
| 34 | + |
| 35 | + # switch to train mode |
| 36 | + model.train() |
| 37 | + |
| 38 | + end = time.time() |
| 39 | + for i, (input, target, target_weight, meta) in enumerate(train_loader): |
| 40 | + # measure data loading time |
| 41 | + data_time.update(time.time() - end) |
| 42 | + |
| 43 | + # compute output |
| 44 | + classification, landmark = model(input) |
| 45 | + |
| 46 | + #target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True).view(classification.size(0),-1) |
| 47 | + target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True) |
| 48 | + classloss = criterion[0](classification, target) |
| 49 | + |
| 50 | + target2 = meta["joints"].reshape(-1,64).type(torch.FloatTensor).cuda(non_blocking=True) |
| 51 | + lmloss = criterion[1](landmark, target2) |
| 52 | + |
| 53 | + #loss = config.TRAIN.LOSS_WEIGHT[0]*classloss + config.TRAIN.LOSS_WEIGHT[1]*lmloss |
| 54 | + loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss |
| 55 | + |
| 56 | + # compute gradient and do update step |
| 57 | + optimizer.zero_grad() |
| 58 | + loss.backward() |
| 59 | + optimizer.step() |
| 60 | + |
| 61 | + # measure accuracy and record loss |
| 62 | + losses.update(loss.item(), input.size(0)) |
| 63 | + loss_classifier.update(classloss.item(), input.size(0)) |
| 64 | + loss_landmark.update(lmloss.item(), input.size(0)) |
| 65 | + |
| 66 | + avg_acc, cnt= accuracy_landmark(landmark.detach().cpu().numpy(), |
| 67 | + target2.detach().cpu().numpy()) |
| 68 | + acc.update(avg_acc, cnt) |
| 69 | + |
| 70 | + avg_acc, cnt = accuracy_classification(classification.detach().cpu().numpy(), |
| 71 | + target.detach().cpu().numpy()) |
| 72 | + acc_cls.update(avg_acc, cnt) |
| 73 | + |
| 74 | + # measure elapsed time |
| 75 | + batch_time.update(time.time() - end) |
| 76 | + end = time.time() |
| 77 | + |
| 78 | + if i % config.PRINT_FREQ == 0: |
| 79 | + msg = 'Epoch: [{0}][{1}/{2}]\t' \ |
| 80 | + 'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \ |
| 81 | + 'Speed {speed:.1f} samples/s\t' \ |
| 82 | + 'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \ |
| 83 | + 'Loss {loss.val:.5f} ({loss.avg:.5f}) ({classific.avg: .5f}+{lm.avg: .5f})\t' \ |
| 84 | + 'Accuracy(landmark) {acc.val:.3f} ({acc.avg:.3f})\t'\ |
| 85 | + 'Accuracy(classification) {acc_cls.val:.3f} ({acc_cls.avg:.3f})'.format( |
| 86 | + epoch, i, len(train_loader), batch_time=batch_time, |
| 87 | + speed=input.size(0)/batch_time.val, |
| 88 | + data_time=data_time, |
| 89 | + loss=losses, classific=loss_classifier, lm=loss_landmark, |
| 90 | + acc=acc, acc_cls=acc_cls) |
| 91 | + logger.info(msg) |
| 92 | + |
| 93 | + |
| 94 | +def validate(config, val_loader, val_dataset, model, criterion, output_dir, |
| 95 | + tb_log_dir, writer_dict=None): |
| 96 | + batch_time = AverageMeter() |
| 97 | + losses = AverageMeter() |
| 98 | + acc = AverageMeter() |
| 99 | + |
| 100 | + # switch to evaluate mode |
| 101 | + model.eval() |
| 102 | + |
| 103 | + num_samples = len(val_dataset) |
| 104 | + all_preds = np.zeros( |
| 105 | + (num_samples, config.MODEL.NUM_JOINTS, 3), |
| 106 | + dtype=np.float32 |
| 107 | + ) |
| 108 | + all_boxes = np.zeros((num_samples, 6)) |
| 109 | + image_path = [] |
| 110 | + filenames = [] |
| 111 | + imgnums = [] |
| 112 | + idx = 0 |
| 113 | + with torch.no_grad(): |
| 114 | + end = time.time() |
| 115 | + for i, (input, target, target_weight, meta) in enumerate(val_loader): |
| 116 | + # compute output |
| 117 | + classification, landmark = model(input) |
| 118 | + |
| 119 | + target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True) |
| 120 | + classloss = criterion[0](classification, target) |
| 121 | + |
| 122 | + target2 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True) |
| 123 | + lmloss = criterion[1](landmark, target2) |
| 124 | + |
| 125 | + #loss = config.TRAIN.LOSS_WEIGHT[0]*classloss + config.TRAIN.LOSS_WEIGHT[1]*lmloss |
| 126 | + loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss |
| 127 | + |
| 128 | + num_images = input.size(0) |
| 129 | + # measure accuracy and record loss |
| 130 | + losses.update(loss.item(), num_images) |
| 131 | + avg_acc, cnt = accuracy_landmark(landmark.detach().cpu().numpy(), |
| 132 | + target2.detach().cpu().numpy()) |
| 133 | + acc.update(avg_acc, cnt) |
| 134 | + |
| 135 | + # measure elapsed time |
| 136 | + batch_time.update(time.time() - end) |
| 137 | + end = time.time() |
| 138 | + |
| 139 | + idx += num_images |
| 140 | + |
| 141 | + if i % config.PRINT_FREQ == 0: |
| 142 | + msg = 'Test: [{0}/{1}]\t' \ |
| 143 | + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 144 | + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
| 145 | + 'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format( |
| 146 | + i, len(val_loader), batch_time=batch_time, |
| 147 | + loss=losses, acc=acc) |
| 148 | + logger.info(msg) |
| 149 | + |
| 150 | + return acc.avg |
| 151 | + |
| 152 | +def test(config, val_loader, val_dataset, model, criterion, output_dir, |
| 153 | + tb_log_dir, writer_dict=None): |
| 154 | + batch_time = AverageMeter() |
| 155 | + losses = AverageMeter() |
| 156 | + acc = AverageMeter() |
| 157 | + acc_cls = AverageMeter() |
| 158 | + |
| 159 | + # switch to evaluate mode |
| 160 | + model.eval() |
| 161 | + |
| 162 | + num_samples = len(val_dataset) |
| 163 | + all_preds = np.zeros( |
| 164 | + (num_samples, config.MODEL.NUM_JOINTS, 3), |
| 165 | + dtype=np.float32 |
| 166 | + ) |
| 167 | + all_boxes = np.zeros((num_samples, 6)) |
| 168 | + image_path = [] |
| 169 | + filenames = [] |
| 170 | + imgnums = [] |
| 171 | + idx = 0 |
| 172 | + with torch.no_grad(): |
| 173 | + end = time.time() |
| 174 | + for i, (input, target, target_weight, meta) in enumerate(val_loader): |
| 175 | + # compute output |
| 176 | + classification, landmark = model(input) |
| 177 | + |
| 178 | + target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True) |
| 179 | + classloss = criterion[0](classification, target) |
| 180 | + |
| 181 | + target2 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True) |
| 182 | + lmloss = criterion[1](landmark, target2) |
| 183 | + |
| 184 | + #loss = config.TRAIN.LOSS_WEIGHT[0] * classloss + config.TRAIN.LOSS_WEIGHT[1] * lmloss |
| 185 | + loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss |
| 186 | + |
| 187 | + num_images = input.size(0) |
| 188 | + # measure accuracy and record loss |
| 189 | + losses.update(loss.item(), num_images) |
| 190 | + |
| 191 | + avg_acc, cnt = accuracy_landmark(landmark.detach().cpu().numpy(), |
| 192 | + target2.detach().cpu().numpy()) |
| 193 | + acc.update(avg_acc, cnt) |
| 194 | + |
| 195 | + avg_acc, cnt = accuracy_classification(classification.detach().cpu().numpy(), |
| 196 | + target.detach().cpu().numpy()) |
| 197 | + acc_cls.update(avg_acc, cnt) |
| 198 | + |
| 199 | + # measure elapsed time |
| 200 | + batch_time.update(time.time() - end) |
| 201 | + end = time.time() |
| 202 | + |
| 203 | + if i % 1 == 0: |
| 204 | + msg = 'Test: [{0}/{1}]\t' \ |
| 205 | + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \ |
| 206 | + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \ |
| 207 | + 'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'\ |
| 208 | + 'Accuracy {acc2.val:.3f} ({acc2.avg:.3f})'.format( |
| 209 | + i, len(val_loader), batch_time=batch_time, |
| 210 | + loss=losses, acc=acc, acc2=acc_cls) |
| 211 | + logger.info(msg) |
| 212 | + |
| 213 | + prefix = os.path.join(output_dir, 'result') |
| 214 | + |
| 215 | + save_images_landmark(meta, landmark.detach().cpu().numpy(), classification.detach().cpu().numpy(), prefix, i) |
| 216 | + |
| 217 | + return 0 |
| 218 | + |
| 219 | + |
| 220 | +# markdown format output |
| 221 | +def _print_name_value(name_value, full_arch_name): |
| 222 | + names = name_value.keys() |
| 223 | + values = name_value.values() |
| 224 | + num_values = len(name_value) |
| 225 | + logger.info( |
| 226 | + '| Arch ' + |
| 227 | + ' '.join(['| {}'.format(name) for name in names]) + |
| 228 | + ' |' |
| 229 | + ) |
| 230 | + logger.info('|---' * (num_values+1) + '|') |
| 231 | + |
| 232 | + if len(full_arch_name) > 15: |
| 233 | + full_arch_name = full_arch_name[:8] + '...' |
| 234 | + logger.info( |
| 235 | + '| ' + full_arch_name + ' ' + |
| 236 | + ' '.join(['| {:.3f}'.format(value) for value in values]) + |
| 237 | + ' |' |
| 238 | + ) |
| 239 | + |
| 240 | + |
| 241 | +class AverageMeter(object): |
| 242 | + """Computes and stores the average and current value""" |
| 243 | + def __init__(self): |
| 244 | + self.reset() |
| 245 | + |
| 246 | + def reset(self): |
| 247 | + self.val = 0 |
| 248 | + self.avg = 0 |
| 249 | + self.sum = 0 |
| 250 | + self.count = 0 |
| 251 | + |
| 252 | + def update(self, val, n=1): |
| 253 | + self.val = val |
| 254 | + self.sum += val * n |
| 255 | + self.count += n |
| 256 | + self.avg = self.sum / self.count if self.count != 0 else 0 |
0 commit comments