Skip to content

Commit 84f0b66

Browse files
authored
simple Finetuning (torch-points3d#455)
* add set weights * set pretrained weight * add pretrained possibilities * pre trained seems to work * small change * update changelogs * fix in trainer * load weights with the same size only
1 parent 062d6b9 commit 84f0b66

File tree

6 files changed

+98
-5
lines changed

6 files changed

+98
-5
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- Support for the IRALab benchmark (https://arxiv.org/abs/2003.12841), with data from the ETH, Canadian Planetary, Kaist and TUM datasets.
13-
13+
- Possibility to load pretrained models by adding the path in the confs for finetuning.
1414

1515
### Bug fix
1616

conf/data/registration/testeth.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ data:
44
dataroot: data
55
first_subsampling: 0.02
66
max_dist_overlap: 0.05
7-
min_size_block: 1.5
8-
max_size_block: 2
7+
min_size_block: 5
8+
max_size_block: 7
99
num_pos_pairs: 30000
1010
min_points: 300
1111
num_points: 5000
@@ -21,10 +21,10 @@ data:
2121
ss_transform:
2222
- transform: CubeCrop
2323
params:
24-
c: 1
24+
c: 5
2525
- transform: CubeCrop
2626
params:
27-
c: 1.5
27+
c: 5.5
2828

2929
train_transform:
3030
- transform: SaveOriginalPosId

conf/models/registration/minkowski.yaml

+42
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,48 @@ models:
3939
num_hn_samples: 256
4040
pos_thresh: 0.1
4141
neg_thresh: 1.4
42+
MinkUNet_Fragment_pretrained:
43+
class: minkowski.MinkowskiFragment
44+
path_pretrained: "INSERT PATH"
45+
weight_name: "latest"
46+
conv_type: "SPARSE"
47+
loss_mode: "match"
48+
down_conv:
49+
module_name: Res2BlockDown
50+
dimension: 3
51+
bn_momentum: 0.05
52+
down_conv_nn:
53+
[
54+
[FEAT, 32],
55+
[32, 64],
56+
[64, 128],
57+
[128, 256]
58+
]
59+
kernel_size: [5, 3, 3, 3]
60+
stride: [1, 2, 2, 2]
61+
dilation: [1, 1, 1, 1]
62+
up_conv:
63+
module_name: Res2BlockUp
64+
dimension: 3
65+
bn_momentum: 0.05
66+
up_conv_nn:
67+
[
68+
[256, 64],
69+
[64 + 128, 64],
70+
[64 + 64, 64],
71+
[64 + 32, 64, 32]
72+
]
73+
kernel_size: [3, 3, 3, 1]
74+
stride: [2, 2, 2, 1]
75+
dilation: [1, 1, 1, 1]
76+
normalize_feature: True
77+
metric_loss:
78+
class: "ContrastiveHardestNegativeLoss"
79+
params:
80+
num_pos: 1024
81+
num_hn_samples: 256
82+
pos_thresh: 0.1
83+
neg_thresh: 1.4
4284

4385

4486
Res16UNet32B:

test/test_basemodel.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import torch
23
from omegaconf import OmegaConf, DictConfig
34
from torch.nn import (
45
Sequential,
@@ -49,6 +50,20 @@ def set_input(self, a):
4950
self.input = a
5051

5152

53+
class MockModel_(BaseModel):
54+
__REQUIRED_DATA__ = ["x"]
55+
__REQUIRED_LABELS__ = ["y"]
56+
57+
def __init__(self):
58+
super(MockModel_, self).__init__(DictConfig({"conv_type": "Dummy"}))
59+
60+
self._channels = [12, 12, 12, 17]
61+
self.nn = MLP(self._channels)
62+
63+
def set_input(self, a):
64+
self.input = a
65+
66+
5267
class TestBaseModel(unittest.TestCase):
5368
def test_getinput(self):
5469
model = MockModel()
@@ -71,6 +86,21 @@ def test_enable_dropout_eval(self):
7186
self.assertEqual(model.nn[i][1].training, True)
7287
self.assertEqual(model.nn[i][2].training, False)
7388

89+
def test_load_pretrained_model(self):
90+
"""
91+
test load_state_dict_with_same_shape
92+
"""
93+
model1 = MockModel()
94+
model2 = MockModel_()
95+
96+
w1 = model1.state_dict()
97+
98+
model2.load_state_dict_with_same_shape(w1)
99+
w2 = model2.state_dict()
100+
for k, p in w2.items():
101+
if "nn.2." not in k:
102+
torch.testing.assert_allclose(w1[k], p)
103+
74104
def test_accumulated_gradient(self):
75105
params = load_model_config("segmentation", "pointnet2", "pointnet2ms")
76106
config_training = OmegaConf.load(os.path.join(DIR, "test_config/training_config.yaml"))

torch_points3d/models/base_model.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import OrderedDict
22
from abc import abstractmethod
33
from typing import Optional, Dict, Any, List
4+
import os
45
import torch
56
from torch.optim.optimizer import Optimizer
67
from torch.optim.lr_scheduler import _LRScheduler
@@ -133,6 +134,24 @@ def set_input(self, input, device):
133134
"""
134135
raise NotImplementedError
135136

137+
def load_state_dict_with_same_shape(self, weights, strict=False):
138+
model_state = self.state_dict()
139+
filtered_weights = {k: v for k, v in weights.items() if k in model_state and v.size() == model_state[k].size()}
140+
log.info("Loading weights:" + ", ".join(filtered_weights.keys()))
141+
self.load_state_dict(filtered_weights, strict=strict)
142+
143+
def set_pretrained_weights(self):
144+
path_pretrained = getattr(self.opt, "path_pretrained", None)
145+
weight_name = getattr(self.opt, "weight_name", "latest")
146+
147+
if path_pretrained is not None:
148+
if not os.path.exists(path_pretrained):
149+
log.warning("The path does not exist, it will not load any model")
150+
else:
151+
log.info("load pretrained weights from {}".format(path_pretrained))
152+
m = torch.load(path_pretrained)["models"][weight_name]
153+
self.load_state_dict_with_same_shape(m, strict=False)
154+
136155
def get_labels(self):
137156
""" returns a trensor of size ``[N_points]`` where each value is the label of a point
138157
"""

torch_points3d/trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,11 @@ def _initialize_trainer(self):
9393
)
9494
self._model: BaseModel = instantiate_model(copy.deepcopy(self._cfg), self._dataset)
9595
self._model.instantiate_optimizers(self._cfg)
96+
self._model.set_pretrained_weights()
9697
self._checkpoint.dataset_properties = self._dataset.used_properties
9798

9899
log.info(self._model)
100+
99101
self._model.log_optimizers()
100102
log.info("Model size = %i", sum(param.numel() for param in self._model.parameters() if param.requires_grad))
101103

0 commit comments

Comments
 (0)