From fdb64fd7c97f4526028c07f496c260465930e39c Mon Sep 17 00:00:00 2001
From: CHAEN <nuguziii@naver.com>
Date: Fri, 28 Feb 2020 14:42:45 +0900
Subject: [PATCH 1/3] =?UTF-8?q?landmark=20accuracy=20=EC=B6=94=EA=B0=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 lib/core/evaluate.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/lib/core/evaluate.py b/lib/core/evaluate.py
index 1058a5cc..86d156c2 100644
--- a/lib/core/evaluate.py
+++ b/lib/core/evaluate.py
@@ -92,23 +92,27 @@ def accuracy_classification(output, target, thres=0.0):
     avg_acc = np.mean(acc)
     return avg_acc, cnt
 
-def accuracy_landmark(output, target):
+def accuracy_landmark(output, target, thres=6.0):
     '''
     Calculate accuracy according to PCK,
     but uses ground truth heatmap rather than x,y locations
     First value to be returned is average accuracy across 'idxs',
     followed by individual accuracies
     '''
+    batch = output.shape[0]
+    acc = np.zeros(batch)
 
-    cnt = output.shape[0]
-    acc = np.zeros(cnt)
+    output = output.reshape(batch, 32, 2)
+    target = target.reshape(batch, 32, 2)
 
-    target = target.reshape(cnt, -1)
+    diff = np.sqrt(np.square(output[:,:,0] - target[:,:,0]) + np.square(output[:,:,1] - target[:,:,1]))
 
-    for i in range(cnt):
-        acc[i] = sum(output[i]==target[i])/32
+    for i in range(batch):
+        cur = diff[i]
+        cur[cur < thres] = 0
+        acc[i] = (32-np.count_nonzero(cur))/32
 
     avg_acc = np.mean(acc)
-    return avg_acc, cnt
+    return avg_acc, batch
 
 

From 1a57331678662a4ec5681accc85f7a7181cfaa52 Mon Sep 17 00:00:00 2001
From: CHAEN <nuguziii@naver.com>
Date: Fri, 28 Feb 2020 14:43:22 +0900
Subject: [PATCH 2/3] =?UTF-8?q?function,=20pos=5Fhrnet=20landmark=20versio?=
 =?UTF-8?q?n=20=EC=B6=94=EA=B0=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 lib/core/function_plus.py     | 256 ++++++++++++++++
 lib/models/__init__.py        |   1 +
 lib/models/pose_hrnet_plus.py | 529 ++++++++++++++++++++++++++++++++++
 tools/test.py                 |   6 +-
 tools/train.py                |  10 +-
 5 files changed, 796 insertions(+), 6 deletions(-)
 create mode 100644 lib/core/function_plus.py
 create mode 100644 lib/models/pose_hrnet_plus.py

