Skip to content

Commit f3d03e5

Browse files
Transforms fixes + eval on multiple epochs (torch-points3d#233)
* Transforms fixes + eval on multiple epochs * Adding some more tests * Some missing changes * Fix tests Co-authored-by: chaton <thomas.chaton.ai@gmail.com>
1 parent 98e4289 commit f3d03e5

22 files changed

+327
-140
lines changed

conf/eval.yaml

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
num_workers: 2
2-
batch_size: 16
1+
num_workers: 6
2+
batch_size: 10
33
cuda: 1
4-
weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest]
4+
weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest]
55
enable_cudnn: True
6-
checkpoint_dir: "" # "{your_path}/outputs/2020-01-28/11-04-13" for example
6+
checkpoint_dir: "/home/nicolas/deeppointcloud-benchmarks/outputs/2020-04-14/21-54-19" # "{your_path}/outputs/2020-01-28/11-04-13" for example
77
model_name: KPConvPaper
8-
precompute_multi_scale: True # Compute multiscate features on cpu for faster training / inference
8+
precompute_multi_scale: True # Compute multiscate features on cpu for faster training / inference
99
enable_dropout: False
10+
voting_runs: 1
1011

11-
12-
12+
tracker_options: # Extra options for the tracker
13+
full_res: True

conf/training/kpconv.yaml

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Those arguments defines the training hyper-parameters
22
training:
3-
epochs: 200
4-
num_workers: 10
5-
batch_size: 10
6-
shuffle: True
3+
epochs: 550
4+
num_workers: 8
5+
batch_size: 8
76
cuda: 1
87
precompute_multi_scale: True # Compute multiscate features on cpu for faster training / inference
98
optim:
@@ -33,7 +32,7 @@ wandb:
3332
project: s3dis
3433
log: True
3534
notes: "Fixed labels"
36-
name: "kpconv-fixedlabels"
35+
name: "kpconv-dynamicdataset-6"
3736
public: True # It will be display the model within wandb log, else not.
3837

3938

docs/src/api/transforms.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ Transforms
5454

5555
.. autofunction:: torch_points3d.core.data_transform.Random3AxisRotation
5656

57-
.. autofunction:: src.core.data_transform.RandomCoordsFlip
57+
.. autofunction:: torch_points3d.core.data_transform.RandomCoordsFlip
5858

5959
.. autofunction:: torch_points3d.core.data_transform.compute_planarity

eval.py

+68-25
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,79 @@
2323
log = logging.getLogger(__name__)
2424

2525

26-
def eval_epoch(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint):
26+
def eval_epoch(
27+
model: BaseModel,
28+
dataset,
29+
device,
30+
tracker: BaseTracker,
31+
checkpoint: ModelCheckpoint,
32+
voting_runs=1,
33+
tracker_options={},
34+
):
2735
tracker.reset("val")
2836
loader = dataset.val_dataloader
29-
with Ctq(loader) as tq_val_loader:
30-
for data in tq_val_loader:
31-
with torch.no_grad():
32-
model.set_input(data, device)
33-
model.forward()
37+
for i in range(voting_runs):
38+
with Ctq(loader) as tq_val_loader:
39+
for data in tq_val_loader:
40+
with torch.no_grad():
41+
model.set_input(data, device)
42+
model.forward()
3443

35-
tracker.track(model)
36-
tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR)
44+
tracker.track(model, **tracker_options)
45+
tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR)
3746

47+
tracker.finalise(**tracker_options)
3848
tracker.print_summary()
3949

4050

41-
def test_epoch(model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint):
51+
def test_epoch(
52+
model: BaseModel,
53+
dataset,
54+
device,
55+
tracker: BaseTracker,
56+
checkpoint: ModelCheckpoint,
57+
voting_runs=1,
58+
tracker_options={},
59+
):
4260

4361
loaders = dataset.test_dataloaders
4462

4563
for loader in loaders:
4664
stage_name = loader.dataset.name
4765
tracker.reset(stage_name)
48-
with Ctq(loader) as tq_test_loader:
49-
for data in tq_test_loader:
50-
with torch.no_grad():
51-
model.set_input(data, device)
52-
model.forward()
53-
54-
tracker.track(model)
55-
tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR)
56-
57-
tracker.print_summary()
58-
59-
60-
def run(cfg, model, dataset: BaseDataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint):
66+
for i in range(voting_runs):
67+
with Ctq(loader) as tq_test_loader:
68+
for data in tq_test_loader:
69+
with torch.no_grad():
70+
model.set_input(data, device)
71+
model.forward()
72+
73+
tracker.track(model, **tracker_options)
74+
tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR)
75+
76+
tracker.finalise(**tracker_options)
77+
tracker.print_summary()
78+
79+
80+
def run(
81+
cfg,
82+
model,
83+
dataset: BaseDataset,
84+
device,
85+
tracker: BaseTracker,
86+
checkpoint: ModelCheckpoint,
87+
voting_runs=1,
88+
tracker_options={},
89+
):
6190
if dataset.has_val_loader:
62-
eval_epoch(model, dataset, device, tracker, checkpoint)
91+
eval_epoch(
92+
model, dataset, device, tracker, checkpoint, voting_runs=voting_runs, tracker_options=tracker_options
93+
)
6394

64-
test_epoch(model, dataset, device, tracker, checkpoint)
95+
if dataset.has_test_loaders:
96+
test_epoch(
97+
model, dataset, device, tracker, checkpoint, voting_runs=voting_runs, tracker_options=tracker_options
98+
)
6599

