You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md
+29-32Lines changed: 29 additions & 32 deletions
Original file line number
Diff line number
Diff line change
@@ -17,7 +17,7 @@ We are hoping to get your thoughts about the API prior finalizing it. To collect
17
17
18
18
TorchVision currently provides pre-trained models which could be a starting point for transfer learning or used as-is in Computer Vision applications. The typical way to instantiate a pre-trained model and make a prediction is:
The weights of each model are associated with meta-data. The type of information we store depends on the task of the model (Classification, Detection, Segmentation etc). Typical information includes a link to the training recipe, the interpolation mode, information such as the categories and validation metrics. These values are programmatically accessible via the `meta` attribute:
124
122
125
-
```
123
+
```python
126
124
from torchvision.prototype.models import ResNet50_Weights
for k, v in ResNet50_Weights.ImageNet1K_V2.meta.items():
133
131
print(k, v)
134
-
135
132
```
136
133
137
134
Additionally, each weights entry is associated with the necessary preprocessing transforms. All current preprocessing transforms are JIT-scriptable and can be accessed via the `transforms` attribute. Prior using them with the data, the transforms need to be initialized/constructed. This lazy initialization scheme is done to ensure the solution is memory efficient. The input of the transforms can be either a `PIL.Image` or a `Tensor` read using `torchvision.io`.
138
135
139
-
```
136
+
```python
140
137
from torchvision.prototype.models import ResNet50_Weights
141
138
142
139
# Initializing preprocessing at standard 224x224 resolution
# Once initialized the callable can accept the image data:
149
146
# img_preprocessed = preprocess(img)
150
-
151
147
```
152
148
153
149
Associating the weights with their meta-data and preprocessing will boost transparency, improve reproducibility and make it easier to document how a set of weights was produced.
@@ -156,7 +152,7 @@ Associating the weights with their meta-data and preprocessing will boost transp
156
152
157
153
The ability to link directly the weights with their properties (meta data, preprocessing callables etc) is the reason why our implementation uses Enums instead of Strings. Nevertheless for cases when only the name of the weights is available, we offer a method capable of linking Weight names to their Enums:
158
154
159
-
```
155
+
```python
160
156
from torchvision.prototype.models import get_weight
In the new API the boolean `pretrained` and `pretrained_backbone` parameters, which were previously used to load weights to the full model or to its backbone, are deprecated. The current implementation is fully backwards compatible as it seamlessly maps the old parameters to the new ones. Using the old parameters to the new builders emits the following deprecation warnings:
174
169
175
-
```
170
+
```python
176
171
>>> model = torchvision.prototype.models.resnet50(pretrained=True)
177
172
UserWarning: The parameter 'pretrained'is deprecated, please use 'weights' instead.
178
173
UserWarning:
@@ -183,7 +178,7 @@ You can also use `weights=ResNet50_Weights.default` to get the most up-to-date w
183
178
184
179
Additionally the builder methods require using keyword parameters. The use of positional parameter is deprecated and using them emits the following warning:
185
180
186
-
```
181
+
```python
187
182
>>> model = torchvision.prototype.models.resnet50(None)
188
183
UserWarning:
189
184
Using 'weights'as positional parameter(s) is deprecated.
@@ -217,27 +212,31 @@ For alternative ways to install the nightly have a look on the PyTorch [download
217
212
## Accessing state-of-the-art model weights with the new API
218
213
219
214
If you are still unconvinced about giving a try to the new API, here is one more reason to do so. We’ve recently refreshed our [training recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/) and achieved SOTA accuracy from many of our models. The improved weights can easily be accessed via the new API. Here is a quick overview of the model improvements:
@@ -246,6 +245,4 @@ If you are still unconvinced about giving a try to the new API, here is one more
246
245
|Wide ResNet50 2|78.468|81.602|
247
246
|Wide ResNet101 2|78.848|82.51|
248
247
249
-
* At the time of writing, the RegNet refresh work is in progress, see [PR 5107](https://github.com/pytorch/vision/pull/5107).
250
-
251
248
Please spare a few minutes to provide your feedback on the new API, as this is crucial for graduating it from prototype and including it in the next release. You can do this on the dedicated [Github Issue](https://github.com/pytorch/vision/issues/5088). We are looking forward to reading your comments!
0 commit comments