diff --git a/lib/core/function_plus.py b/lib/core/function_plus.py
new file mode 100644
index 00000000..84b4ccf1
--- /dev/null
+++ b/lib/core/function_plus.py
@@ -0,0 +1,256 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+ 
+import time
+import logging
+import os
+
+import numpy as np
+import torch
+
+from core.evaluate import accuracy, accuracy_classification, accuracy_landmark
+from utils.vis import save_result_images, save_debug_images, save_images_landmark
+
+
+logger = logging.getLogger(__name__)
+
+
+def train(config, train_loader, model, criterion, optimizer, epoch,
+          output_dir, tb_log_dir, writer_dict):
+    batch_time = AverageMeter()
+    data_time = AverageMeter()
+    losses = AverageMeter()
+    loss_classifier = AverageMeter()
+    loss_landmark = AverageMeter()
+    acc = AverageMeter()
+    acc_cls = AverageMeter()
+
+    # switch to train mode
+    model.train()
+
+    end = time.time()
+    for i, (input, target, target_weight, meta) in enumerate(train_loader):
+        # measure data loading time
+        data_time.update(time.time() - end)
+
+        # compute output
+        classification, landmark = model(input)
+
+        #target2 = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True).view(classification.size(0),-1)
+        target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
+        classloss = criterion[0](classification, target)
+
+        target2 = meta["joints"].reshape(-1,64).type(torch.FloatTensor).cuda(non_blocking=True)
+        lmloss = criterion[1](landmark, target2)
+
+        #loss = config.TRAIN.LOSS_WEIGHT[0]*classloss + config.TRAIN.LOSS_WEIGHT[1]*lmloss
+        loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss
+
+        # compute gradient and do update step
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        # measure accuracy and record loss
+        losses.update(loss.item(), input.size(0))
+        loss_classifier.update(classloss.item(), input.size(0))
+        loss_landmark.update(lmloss.item(), input.size(0))
+
+        avg_acc, cnt= accuracy_landmark(landmark.detach().cpu().numpy(),
+                                                    target2.detach().cpu().numpy())
+        acc.update(avg_acc, cnt)
+
+        avg_acc, cnt = accuracy_classification(classification.detach().cpu().numpy(),
+                                                   target.detach().cpu().numpy())
+        acc_cls.update(avg_acc, cnt)
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if i % config.PRINT_FREQ == 0:
+            msg = 'Epoch: [{0}][{1}/{2}]\t' \
+                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
+                  'Speed {speed:.1f} samples/s\t' \
+                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
+                  'Loss {loss.val:.5f} ({loss.avg:.5f}) ({classific.avg: .5f}+{lm.avg: .5f})\t' \
+                  'Accuracy(landmark) {acc.val:.3f} ({acc.avg:.3f})\t'\
+                   'Accuracy(classification) {acc_cls.val:.3f} ({acc_cls.avg:.3f})'.format(
+                      epoch, i, len(train_loader), batch_time=batch_time,
+                      speed=input.size(0)/batch_time.val,
+                      data_time=data_time,
+                      loss=losses, classific=loss_classifier, lm=loss_landmark,
+                      acc=acc, acc_cls=acc_cls)
+            logger.info(msg)
+
+
+def validate(config, val_loader, val_dataset, model, criterion, output_dir,
+             tb_log_dir, writer_dict=None):
+    batch_time = AverageMeter()
+    losses = AverageMeter()
+    acc = AverageMeter()
+
+    # switch to evaluate mode
+    model.eval()
+
+    num_samples = len(val_dataset)
+    all_preds = np.zeros(
+        (num_samples, config.MODEL.NUM_JOINTS, 3),
+        dtype=np.float32
+    )
+    all_boxes = np.zeros((num_samples, 6))
+    image_path = []
+    filenames = []
+    imgnums = []
+    idx = 0
+    with torch.no_grad():
+        end = time.time()
+        for i, (input, target, target_weight, meta) in enumerate(val_loader):
+            # compute output
+            classification, landmark = model(input)
+
+            target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
+            classloss = criterion[0](classification, target)
+
+            target2 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True)
+            lmloss = criterion[1](landmark, target2)
+
+            #loss = config.TRAIN.LOSS_WEIGHT[0]*classloss + config.TRAIN.LOSS_WEIGHT[1]*lmloss
+            loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss
+
+            num_images = input.size(0)
+            # measure accuracy and record loss
+            losses.update(loss.item(), num_images)
+            avg_acc, cnt = accuracy_landmark(landmark.detach().cpu().numpy(),
+                                             target2.detach().cpu().numpy())
+            acc.update(avg_acc, cnt)
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            idx += num_images
+
+            if i % config.PRINT_FREQ == 0:
+                msg = 'Test: [{0}/{1}]\t' \
+                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
+                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
+                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
+                          i, len(val_loader), batch_time=batch_time,
+                          loss=losses, acc=acc)
+                logger.info(msg)
+
+    return acc.avg
+
+def test(config, val_loader, val_dataset, model, criterion, output_dir,
+             tb_log_dir, writer_dict=None):
+    batch_time = AverageMeter()
+    losses = AverageMeter()
+    acc = AverageMeter()
+    acc_cls = AverageMeter()
+
+    # switch to evaluate mode
+    model.eval()
+
+    num_samples = len(val_dataset)
+    all_preds = np.zeros(
+        (num_samples, config.MODEL.NUM_JOINTS, 3),
+        dtype=np.float32
+    )
+    all_boxes = np.zeros((num_samples, 6))
+    image_path = []
+    filenames = []
+    imgnums = []
+    idx = 0
+    with torch.no_grad():
+        end = time.time()
+        for i, (input, target, target_weight, meta) in enumerate(val_loader):
+            # compute output
+            classification, landmark = model(input)
+
+            target = meta["visible"].type(torch.FloatTensor).cuda(non_blocking=True)
+            classloss = criterion[0](classification, target)
+
+            target2 = meta["joints"].reshape(-1, 64).type(torch.FloatTensor).cuda(non_blocking=True)
+            lmloss = criterion[1](landmark, target2)
+
+            #loss = config.TRAIN.LOSS_WEIGHT[0] * classloss + config.TRAIN.LOSS_WEIGHT[1] * lmloss
+            loss = config.TRAIN.LOSS_WEIGHT[1] * lmloss
+
+            num_images = input.size(0)
+            # measure accuracy and record loss
+            losses.update(loss.item(), num_images)
+
+            avg_acc, cnt = accuracy_landmark(landmark.detach().cpu().numpy(),
+                                             target2.detach().cpu().numpy())
+            acc.update(avg_acc, cnt)
+
+            avg_acc, cnt = accuracy_classification(classification.detach().cpu().numpy(),
+                                                   target.detach().cpu().numpy())
+            acc_cls.update(avg_acc, cnt)
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if i % 1 == 0:
+                msg = 'Test: [{0}/{1}]\t' \
+                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
+                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
+                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'\
+                      'Accuracy {acc2.val:.3f} ({acc2.avg:.3f})'.format(
+                          i, len(val_loader), batch_time=batch_time,
+                          loss=losses, acc=acc, acc2=acc_cls)
+                logger.info(msg)
+
+                prefix = os.path.join(output_dir, 'result')
+
+                save_images_landmark(meta, landmark.detach().cpu().numpy(), classification.detach().cpu().numpy(), prefix, i)
+
+    return 0
+
+
+# markdown format output
+def _print_name_value(name_value, full_arch_name):
+    names = name_value.keys()
+    values = name_value.values()
+    num_values = len(name_value)
+    logger.info(
+        '| Arch ' +
+        ' '.join(['| {}'.format(name) for name in names]) +
+        ' |'
+    )
+    logger.info('|---' * (num_values+1) + '|')
+
+    if len(full_arch_name) > 15:
+        full_arch_name = full_arch_name[:8] + '...'
+    logger.info(
+        '| ' + full_arch_name + ' ' +
+        ' '.join(['| {:.3f}'.format(value) for value in values]) +
+         ' |'
+    )
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count if self.count != 0 else 0
diff --git a/lib/models/__init__.py b/lib/models/__init__.py
index e3b7f1a7..af04302e 100644
--- a/lib/models/__init__.py
+++ b/lib/models/__init__.py
@@ -14,3 +14,4 @@
 
 import models.pose_resnet
 import models.pose_hrnet
