Skip to content

Commit 0a919db

Browse files
authoredAug 1, 2022
Add registration mechanism for models (#6333)
* Model registration mechanism. * Add overwrite options to the dataset prototype registration mechanism. * Adding example models. * Fix module filtering * Fix linter * Fix docs * Make name optional if same as model builder * Apply updates from code-review. * fix minor bug * Adding getter for model weight enum * Support both strings and callables on get_model_weight. * linter fixes * Fixing mypy. * Renaming `get_model_weight` to `get_model_weights` * Registering all classification models. * Registering all video models. * Registering all detection models. * Registering all optical flow models. * Fixing mypy. * Registering all segmentation models. * Registering all quantization models. * Fixing linter * Registering all prototype depth perception models. * Adding tests and updating existing tests. * Fix linters * Fix tests. * Add beta annotation on docs. * Fix tests. * Apply changes from code-review. * Adding documentation. * Fix docs.
1 parent 6387051 commit 0a919db

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+374
-120
lines changed
 

‎docs/source/models.rst

+40
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,46 @@ behavior, such as batch normalization. To switch between these modes, use
120120
# Set model to eval mode
121121
model.eval()
122122
123+
Model Registration Mechanism
124+
----------------------------
125+
126+
.. betastatus:: registration mechanism
127+
128+
As of v0.14, TorchVision offers a new model registration mechanism which allows retreaving models
129+
and weights by their names. Here are a few examples on how to use them:
130+
131+
.. code:: python
132+
133+
# List available models
134+
all_models = list_models()
135+
classification_models = list_models(module=torchvision.models)
136+
137+
# Initialize models
138+
m1 = get_model("mobilenet_v3_large", weights=None)
139+
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
140+
141+
# Fetch weights
142+
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
143+
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
144+
145+
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
146+
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
147+
148+
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
149+
assert weights_enum == weights_enum2
150+
151+
Here are the available public methods of the model registration mechanism:
152+
153+
.. currentmodule:: torchvision.models
154+
.. autosummary::
155+
:toctree: generated/
156+
:template: function.rst
157+
158+
get_model
159+
get_model_weights
160+
get_weight
161+
list_models
162+
123163
Using models from Hub
124164
---------------------
125165

‎test/test_backbone_utils.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,6 @@
1111
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
1212

1313

14-
def get_available_models():
15-
# TODO add a registration mechanism to torchvision.models
16-
return [
17-
k
18-
for k, v in models.__dict__.items()
19-
if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
20-
]
21-
22-
2314
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
2415
def test_resnet_fpn_backbone(backbone_name):
2516
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
@@ -135,10 +126,10 @@ def _get_return_nodes(self, model):
135126
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
136127
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
137128

138-
@pytest.mark.parametrize("model_name", get_available_models())
129+
@pytest.mark.parametrize("model_name", models.list_models(models))
139130
def test_build_fx_feature_extractor(self, model_name):
140131
set_rng_seed(0)
141-
model = models.__dict__[model_name](**self.model_defaults).eval()
132+
model = models.get_model(model_name, **self.model_defaults).eval()
142133
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
143134
# Check that it works with both a list and dict for return nodes
144135
self._create_feature_extractor(
@@ -172,9 +163,9 @@ def test_node_name_conventions(self):
172163
train_nodes, _ = get_graph_node_names(model)
173164
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
174165

175-
@pytest.mark.parametrize("model_name", get_available_models())
166+
@pytest.mark.parametrize("model_name", models.list_models(models))
176167
def test_forward_backward(self, model_name):
177-
model = models.__dict__[model_name](**self.model_defaults).train()
168+
model = models.get_model(model_name, **self.model_defaults).train()
178169
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
179170
model = self._create_feature_extractor(
180171
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
@@ -211,10 +202,10 @@ def test_feature_extraction_methods_equivalence(self):
211202
for k in ilg_out.keys():
212203
assert ilg_out[k].equal(fgn_out[k])
213204

214-
@pytest.mark.parametrize("model_name", get_available_models())
205+
@pytest.mark.parametrize("model_name", models.list_models(models))
215206
def test_jit_forward_backward(self, model_name):
216207
set_rng_seed(0)
217-
model = models.__dict__[model_name](**self.model_defaults).train()
208+
model = models.get_model(model_name, **self.model_defaults).train()
218209
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
219210
model = self._create_feature_extractor(
220211
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes

0 commit comments

Comments
 (0)
Please sign in to comment.