-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathtest_backbone_utils.py
336 lines (293 loc) · 13.5 KB
/
test_backbone_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import random
from copy import deepcopy
from itertools import chain
from typing import Mapping, Sequence
import pytest
import torch
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
assert isinstance(model, BackboneWithFPN)
y = model(x)
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
assert isinstance(model_fpn, BackboneWithFPN)
model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
assert isinstance(model, torch.nn.Sequential)
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
return int(x)
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = x + 1
x = x + 1
x = self.relu(x)
x = self.relu(x)
return x
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = TestSubModule()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.submodule(x)
x = x + 1
x = x + 1
x = self.relu(x)
x = self.relu(x)
return x
test_module_nodes = [
"x",
"submodule.add",
"submodule.add_1",
"submodule.relu",
"submodule.relu_1",
"add",
"add_1",
"relu",
"relu_1",
]
class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
model_defaults = {"num_classes": 1}
leaf_modules = []
def _create_feature_extractor(self, *args, **kwargs):
"""
Apply leaf modules
"""
tracer_kwargs = {}
if "tracer_kwargs" not in kwargs:
tracer_kwargs = {"leaf_modules": self.leaf_modules}
else:
tracer_kwargs = kwargs.pop("tracer_kwargs")
return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = [
"getitem",
"floordiv",
"size",
"chunk",
"_assert",
"eq",
"dim",
"getattr",
]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
)
# Get rid of any nodes that don't return tensors as they cause issues
# when testing backward pass.
train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0)
model = models.get_model(model_name, **self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes
self._create_feature_extractor(
model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
)
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
# Check must specify return nodes
with pytest.raises(ValueError):
self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity
with pytest.raises(ValueError):
self._create_feature_extractor(
model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
)
# Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(ValueError):
self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError
with pytest.raises(ValueError):
# First just double check that this node really doesn't exist
if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
else: # otherwise skip this check
raise ValueError
def test_node_name_conventions(self):
model = TestModule()
train_nodes, _ = get_graph_node_names(model)
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_forward_backward(self, model_name):
model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
out = model(self.inp)
out_agg = 0
for node_out in out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.float().mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.float().mean()
out_agg.backward()
def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
ilg_model = IntermediateLayerGetter(model, return_layers).eval()
fx_model = self._create_feature_extractor(model, return_layers)
# Check that we have same parameters
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
assert n1 == n2
assert p1.equal(p2)
# And that outputs match
with torch.no_grad():
ilg_out = ilg_model(self.inp)
fgn_out = fx_model(self.inp)
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k])
@pytest.mark.parametrize("model_name", models.list_models(models))
def test_jit_forward_backward(self, model_name):
set_rng_seed(0)
model = models.get_model(model_name, **self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
model = torch.jit.script(model)
fgn_out = model(self.inp)
out_agg = 0
for node_out in fgn_out.values():
if isinstance(node_out, Sequence):
out_agg += sum(o.float().mean() for o in node_out if o is not None)
elif isinstance(node_out, Mapping):
out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
else:
# Assume that the only other alternative at this point is a Tensor
out_agg += node_out.float().mean()
out_agg.backward()
def test_train_eval(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(p=1.0)
def forward(self, x):
x = x.float().mean()
x = self.dropout(x) # dropout
if self.training:
x += 100 # add
else:
x *= 0 # mul
x -= 0 # sub
return x
model = TestModel()
train_return_nodes = ["dropout", "add", "sub"]
eval_return_nodes = ["dropout", "mul", "sub"]
def checks(model, mode):
with torch.no_grad():
out = model(torch.ones(10, 10))
if mode == "train":
# Check that dropout is respected
assert out["dropout"].item() == 0
# Check that control flow dependent on training_mode is respected
assert out["sub"].item() == 100
assert "add" in out
assert "mul" not in out
elif mode == "eval":
# Check that dropout is respected
assert out["dropout"].item() == 1
# Check that control flow dependent on training_mode is respected
assert out["sub"].item() == 0
assert "mul" in out
assert "add" not in out
# Starting from train mode
model.train()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
# Check that the models stay in their original training state
assert model.training
assert fx_model.training
# Check outputs
checks(fx_model, "train")
# Check outputs after switching to eval mode
fx_model.eval()
checks(fx_model, "eval")
# Starting from eval mode
model.eval()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
# Check that the models stay in their original training state
assert not model.training
assert not fx_model.training
# Check outputs
checks(fx_model, "eval")
# Check outputs after switching to train mode
fx_model.train()
checks(fx_model, "train")
def test_leaf_module_and_function(self):
class LeafModule(torch.nn.Module):
def forward(self, x):
# This would raise a TypeError if it were not in a leaf module
int(x.shape[0])
return torch.nn.functional.relu(x + 4)
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3)
self.leaf_module = LeafModule()
def forward(self, x):
leaf_function(x.shape[0])
x = self.conv(x)
return self.leaf_module(x)
model = self._create_feature_extractor(
TestModule(),
return_nodes=["leaf_module"],
tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
).train()
# Check that LeafModule is not in the list of nodes
assert "relu" not in [str(n) for n in model.graph.nodes]
assert "leaf_module" in [str(n) for n in model.graph.nodes]
# Check forward
out = model(self.inp)
# And backward
out["leaf_module"].float().mean().backward()
def test_deepcopy(self):
# Non-regression test for https://github.com/pytorch/vision/issues/8634
model = models.efficientnet_b3(weights=None)
extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"})
extractor.eval()
extractor.train()
extractor = deepcopy(extractor)
extractor.eval()
extractor.train()