+import models.pose_hrnet_plus
diff --git a/lib/models/pose_hrnet_plus.py b/lib/models/pose_hrnet_plus.py
new file mode 100644
index 00000000..8cd37a95
--- /dev/null
+++ b/lib/models/pose_hrnet_plus.py
@@ -0,0 +1,529 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+BN_MOMENTUM = 0.1
+logger = logging.getLogger(__name__)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+                               padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+                               bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+                                  momentum=BN_MOMENTUM)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class HighResolutionModule(nn.Module):
+    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+                 num_channels, fuse_method, multi_scale_output=True):
+        super(HighResolutionModule, self).__init__()
+        self._check_branches(
+            num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+        self.num_inchannels = num_inchannels
+        self.fuse_method = fuse_method
+        self.num_branches = num_branches
+
+        self.multi_scale_output = multi_scale_output
+
+        self.branches = self._make_branches(
+            num_branches, blocks, num_blocks, num_channels)
+        self.fuse_layers = self._make_fuse_layers()
+        self.relu = nn.ReLU(True)
+
+    def _check_branches(self, num_branches, blocks, num_blocks,
+                        num_inchannels, num_channels):
+        if num_branches != len(num_blocks):
+            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
+                num_branches, len(num_blocks))
+            logger.error(error_msg)
+            raise ValueError(error_msg)
+
+        if num_branches != len(num_channels):
+            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
+                num_branches, len(num_channels))
+            logger.error(error_msg)
+            raise ValueError(error_msg)
+
+        if num_branches != len(num_inchannels):
+            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
+                num_branches, len(num_inchannels))
+            logger.error(error_msg)
+            raise ValueError(error_msg)
+
+    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
+                         stride=1):
+        downsample = None
+        if stride != 1 or \
+           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.num_inchannels[branch_index],
+                    num_channels[branch_index] * block.expansion,
+                    kernel_size=1, stride=stride, bias=False
+                ),
+                nn.BatchNorm2d(
+                    num_channels[branch_index] * block.expansion,
+                    momentum=BN_MOMENTUM
+                ),
+            )
+
+        layers = []
+        layers.append(
+            block(
+                self.num_inchannels[branch_index],
+                num_channels[branch_index],
+                stride,
+                downsample
+            )
+        )
+        self.num_inchannels[branch_index] = \
+            num_channels[branch_index] * block.expansion
+        for i in range(1, num_blocks[branch_index]):
+            layers.append(
+                block(
+                    self.num_inchannels[branch_index],
+                    num_channels[branch_index]
+                )
+            )
+
+        return nn.Sequential(*layers)
+
+    def _make_branches(self, num_branches, block, num_blocks, num_channels):
+        branches = []
+
+        for i in range(num_branches):
+            branches.append(
+                self._make_one_branch(i, block, num_blocks, num_channels)
+            )
+
+        return nn.ModuleList(branches)
+
+    def _make_fuse_layers(self):
+        if self.num_branches == 1:
+            return None
+
+        num_branches = self.num_branches
+        num_inchannels = self.num_inchannels
+        fuse_layers = []
+        for i in range(num_branches if self.multi_scale_output else 1):
+            fuse_layer = []
+            for j in range(num_branches):
+                if j > i:
+                    fuse_layer.append(
+                        nn.Sequential(
+                            nn.Conv2d(
+                                num_inchannels[j],
+                                num_inchannels[i],
+                                1, 1, 0, bias=False
+                            ),
+                            nn.BatchNorm2d(num_inchannels[i]),
+                            nn.Upsample(scale_factor=2**(j-i), mode='nearest')
+                        )
+                    )
+                elif j == i:
+                    fuse_layer.append(None)
+                else:
+                    conv3x3s = []
+                    for k in range(i-j):
+                        if k == i - j - 1:
+                            num_outchannels_conv3x3 = num_inchannels[i]
+                            conv3x3s.append(
+                                nn.Sequential(
+                                    nn.Conv2d(
+                                        num_inchannels[j],
+                                        num_outchannels_conv3x3,
+                                        3, 2, 1, bias=False
+                                    ),
+                                    nn.BatchNorm2d(num_outchannels_conv3x3)
+                                )
+                            )
+                        else:
+                            num_outchannels_conv3x3 = num_inchannels[j]
+                            conv3x3s.append(
+                                nn.Sequential(
+                                    nn.Conv2d(
+                                        num_inchannels[j],
+                                        num_outchannels_conv3x3,
+                                        3, 2, 1, bias=False
+                                    ),
+                                    nn.BatchNorm2d(num_outchannels_conv3x3),
+                                    nn.ReLU(True)
+                                )
+                            )
+                    fuse_layer.append(nn.Sequential(*conv3x3s))
+            fuse_layers.append(nn.ModuleList(fuse_layer))
+
+        return nn.ModuleList(fuse_layers)
+
+    def get_num_inchannels(self):
+        return self.num_inchannels
+
+    def forward(self, x):
+        if self.num_branches == 1:
+            return [self.branches[0](x[0])]
+
+        for i in range(self.num_branches):
+            x[i] = self.branches[i](x[i])
+
+        x_fuse = []
+
+        for i in range(len(self.fuse_layers)):
+            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+            for j in range(1, self.num_branches):
+                if i == j:
+                    y = y + x[j]
+                else:
+                    y = y + self.fuse_layers[i][j](x[j])
+            x_fuse.append(self.relu(y))
+
+        return x_fuse
+
+
+blocks_dict = {
+    'BASIC': BasicBlock,
+    'BOTTLENECK': Bottleneck
+}
+
+
+class PoseHighResolutionNet(nn.Module):
+
+    def __init__(self, cfg, **kwargs):
+        self.inplanes = 64
+        extra = cfg.MODEL.EXTRA
+        super(PoseHighResolutionNet, self).__init__()
+
+        # stem net
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
+                               bias=False)
+        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+        self.relu = nn.ReLU(inplace=True)
+        self.layer1 = self._make_layer(Bottleneck, 64, 4)
+
+        self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
+        num_channels = self.stage2_cfg['NUM_CHANNELS']
+        block = blocks_dict[self.stage2_cfg['BLOCK']]
+        num_channels = [
+            num_channels[i] * block.expansion for i in range(len(num_channels))
+        ]
+        self.transition1 = self._make_transition_layer([256], num_channels)
+        self.stage2, pre_stage_channels = self._make_stage(
+            self.stage2_cfg, num_channels)
+
+        self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
+        num_channels = self.stage3_cfg['NUM_CHANNELS']
+        block = blocks_dict[self.stage3_cfg['BLOCK']]
+        num_channels = [
+            num_channels[i] * block.expansion for i in range(len(num_channels))
+        ]
+        self.transition2 = self._make_transition_layer(
+            pre_stage_channels, num_channels)
+        self.stage3, pre_stage_channels = self._make_stage(
+            self.stage3_cfg, num_channels)
+
+        self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
+        num_channels = self.stage4_cfg['NUM_CHANNELS']
+        block = blocks_dict[self.stage4_cfg['BLOCK']]
+        num_channels = [
+            num_channels[i] * block.expansion for i in range(len(num_channels))
+        ]
+        self.transition3 = self._make_transition_layer(
+            pre_stage_channels, num_channels)
+        self.stage4, pre_stage_channels = self._make_stage(
+            self.stage4_cfg, num_channels, multi_scale_output=False)
+
+        self.final_layer_heatmap = nn.Conv2d(
+            in_channels=pre_stage_channels[0],
+            out_channels=cfg.MODEL.NUM_JOINTS,
+            kernel_size=extra.FINAL_CONV_KERNEL,
+            stride=1,
+            padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+        )
+
+        self.getFeature = nn.Conv2d(
+            in_channels=pre_stage_channels[0],
+            out_channels=cfg.MODEL.NUM_JOINTS,
+            kernel_size=extra.FINAL_CONV_KERNEL,
+            stride=1,
+            padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+        )
+        self.classifier = nn.Linear(cfg.MODEL.NUM_JOINTS, cfg.MODEL.NUM_JOINTS)
+        self.softmax = nn.Softmax(dim=1)
+
+        self.getFeature2 = nn.Conv2d(
+            in_channels=pre_stage_channels[0],
+            out_channels=cfg.MODEL.NUM_JOINTS,
+            kernel_size=extra.FINAL_CONV_KERNEL,
+            stride=1,
+            padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+        )
+        self.final_layer_coord = nn.Linear(cfg.MODEL.NUM_JOINTS*2, cfg.MODEL.NUM_JOINTS*2)
+
+        self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']
+
+    def _make_transition_layer(
+            self, num_channels_pre_layer, num_channels_cur_layer):
+        num_branches_cur = len(num_channels_cur_layer)
+        num_branches_pre = len(num_channels_pre_layer)
+
+        transition_layers = []
+        for i in range(num_branches_cur):
+            if i < num_branches_pre:
+                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+                    transition_layers.append(
+                        nn.Sequential(
+                            nn.Conv2d(
+                                num_channels_pre_layer[i],
+                                num_channels_cur_layer[i],
+                                3, 1, 1, bias=False
+                            ),
+                            nn.BatchNorm2d(num_channels_cur_layer[i]),
+                            nn.ReLU(inplace=True)
+                        )
+                    )
+                else:
+                    transition_layers.append(None)
+            else:
+                conv3x3s = []
+                for j in range(i+1-num_branches_pre):
+                    inchannels = num_channels_pre_layer[-1]
+                    outchannels = num_channels_cur_layer[i] \
+                        if j == i-num_branches_pre else inchannels
+                    conv3x3s.append(
+                        nn.Sequential(
+                            nn.Conv2d(
+                                inchannels, outchannels, 3, 2, 1, bias=False
+                            ),
+                            nn.BatchNorm2d(outchannels),
+                            nn.ReLU(inplace=True)
+                        )
+                    )
+                transition_layers.append(nn.Sequential(*conv3x3s))
+
+        return nn.ModuleList(transition_layers)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(
+                    self.inplanes, planes * block.expansion,
+                    kernel_size=1, stride=stride, bias=False
+                ),
+                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def _make_stage(self, layer_config, num_inchannels,
+                    multi_scale_output=True):
+        num_modules = layer_config['NUM_MODULES']
+        num_branches = layer_config['NUM_BRANCHES']
+        num_blocks = layer_config['NUM_BLOCKS']
+        num_channels = layer_config['NUM_CHANNELS']
+        block = blocks_dict[layer_config['BLOCK']]
+        fuse_method = layer_config['FUSE_METHOD']
+
+        modules = []
+        for i in range(num_modules):
+            # multi_scale_output is only used last module
+            if not multi_scale_output and i == num_modules - 1:
+                reset_multi_scale_output = False
+            else:
+                reset_multi_scale_output = True
+
+            modules.append(
+                HighResolutionModule(
+                    num_branches,
+                    block,
+                    num_blocks,
+                    num_inchannels,
+                    num_channels,
+                    fuse_method,
+                    reset_multi_scale_output
+                )
+            )
+            num_inchannels = modules[-1].get_num_inchannels()
+
+        return nn.Sequential(*modules), num_inchannels
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+        x = self.layer1(x)
+
+        x_list = []
+        for i in range(self.stage2_cfg['NUM_BRANCHES']):
+            if self.transition1[i] is not None:
+                x_list.append(self.transition1[i](x))
+            else:
+                x_list.append(x)
+        y_list = self.stage2(x_list)
+
+        x_list = []
+        for i in range(self.stage3_cfg['NUM_BRANCHES']):
+            if self.transition2[i] is not None:
+                x_list.append(self.transition2[i](y_list[-1]))
+            else:
+                x_list.append(y_list[i])
+        y_list = self.stage3(x_list)
+
+        x_list = []
+        for i in range(self.stage4_cfg['NUM_BRANCHES']):
+            if self.transition3[i] is not None:
+                x_list.append(self.transition3[i](y_list[-1]))
+            else:
+                x_list.append(y_list[i])
+        y_list = self.stage4(x_list)
+
+        features1 = self.getFeature(y_list[0])
+        gap = F.adaptive_avg_pool2d(features1, (1,1))
+        flatten = gap.view(gap.size(0),-1)
+        classification = self.classifier(flatten)
+        classification = self.softmax(classification)
+
+        features2 = self.getFeature2(y_list[0])
+        features2 = F.adaptive_avg_pool2d(features2, (1,2)).view(gap.size(0),-1)
+        landmark = self.final_layer_coord(features2)
+
+        return classification, landmark
+
+    def init_weights(self, pretrained=''):
+        logger.info('=> init weights from normal distribution')
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                nn.init.normal_(m.weight, std=0.001)
+                for name, _ in m.named_parameters():
+                    if name in ['bias']:
+                        nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.ConvTranspose2d):
+                nn.init.normal_(m.weight, std=0.001)
+                for name, _ in m.named_parameters():
+                    if name in ['bias']:
+                        nn.init.constant_(m.bias, 0)
+
+        if os.path.isfile(pretrained):
+            pretrained_state_dict = torch.load(pretrained)
+            logger.info('=> loading pretrained model {}'.format(pretrained))
+
+            need_init_state_dict = {}
+            for name, m in pretrained_state_dict.items():
+                if name.split('.')[0] in self.pretrained_layers \
+                   or self.pretrained_layers[0] is '*':
+                    need_init_state_dict[name] = m
+            self.load_state_dict(need_init_state_dict, strict=False)
+        elif pretrained:
+            logger.error('=> please download pre-trained models first!')
+            raise ValueError('{} is not exist!'.format(pretrained))
+
+
+def get_pose_net(cfg, is_train, **kwargs):
+    model = PoseHighResolutionNet(cfg, **kwargs)
+
+    if is_train and cfg.MODEL.INIT_WEIGHTS:
+        model.init_weights(cfg.MODEL.PRETRAINED)
+
+    return model
diff --git a/tools/test.py b/tools/test.py
index e8ed9394..873919e4 100755
--- a/tools/test.py
+++ b/tools/test.py
@@ -14,6 +14,7 @@
 import pprint
 
 import torch
