forked from leoxiaobin/deep-high-resolution-net.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_and_visualize.py
151 lines (130 loc) · 4.59 KB
/
run_and_visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
from core.function import validate
from utils.utils import create_logger
from custom_dataset import CustomDataset
import dataset
import models
def parse_args():
parser = argparse.ArgumentParser(description="Run and visualize keypoints network for image or video.")
# general
parser.add_argument("--cfg",
help="Configuration file name",
required=True,
type=str)
parser.add_argument("opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
parser.add_argument("--modelDir",
help="model directory",
type=str,
default="")
parser.add_argument("--logDir",
help="log directory",
type=str,
default="")
parser.add_argument("--dataDir",
help="data directory",
type=str,
default="")
parser.add_argument("--prevModelDir",
help="prev Model directory",
type=str,
default="")
parser.add_argument("--visualize",
help="Visualize the results",
type=bool,
default=False)
parser.add_argument("--input",
help="Input image file",
type=str,
default="")
parser.add_argument("--video",
help="Input video file",
type=str,
default="")
args = parser.parse_args()
return args
def main():
args = parse_args()
update_config(cfg, args)
# Create a logger
logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid')
logger.info(pprint.pformat(args))
logger.info(cfg)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
# Configure model
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, is_train=False
)
if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
else:
model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
# define loss function (criterion) and optimizer
criterion = JointsMSELoss(
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
).cuda()
# Data loading code
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
# Load data
if args.input != "":
# TODO: Write a way to handle single images
# TODO: Handle visualization
valid_dataset = CustomDataset(args.input, transforms.Compose([
transforms.ToTensor(),
normalize,
]))
elif args.video != "":
# TODO: Write a way to handle videos image by image
# TODO: Handle visualization
pass
else:
# Original dataset way
valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
transforms.Compose([
transforms.ToTensor(),
normalize,
])
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
shuffle=False,
num_workers=cfg.WORKERS,
pin_memory=True
)
# evaluate on validation set
validate(cfg, valid_loader, valid_dataset, model, criterion,
final_output_dir, tb_log_dir)
if __name__ == '__main__':
main()