Skip to content

Commit 36b5590

Browse files
authored
Update 2021-12-22-introducing-torchvision-new-multi-weight-support-api.md
1 parent 88649f3 commit 36b5590

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

_posts/2021-12-22-introducing-torchvision-new-multi-weight-support-api.md

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ We are hoping to get your thoughts about the API prior finalizing it. To collect
1717

1818
TorchVision currently provides pre-trained models which could be a starting point for transfer learning or used as-is in Computer Vision applications. The typical way to instantiate a pre-trained model and make a prediction is:
1919

20-
```
20+
```python
2121
import torch
2222

2323
from PIL import Image
@@ -50,7 +50,7 @@ score = prediction[class_id].item()
5050
with open("imagenet_classes.txt", "r") as f:
5151
categories = [s.strip() for s in f.readlines()]
5252
category_name = categories[class_id]
53-
print(f"{category_name}: **** {100 * score}%")
53+
print(f"{category_name}: ** {100 * score}%")
5454

5555
```
5656

@@ -66,7 +66,7 @@ The new API addresses the above limitations and reduces the amount of boilerplat
6666

6767
Let’s see how we can achieve exactly the same results as above using the new API:
6868

69-
```
69+
```python
7070
from PIL import Image
7171
from torchvision.prototype import models as PM
7272

@@ -89,8 +89,7 @@ prediction = model(batch).squeeze(0).softmax(0)
8989
class_id = prediction.argmax().item()
9090
score = prediction[class_id].item()
9191
category_name = weights.meta["categories"][class_id]
92-
print(f"{category_name}: **** {100 * score}**%**")
93-
92+
print(f"{category_name}: ** {100 * score}*%*")
9493
```
9594

9695
As we can see the new API eliminates the aforementioned limitations. Let’s explore the new features in detail.
@@ -101,7 +100,7 @@ At the heart of the new API, we have the ability to define multiple different we
101100

102101
Here is an example of initializing models with different weights:
103102

104-
```
103+
```python
105104
from torchvision.prototype.models import resnet50, ResNet50_Weights
106105

107106
# Legacy weights with accuracy 76.130%
@@ -115,14 +114,13 @@ model = resnet50(weights=ResNet50_Weights.default)
115114

116115
# No weights - random initialization
117116
model = resnet50(weights=None)
118-
119117
```
120118

121119
### Associated meta-data & preprocessing transforms
122120

123121
The weights of each model are associated with meta-data. The type of information we store depends on the task of the model (Classification, Detection, Segmentation etc). Typical information includes a link to the training recipe, the interpolation mode, information such as the categories and validation metrics. These values are programmatically accessible via the `meta` attribute:
124122

125-
```
123+
```python
126124
from torchvision.prototype.models import ResNet50_Weights
127125

128126
# Accessing a single record
@@ -131,12 +129,11 @@ size = ResNet50_Weights.ImageNet1K_V2.meta["size"]
131129
# Iterating the items of the meta-data dictionary
132130
for k, v in ResNet50_Weights.ImageNet1K_V2.meta.items():
133131
print(k, v)
134-
135132
```
136133

137134
Additionally, each weights entry is associated with the necessary preprocessing transforms. All current preprocessing transforms are JIT-scriptable and can be accessed via the `transforms` attribute. Prior using them with the data, the transforms need to be initialized/constructed. This lazy initialization scheme is done to ensure the solution is memory efficient. The input of the transforms can be either a `PIL.Image` or a `Tensor` read using `torchvision.io`.
138135

139-
```
136+
```python
140137
from torchvision.prototype.models import ResNet50_Weights
141138