+import torch.nn as nn
 import torch.nn.parallel
 import torch.backends.cudnn as cudnn
 import torch.optim
@@ -25,7 +26,7 @@
 from config import cfg
 from config import update_config
 from core.loss import JointsMSELoss, JointsCELoss
-from core.function import test
+from core.function_plus import test
 from utils.utils import create_logger
 
 import dataset
@@ -104,6 +105,7 @@ def main():
 
     # classifierLoss = nn.MSELoss(reduction='mean').cuda()
     classifierLoss = JointsCELoss().cuda()
+    lmloss = nn.MSELoss(reduction='mean').cuda()
 
     # Data loading code
     test_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
@@ -122,7 +124,7 @@ def main():
     )
 
     # evaluate on validation set
-    test(cfg, test_loader, test_dataset, model, [heatmapLoss, classifierLoss],
+    test(cfg, test_loader, test_dataset, model, [classifierLoss, lmloss],
              final_output_dir, tb_log_dir)
 
 
diff --git a/tools/train.py b/tools/train.py
index 54cbee75..634fcfde 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -27,8 +27,10 @@
 from config import cfg
 from config import update_config
 from core.loss import JointsMSELoss, JointsCELoss
-from core.function import train
-from core.function import validate
+#from core.function import train
+#from core.function import validate
+from core.function_plus import train
+from core.function_plus import validate
 from utils.utils import get_optimizer
 from utils.utils import save_checkpoint
 from utils.utils import create_logger
@@ -183,13 +185,13 @@ def main():
         lr_scheduler.step()
 
         # train for one epoch
-        train(cfg, train_loader, model, [heatmapLoss, classifierLoss, lmloss], optimizer, epoch,
+        train(cfg, train_loader, model, [classifierLoss, lmloss], optimizer, epoch,
               final_output_dir, tb_log_dir, writer_dict)
 
 
         # evaluate on validation set
         perf_indicator = validate(
-            cfg, valid_loader, valid_dataset, model, [heatmapLoss, classifierLoss, lmloss],
+            cfg, valid_loader, valid_dataset, model, [classifierLoss, lmloss],
             final_output_dir, tb_log_dir, writer_dict
         )
 

From 4c35d46bf182380c421e0406222b90809c662258 Mon Sep 17 00:00:00 2001
From: CHAEN <nuguziii@naver.com>
Date: Fri, 28 Feb 2020 14:43:38 +0900
Subject: [PATCH 3/3] =?UTF-8?q?landmark=20visualize=20result=20=EC=B6=94?=
 =?UTF-8?q?=EA=B0=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 lib/utils/vis.py | 75 +++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 71 insertions(+), 4 deletions(-)

diff --git a/lib/utils/vis.py b/lib/utils/vis.py
index 3557c32d..ecb8172c 100755
--- a/lib/utils/vis.py
+++ b/lib/utils/vis.py
@@ -84,10 +84,20 @@ def save_image_with_joints(batch_joints, file_name, meta):
         joint_coord = pano.getOriginalCoord(image, joint_coord)
         joint_gt_coord = pano.getOriginalCoord(image, joint_gt_coord)
 
-        #if joint_vis[0]:
-        cv2.putText(image, teeth_num[i+1], (int(joint_coord[0]), int(joint_coord[1])), cv2.FONT_ITALIC, 0.4, [255, 0, 0], 1)
-        cv2.circle(image, (int(joint_coord[0]), int(joint_coord[1])), 2, [255, 0, 0], 2)
-        cv2.circle(image, (int(joint_gt_coord[0]), int(joint_gt_coord[1])), 2, [0,255, 0], 2)
+        if joint_vis[0]:
+            cv2.putText(image, teeth_num[i+1], (int(joint_coord[0])+4, int(joint_coord[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [255, 0, 0], 1)
+            cv2.circle(image, (int(joint_coord[0]), int(joint_coord[1])), 2, [255, 0, 0], 2)
+            cv2.putText(image, teeth_num[i + 1], (int(joint_gt_coord[0])+4, int(joint_gt_coord[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [0,255, 0], 1)
+            cv2.circle(image, (int(joint_gt_coord[0]), int(joint_gt_coord[1])), 1, [0,255, 0], 2)
+        else:
+            cv2.putText(image, teeth_num[i + 1], (int(joint_coord[0])+4, int(joint_coord[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [255, 255, 0], 1)
+            cv2.circle(image, (int(joint_coord[0]), int(joint_coord[1])), 2, [255, 255, 0], 2)
+            cv2.putText(image, teeth_num[i + 1], (int(joint_gt_coord[0])+4, int(joint_gt_coord[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [0, 255, 255], 1)
+            cv2.circle(image, (int(joint_gt_coord[0]), int(joint_gt_coord[1])), 1, [0, 255, 255], 2)
 
     cv2.imwrite(file_name, image)
 
@@ -208,3 +218,60 @@ def save_result_images(config, input, meta, target, joints_pred, output,
         save_batch_heatmaps(
             input, output, heatDir
         )
+
+def save_images_landmark(meta, batch_landmark, batch_classification, prefix, i):
+    '''
+        batch_image: [1, channel, height, width]
+        batch_landmark: [1, 64],
+        batch_classification: [1, 32],
+        batch_joints: [1, 32, 3],
+    '''
+
+    if not os.path.exists(prefix):
+        os.makedirs(prefix)
+
+    file_name = '{}_{}.jpg'.format(os.path.join(prefix, 'test'), i)
+
+    pano = PANODataset(None)
+    image = cv2.imread(meta["filename"][0], cv2.IMREAD_COLOR)
+
+    joints_gt = meta['joints'][0]
+
+    cv2.putText(image, 'pred', (250, 100), cv2.FONT_ITALIC, 1.5, [255, 0, 0], 3)
+    cv2.circle(image, (200, 100), 3, [255, 0, 0], 2)
+    cv2.putText(image, 'gt', (250, 170), cv2.FONT_ITALIC, 1.5, [0, 255, 0], 3)
+    cv2.circle(image, (200, 170), 3, [0, 255, 0], 2)
+
+    batch_landmark = batch_landmark.reshape(1, 32, 2)
+
+    for i, (joint_gt) in enumerate(joints_gt):
+
+        teeth_num = {1: '18', 2: '17', 3: '16', 4: '15', 5: '14', 6: '13', 7: '12', 8: '11',
+                     9: '21', 10: '22', 11: '23', 12: '24', 13: '25', 14: '26', 15: '27', 16: '28',
+                     17: '48', 18: '47', 19: '46', 20: '45', 21: '44', 22: '43', 23: '42', 24: '41',
+                     25: '31', 26: '32', 27: '33', 28: '34', 29: '35', 30: '36', 31: '37', 32: '38'}
+
+        joint_gt_coord = joint_gt[:2]
+        joint_gt_coord = pano.getOriginalCoord(image, joint_gt_coord)
+
+        joint_lm = batch_landmark[0][i]
+        joint_lm = pano.getOriginalCoord(image, joint_lm)
+
+        cls = batch_classification[0][i]
+
+        if cls>0:
+            cv2.putText(image, teeth_num[i+1], (int(joint_lm[0])+4, int(joint_lm[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [255, 0, 0], 1)
+            cv2.circle(image, (int(joint_lm[0]), int(joint_lm[1])), 2, [255, 0, 0], 2)
+            cv2.putText(image, teeth_num[i + 1], (int(joint_gt_coord[0])+4, int(joint_gt_coord[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [0,255, 0], 1)
+            cv2.circle(image, (int(joint_gt_coord[0]), int(joint_gt_coord[1])), 1, [0,255, 0], 2)
+        else:
+            cv2.putText(image, teeth_num[i + 1], (int(joint_lm[0])+4, int(joint_lm[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [255, 255, 0], 1)
+            cv2.circle(image, (int(joint_lm[0]), int(joint_lm[1])), 2, [255, 255, 0], 2)
+            cv2.putText(image, teeth_num[i + 1], (int(joint_gt_coord[0])+4, int(joint_gt_coord[1])+4), cv2.FONT_ITALIC, 0.4,
+                        [0, 255, 255], 1)
+            cv2.circle(image, (int(joint_gt_coord[0]), int(joint_gt_coord[1])), 1, [0, 255, 255], 2)
+
+    cv2.imwrite(file_name, image)
\ No newline at end of file