Skip to content

Commit f87b304

Browse files
committed
registration trains
1 parent b880bea commit f87b304

File tree

6 files changed

+5
-9
lines changed

6 files changed

+5
-9
lines changed

conf/models/registration/pointnet2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ models:
88
radii: [[0.2], [0.4]]
99
nsamples: [[64], [64]]
1010
down_conv_nn:
11-
[[[FEAT + 3, 64, 64, 128]], [[128+3, 128, 128, 256]]]
11+
[[[FEAT, 64, 64, 128]], [[128+3, 128, 128, 256]]]
1212
mlp_cls:
1313
nn: [128, 256, 32]
1414
dropout: 0.5

src/datasets/registration/general3dmatch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ def get_patch(self, idx):
134134
if(self.transform is not None):
135135
data_source = self.transform(data_source)
136136
data_target = self.transform(data_target)
137-
batch = Batch.from_data_list([data_source, data_target])
138-
batch.pair = batch.batch
139-
batch.batch = None
137+
batch = make_pair(data_source, data_target)
140138
batch = batch.contiguous().to(torch.float)
141139

142140
return batch
@@ -164,9 +162,7 @@ def get_fragment(self, idx):
164162
if(self.transform is not None):
165163
data_source = self.transform(data_source)
166164
data_target = self.transform(data_target)
167-
batch = Batch.from_data_list([data_source, data_target])
168-
batch.pair = batch.batch
169-
batch.batch = None
165+
batch = make_pair(data_source, data_target)
170166
batch.y = torch.from_numpy(match['pair'])
171167
return batch.contiguous().to(torch.float)
172168

src/datasets/registration/pair.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ def make_pair(data_source: Data, data_target: Data):
1717
for key_source in data_source.keys:
1818
batch[key_source+"_source"] = data_source[key_source]
1919
for key_target in data_target.keys:
20-
batch[key_target+"_source"] = data_target[key_target]
20+
batch[key_target+"_target"] = data_target[key_target]
2121
return batch.contiguous()

src/models/registration/pointnet2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def set_input(self, data):
4949
x = torch.cat([data.x_source, data.x_target], 0).transpose(1, 2).contiguous()
5050
else:
5151
x = None
52+
5253
pos = torch.cat([data.pos_source, data.pos_target], 0)
5354
rang = torch.arange(0, data.pos_source.shape[0])
5455
labels = torch.cat([rang, rang], 0)

test.pt

5.5 MB
Binary file not shown.

train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def train_epoch(
4545
with Ctq(train_loader) as tq_train_loader:
4646
for i, data in enumerate(tq_train_loader):
4747
data = data.to(device) # This takes time
48-
4948
model.set_input(data)
5049
t_data = time.time() - iter_data_time
5150

0 commit comments

Comments
 (0)