66100

67101
@hydra.main(config_path="conf/eval.yaml")
@@ -98,7 +132,16 @@ def main(cfg):
98132
tracker: BaseTracker = dataset.get_tracker(model, dataset, False, False)
99133

100134
# Run training / evaluation
101-
run(cfg, model, dataset, device, tracker, checkpoint)
135+
run(
136+
cfg,
137+
model,
138+
dataset,
139+
device,
140+
tracker,
141+
checkpoint,
142+
voting_runs=cfg.voting_runs,
143+
tracker_options=cfg.tracker_options,
144+
)
102145

103146

104147
if __name__ == "__main__":

poetry.lock

+26-27
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ markdown==3.2.1
3535
markupsafe==1.1.1
3636
matplotlib==3.2.1
3737
networkx==2.4
38-
numpy==1.18.2
38+
numpy==1.18.3
3939
nvidia-ml-py3==7.352.0
4040
oauthlib==3.1.0
4141
omegaconf==1.4.1

test/test_activate_dropout.py test/test_basemodel.py

+13
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def __init__(self):
3030
self._channels = [12, 12, 12, 12]
3131
self.nn = MLP(self._channels)
3232

33+
def set_input(self, a):
34+
self.input = a
35+
3336

3437
class TestSimpleBatch(unittest.TestCase):
3538
def test_enable_dropout_eval(self):
@@ -46,5 +49,15 @@ def test_enable_dropout_eval(self):
4649
self.assertEqual(model.nn[i][2].training, False)
4750

4851

52+
class TestBaseModel(unittest.TestCase):
53+
def test_getinput(self):
54+
model = MockModel()
55+
with self.assertRaises(AttributeError):
56+
model.get_input()
57+
58+
model.set_input(1)
59+
self.assertEqual(model.get_input(), 1)
60+
61+
4962
if __name__ == "__main__":
5063
unittest.main()

test/test_segmentationtracker.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_track(self):
7070
self.assertAlmostEqual(metrics[k], 0, 5)
7171

7272
def test_ignore_label(self):
73-
tracker = SegmentationTracker(MockDataset())
73+
tracker = SegmentationTracker(MockDataset(), ignore_label=-100)
7474
tracker.reset("test")
7575
model = MockModel()
7676
model.iter = 3
@@ -79,6 +79,16 @@ def test_ignore_label(self):
7979
for k in ["test_acc", "test_miou", "test_macc"]:
8080
self.assertAlmostEqual(metrics[k], 100, 5)
8181

82+
def test_finalise(self):
83+
tracker = SegmentationTracker(MockDataset(), ignore_label=-100)
84+
tracker.reset("test")
85+
model = MockModel()
86+
model.iter = 3
87+
tracker.track(model)
88+
tracker.finalise()
89+
with self.assertRaises(RuntimeError):
90+
tracker.track(model)
91+
8292

8393
if __name__ == "__main__":
8494
unittest.main()

test/test_sphere_sampling.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import os
2+
import sys
3+
import unittest
4+
import numpy as np
5+
import torch
6+
from torch_geometric.data import Data
7+
8+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
9+
sys.path.insert(0, ROOT)
10+
11+
from torch_points3d.core.data_transform.transforms import RandomSphere, SphereSampling
12+
13+
14+
class TestRandomSphere(unittest.TestCase):
15+
def setUp(self):
16+
17+
pos = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 1], [1, 1, 1], [0, 1, 1]])
18+
labels = np.array([0, 0, 0, 0, 0, 0])
19+
20+
self.data = Data(pos=torch.from_numpy(pos).float(), labels=torch.from_numpy(labels))
21+
22+
def test_neighbour_found_under_random_sampling(self):
23+
random_sphere = RandomSphere(0.1, strategy="RANDOM")
24+
data = random_sphere(self.data.clone())
25+
assert data.labels.shape[0] == 1
26+
27+
random_sphere = RandomSphere(3, strategy="RANDOM")
28+
data = random_sphere(self.data.clone())
29+
assert data.labels.shape[0] == 6
30+
31+
32+
class TestSphereSampling(unittest.TestCase):
33+
def setUp(self):
34+
pos = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 1], [1, 1, 1], [0, 1, 1]])
35+
labels = torch.tensor([0, 1, 2, 0, 0, 0])
36+
self.data = Data(pos=pos.float(), labels=labels)
37+
38+
def test_sphere(self):
39+
sphere_sampling = SphereSampling(0.1, [0, 0, 0])
40+
sampled = sphere_sampling(self.data)
41+
42+
self.assertIn(SphereSampling.KDTREE_KEY, self.data)
43+
self.assertEqual(len(sampled.labels), 1)
44+
self.assertEqual(sampled.labels[0], 2)
45+
46+
def test_align(self):
47+
sphere_sampling = SphereSampling(0.1, [1, 0, 0])
48+
sampled = sphere_sampling(self.data)
49+
torch.testing.assert_allclose(sampled.pos, torch.tensor([[0.0, 0, 0]]))
50+
self.assertEqual(sampled.labels[0], 0)
51+
52+
sphere_sampling = SphereSampling(0.1, [1, 0, 0], align_origin=False)
53+
sampled = sphere_sampling(self.data)
54+
torch.testing.assert_allclose(sampled.pos, torch.tensor([[1.0, 0, 0]]))
55+
self.assertEqual(sampled.labels[0], 0)
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

0 commit comments

Comments
 (0)