142139
# Initializing preprocessing at standard 224x224 resolution
@@ -147,7 +144,6 @@ preprocess = ResNet50_Weights.ImageNet1K.transforms(crop_size=400, resize_size=4
147144

148145
# Once initialized the callable can accept the image data:
149146
# img_preprocessed = preprocess(img)
150-
151147
```
152148

153149
Associating the weights with their meta-data and preprocessing will boost transparency, improve reproducibility and make it easier to document how a set of weights was produced.
@@ -156,7 +152,7 @@ Associating the weights with their meta-data and preprocessing will boost transp
156152

157153
The ability to link directly the weights with their properties (meta data, preprocessing callables etc) is the reason why our implementation uses Enums instead of Strings. Nevertheless for cases when only the name of the weights is available, we offer a method capable of linking Weight names to their Enums:
158154

159-
```
155+
```python
160156
from torchvision.prototype.models import get_weight
161157

162158
# Weights can be retrieved by name:
@@ -165,14 +161,13 @@ assert get_weight("ResNet50_Weights.ImageNet1K_V2") == ResNet50_Weights.ImageNet
165161

166162
# Including using the default alias:
167163
assert get_weight("ResNet50_Weights.default") == ResNet50_Weights.ImageNet1K_V2
168-
169164
```
170165

171166
## Deprecations
172167

173168
In the new API the boolean `pretrained` and `pretrained_backbone` parameters, which were previously used to load weights to the full model or to its backbone, are deprecated. The current implementation is fully backwards compatible as it seamlessly maps the old parameters to the new ones. Using the old parameters to the new builders emits the following deprecation warnings:
174169

175-
```
170+
```python
176171
>>> model = torchvision.prototype.models.resnet50(pretrained=True)
177172
UserWarning: The parameter 'pretrained' is deprecated, please use 'weights' instead.
178173
UserWarning:
@@ -183,7 +178,7 @@ You can also use `weights=ResNet50_Weights.default` to get the most up-to-date w
183178

184179
Additionally the builder methods require using keyword parameters. The use of positional parameter is deprecated and using them emits the following warning:
185180

186-
```
181+
```python
187182
>>> model = torchvision.prototype.models.resnet50(None)
188183
UserWarning:
189184
Using 'weights' as positional parameter(s) is deprecated.
@@ -217,27 +212,31 @@ For alternative ways to install the nightly have a look on the PyTorch [download
217212
## Accessing state-of-the-art model weights with the new API
218213

219214
If you are still unconvinced about giving a try to the new API, here is one more reason to do so. We’ve recently refreshed our [training recipe](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/) and achieved SOTA accuracy from many of our models. The improved weights can easily be accessed via the new API. Here is a quick overview of the model improvements:
220-
[Image: chart.png]
215+
216+
<div class="text-center">
217+
<img src="{{ site.baseurl }}/assets/images/torchvision_chart1.png" width="100%">
218+
</div>
219+
221220
|Model |Old Acc@1 |New Acc@1 |
222221
|--- |--- |--- |
223222
|EfficientNet B1 |78.642 |79.838 |
224223
|MobileNetV3 Large |74.042 |75.274 |
225224
|Quantized ResNet50 |75.92 |80.282 |
226225
|Quantized ResNeXt101 32x8d |78.986 |82.574 |
227-
|RegNet X 400mf * |72.834 |74.864 |
228-
|RegNet X 800mf * |75.212 |77.522 |
229-
|RegNet X 1 6gf * |77.04 |79.668 |
230-
|RegNet X 3 2gf * |78.364 |81.198 |
231-
|RegNet X 8gf * |79.344 |81.682 |
232-
|RegNet X 16gf * |80.058 |82.72 |
233-
|RegNet X 32gf * |80.622 |83.018 |
234-
|RegNet Y 400mf * |74.046 |75.806 |
235-
|RegNet Y 800mf * |76.42 |78.838 |
236-
|RegNet Y 1 6gf * |77.95 |80.882 |
237-
|RegNet Y 3 2gf * |78.948 |81.984 |
238-
|RegNet Y 8gf * |80.032 |82.828 |
239-
|RegNet Y 16gf * |80.424 |82.89 |
240-
|RegNet Y 32gf * |80.878 |83.366 |
226+
|RegNet X 400mf |72.834 |74.864 |
227+
|RegNet X 800mf |75.212 |77.522 |
228+
|RegNet X 1 6gf |77.04 |79.668 |
229+
|RegNet X 3 2gf |78.364 |81.198 |
230+
|RegNet X 8gf |79.344 |81.682 |
231+
|RegNet X 16gf |80.058 |82.72 |
232+
|RegNet X 32gf |80.622 |83.018 |
233+
|RegNet Y 400mf |74.046 |75.806 |
234+
|RegNet Y 800mf |76.42 |78.838 |
235+
|RegNet Y 1 6gf |77.95 |80.882 |
236+
|RegNet Y 3 2gf |78.948 |81.984 |
237+
|RegNet Y 8gf |80.032 |82.828 |
238+
|RegNet Y 16gf |80.424 |82.89 |
239+
|RegNet Y 32gf |80.878 |83.366 |
241240
|ResNet50 |76.13 |80.674 |
242241
|ResNet101 |77.374 |81.886 |
243242
|ResNet152 |78.312 |82.284 |
@@ -246,6 +245,4 @@ If you are still unconvinced about giving a try to the new API, here is one more
246245
|Wide ResNet50 2 |78.468 |81.602 |
247246
|Wide ResNet101 2 |78.848 |82.51 |
248247

249-
* At the time of writing, the RegNet refresh work is in progress, see [PR 5107](https://github.com/pytorch/vision/pull/5107).
250-
251248
Please spare a few minutes to provide your feedback on the new API, as this is crucial for graduating it from prototype and including it in the next release. You can do this on the dedicated [Github Issue](https://github.com/pytorch/vision/issues/5088). We are looking forward to reading your comments!

0 commit comments

Comments
 (0)