|
11 | 11 | from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
|
12 | 12 |
|
13 | 13 |
|
14 |
| -def get_available_models(): |
15 |
| - # TODO add a registration mechanism to torchvision.models |
16 |
| - return [ |
17 |
| - k |
18 |
| - for k, v in models.__dict__.items() |
19 |
| - if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight" |
20 |
| - ] |
21 |
| - |
22 |
| - |
23 | 14 | @pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
|
24 | 15 | def test_resnet_fpn_backbone(backbone_name):
|
25 | 16 | x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
|
@@ -135,10 +126,10 @@ def _get_return_nodes(self, model):
|
135 | 126 | eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
|
136 | 127 | return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
|
137 | 128 |
|
138 |
| - @pytest.mark.parametrize("model_name", get_available_models()) |
| 129 | + @pytest.mark.parametrize("model_name", models.list_models(models)) |
139 | 130 | def test_build_fx_feature_extractor(self, model_name):
|
140 | 131 | set_rng_seed(0)
|
141 |
| - model = models.__dict__[model_name](**self.model_defaults).eval() |
| 132 | + model = models.get_model(model_name, **self.model_defaults).eval() |
142 | 133 | train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
|
143 | 134 | # Check that it works with both a list and dict for return nodes
|
144 | 135 | self._create_feature_extractor(
|
@@ -172,9 +163,9 @@ def test_node_name_conventions(self):
|
172 | 163 | train_nodes, _ = get_graph_node_names(model)
|
173 | 164 | assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
|
174 | 165 |
|
175 |
| - @pytest.mark.parametrize("model_name", get_available_models()) |
| 166 | + @pytest.mark.parametrize("model_name", models.list_models(models)) |
176 | 167 | def test_forward_backward(self, model_name):
|
177 |
| - model = models.__dict__[model_name](**self.model_defaults).train() |
| 168 | + model = models.get_model(model_name, **self.model_defaults).train() |
178 | 169 | train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
|
179 | 170 | model = self._create_feature_extractor(
|
180 | 171 | model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
|
@@ -211,10 +202,10 @@ def test_feature_extraction_methods_equivalence(self):
|
211 | 202 | for k in ilg_out.keys():
|
212 | 203 | assert ilg_out[k].equal(fgn_out[k])
|
213 | 204 |
|
214 |
| - @pytest.mark.parametrize("model_name", get_available_models()) |
| 205 | + @pytest.mark.parametrize("model_name", models.list_models(models)) |
215 | 206 | def test_jit_forward_backward(self, model_name):
|
216 | 207 | set_rng_seed(0)
|
217 |
| - model = models.__dict__[model_name](**self.model_defaults).train() |
| 208 | + model = models.get_model(model_name, **self.model_defaults).train() |
218 | 209 | train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
|
219 | 210 | model = self._create_feature_extractor(
|
220 | 211 | model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
|
|
0 commit comments