forked from torch-points3d/torch-points3d
-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathvotenet.py
150 lines (128 loc) · 5.9 KB
/
votenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import logging
import numpy as np
import torch
import os
from torch_geometric.data import Data
from torch_points3d.datasets.object_detection.box_data import BoxData
from torch_points3d.models.base_model import BaseModel
from torch_points3d.applications import models
import torch_points3d.modules.VoteNet as votenet_module
from torch_points3d.models.base_architectures import UnetBasedModel
from torch_points3d.datasets.segmentation import IGNORE_LABEL
log = logging.getLogger(__name__)
class VoteNetModel(BaseModel):
__REQUIRED_DATA__ = [
"pos",
]
__REQUIRED_LABELS__ = [
"center_label",
"heading_class_label",
"heading_residual_label",
"size_class_label",
"size_residual_label",
"sem_cls_label",
"box_label_mask",
"vote_label",
"vote_label_mask",
]
def __init__(self, option, model_type, dataset, modules):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
super(VoteNetModel, self).__init__(option)
self._dataset = dataset
self._weight_classes = dataset.weight_classes
# 1 - CREATE BACKBONE MODEL
input_nc = dataset.feature_dimension
backbone_option = option.backbone
backbone_cls = getattr(models, backbone_option.model_type)
self.backbone_model = backbone_cls(architecture="unet", input_nc=input_nc, config=backbone_option)
# 2 - CREATE VOTING MODEL
voting_option = option.voting
voting_cls = getattr(votenet_module, voting_option.module_name)
self.voting_module = voting_cls(vote_factor=voting_option.vote_factor, seed_feature_dim=voting_option.feat_dim)
# 3 - CREATE PROPOSAL MODULE
num_classes = dataset.num_classes
proposal_option = option.proposal
proposal_cls = getattr(votenet_module, proposal_option.module_name)
self.proposal_cls_module = proposal_cls(
num_class=num_classes,
vote_aggregation_config=proposal_option.vote_aggregation,
num_heading_bin=proposal_option.num_heading_bin,
mean_size_arr=dataset.mean_size_arr,
num_proposal=proposal_option.num_proposal,
sampling=proposal_option.sampling,
)
# Loss params
self.loss_params = option.loss_params
self.loss_params.num_heading_bin = proposal_option.num_heading_bin
mean_size_arr = dataset.mean_size_arr
if isinstance(mean_size_arr, torch.Tensor):
mean_size_arr = mean_size_arr.numpy().tolist()
if isinstance(dataset.mean_size_arr, np.ndarray):
mean_size_arr = mean_size_arr.tolist()
self.loss_params.mean_size_arr = mean_size_arr
self.losses_has_been_added = False
self.loss_names = []
def set_input(self, data, device):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
# Forward through backbone model
self.input = data.to(device)
def forward(self, *args, **kwargs):
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
data_features = self.backbone_model.forward(self.input)
data_votes = self.voting_module(data_features)
sampling_id_key = "sampling_id_0"
num_seeds = data_features.pos.shape[1]
seed_inds = getattr(data_features, sampling_id_key, None)[:, :num_seeds]
setattr(data_votes, "seed_inds", seed_inds) # [B,num_seeds]
outputs: votenet_module.VoteNetResults = self.proposal_cls_module(data_votes)
# Set output
self.output = outputs
if hasattr(self.input, "center_label"):
gt_center = self.input.center_label[:, :, 0:3]
self.output.assign_objects(
gt_center, self.input.box_label_mask, self.loss_params.near_threshold, self.loss_params.far_threshold
)
with torch.no_grad():
self._dump_visuals()
def _compute_losses(self):
if self._weight_classes is not None:
self._weight_classes = self._weight_classes.to(self.device)
losses = votenet_module.get_loss(self.input, self.output, self.loss_params, weight_classes=self._weight_classes)
for loss_name, loss in losses.items():
if torch.is_tensor(loss):
if not self.losses_has_been_added:
self.loss_names += [loss_name]
setattr(self, loss_name, loss)
self.losses_has_been_added = True
def _dump_visuals(self):
if True:
return
if not hasattr(self, "visual_count"):
self.visual_count = 0
pred_boxes = self.output.get_boxes(self._dataset, apply_nms=True)
gt_boxes = []
for idx in range(len(pred_boxes)):
# Ground truth
sample_boxes = self.input.instance_box_corners[idx]
sample_boxes = sample_boxes[self.input.box_label_mask[idx]]
sample_labels = self.input.sem_cls_label[idx]
gt_box_data = [BoxData(sample_labels[i].item(), sample_boxes[i]) for i in range(len(sample_boxes))]
gt_boxes.append(gt_box_data)
data_visual = Data(pos=self.input.pos, batch=self.input.batch, gt_boxes=gt_boxes, pred_boxes=pred_boxes)
if not os.path.exists("viz"):
os.mkdir("viz")
torch.save(data_visual.to("cpu"), "viz/data_%i.pt" % (self.visual_count))
self.visual_count += 1
def backward(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
self._compute_losses()
self.loss.backward()