Skip to content

Commit 474ae0e

Browse files
authored
Update 2021-10-28-FX-feature-extraction-torchvision.md
1 parent b2ff918 commit 474ae0e

File tree

1 file changed

+101
-81
lines changed

1 file changed

+101
-81
lines changed

_posts/2021-10-28-FX-feature-extraction-torchvision.md

Lines changed: 101 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -48,54 +48,63 @@ To illustrate these, let’s consider a simple convolutional neural network that
4848
```python
4949
import torch
5050
from torch import nn
51+
52+
5153
class ConvBlock(nn.Module):
52-
"""
53-
Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
54-
via 2x2 max pool.
55-
"""
56-
def __init__(self, num_layers, in_channels, out_channels):
57-
super().__init__()
58-
self.convs = nn.ModuleList(
59-
[nn.Sequential(
60-
nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
61-
nn.ReLU()
62-
)
63-
for i in range(num_layers)]
64-
)
65-
self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
66-
67-
def forward(self, x):
68-
for conv in self.convs:
69-
x = conv(x)
70-
x = self.downsample(x)
71-
return x
72-
54+
"""
55+
Applies `num_layers` 3x3 convolutions each followed by ReLU then downsamples
56+
via 2x2 max pool.
57+
"""
58+
59+
def __init__(self, num_layers, in_channels, out_channels):
60+
super().__init__()
61+
self.convs = nn.ModuleList(
62+
[nn.Sequential(
63+
nn.Conv2d(in_channels if i==0 else out_channels, out_channels, 3, padding=1),
64+
nn.ReLU()
65+
)
66+
for i in range(num_layers)]
67+
)
68+
self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
69+
70+
def forward(self, x):
71+
for conv in self.convs:
72+
x = conv(x)
73+
x = self.downsample(x)
74+
return x
75+
76+
7377
class CNN(nn.Module):
74-
"""
75-
Applies several ConvBlocks each doubling the number of channels, and
76-
halving the feature map size, before taking a global average and classifying.
77-
"""
78-
def __init__(self, in_channels, num_blocks, num_classes):
79-
super().__init__()
80-
first_channels = 64
81-
self.blocks = nn.ModuleList(
82-
[ConvBlock(
83-
2 if i==0 else 3,
84-
in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
85-
out_channels=first_channels*(2**i))
86-
for i in range(num_blocks)]
87-
)
88-
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
89-
self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)
90-
def forward(self, x):
91-
for block in self.blocks:
92-
x = block(x)
93-
x = self.global_pool(x)
94-
x = x.flatten(1)
95-
x = self.cls(x)
96-
return x
78+
"""
79+
Applies several ConvBlocks each doubling the number of channels, and
80+
halving the feature map size, before taking a global average and classifying.
81+
"""
82+
83+
def __init__(self, in_channels, num_blocks, num_classes):
84+
super().__init__()
85+
first_channels = 64
86+
self.blocks = nn.ModuleList(
87+
[ConvBlock(
88+
2 if i==0 else 3,
89+
in_channels=(in_channels if i == 0 else first_channels*(2**(i-1))),
90+
out_channels=first_channels*(2**i))
91+
for i in range(num_blocks)]
92+
)
93+
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
94+
self.cls = nn.Linear(first_channels*(2**(num_blocks-1)), num_classes)
95+
96+
def forward(self, x):
97+
for block in self.blocks:
98+
x = block(x)
99+
x = self.global_pool(x)
100+
x = x.flatten(1)
101+
x = self.cls(x)
102+
return x
103+
104+
97105
model = CNN(3, 4, 10)
98106
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
107+
99108
```
100109

101110
Let’s say we want to get the final feature map before global average pooling. We could do the following:
@@ -104,26 +113,26 @@ Let’s say we want to get the final feature map before global average pooling.
104113

105114
```python
106115
def forward(self, x):
107-
for block in self.blocks:
108-
x = block(x)
109-
self.final_feature_map = x
110-
x = self.global_pool(x)
111-
x = x.flatten(1)
112-
x = self.cls(x)
113-
return x
116+
for block in self.blocks:
117+
x = block(x)
118+
self.final_feature_map = x
119+
x = self.global_pool(x)
120+
x = x.flatten(1)
121+
x = self.cls(x)
122+
return x
114123
```
115124

116125
Or return it directly:
117126

118127
```python
119128
def forward(self, x):
120-
for block in self.blocks:
121-
x = block(x)
122-
self.final_feature_map = x
123-
x = self.global_pool(x)
124-
x = x.flatten(1)
125-
x = self.cls(x)
126-
return x
129+
for block in self.blocks:
130+
x = block(x)
131+
final_feature_map = x
132+
x = self.global_pool(x)
133+
x = x.flatten(1)
134+
x = self.cls(x)
135+
return x, final_feature_map
127136
```
128137
That looks pretty easy. But there are some downsides here which all stem from the same underlying issue: that is, modifying the source code is not ideal:
129138

@@ -140,15 +149,18 @@ Following on the example from above, say we want to get a feature map from each
140149

141150
```python
142151
class CNNFeatures(nn.Module):
143-
def __init__(self, backbone):
144-
super().__init__()
145-
self.blocks = backbone.blocks
146-
def forward(self, x):
147-
feature_maps = []
148-
for block in self.blocks:
149-
x = block(x)
150-
feature_maps.append(x)
151-
return feature_maps
152+
def __init__(self, backbone):
153+
super().__init__()
154+
self.blocks = backbone.blocks
155+
156+
def forward(self, x):
157+
feature_maps = []
158+
for block in self.blocks:
159+
x = block(x)
160+
feature_maps.append(x)
161+
return feature_maps
162+
163+
152164
backbone = CNN(3, 4, 10)
153165
model = CNNFeatures(backbone)
154166
out = model(torch.zeros(1, 3, 32, 32)) # This is now a list of Tensors, each representing a feature map
@@ -171,10 +183,13 @@ Hooks move us away from the paradigm of writing source code, towards one of spec
171183
```python
172184
model = CNN(3, 4, 10)
173185
feature_maps = [] # This will be a list of Tensors, each representing a feature map
186+
174187
def hook_feat_map(mod, inp, out):
175188
feature_maps.append(out)
189+
176190
for block in model.blocks:
177191
block.register_forward_hook(hook_feat_map)
192+
178193
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
179194
```
180195

@@ -207,12 +222,13 @@ The natural question for some new-starters in Python and coding at this point mi
207222

208223
```python
209224
class MyModule(torch.nn.Module):
210-
def __init__(self):
211-
super().__init__()
212-
self.param = torch.nn.Parameter(torch.rand(3, 4))
213-
self.submodule = MySubModule()
214-
def forward(self, x):
215-
return self.submodule(x + self.param).clamp(min=0.0, max=1.0)
225+
def __init__(self):
226+
super().__init__()
227+
self.param = torch.nn.Parameter(torch.rand(3, 4))
228+
self.submodule = MySubModule()
229+
230+
def forward(self, x):
231+
return self.submodule(x + self.param).clamp(min=0.0, max=1.0)
216232
```
217233

218234
The forward method has a single line of code which we can unravel as:
@@ -337,10 +353,12 @@ We could do something similar with functions. For example, Python’s inbuilt `l
337353

338354
```python
339355
torch.fx.wrap('len')
356+
340357
class MyModule(nn.Module):
341-
def forward(self, x):
342-
x += 1
343-
len(x)
358+
def forward(self, x):
359+
x += 1
360+
len(x)
361+
344362
model = MyModule()
345363
feature_extractor = create_feature_extractor(model, return_nodes=['add'])
346364
```
@@ -350,14 +368,16 @@ For functions you define, you may instead use another keyword argument to `creat
350368

351369
```python
352370
def myfunc(x):
353-
return len(x)
371+
return len(x)
372+
354373
class MyModule(nn.Module):
355-
def forward(self, x):
356-
x += 1
357-
myfunc(x)
374+
def forward(self, x):
375+
x += 1
376+
myfunc(x)
377+
358378
model = MyModule()
359379
feature_extractor = create_feature_extractor(
360-
model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})
380+
model, return_nodes=['add'], tracer_kwargs={'autowrap_functions': [myfunc]})
361381
```
362382

363383
Notice that none of the fixes above involved modifying source code.

0 commit comments

Comments
 (0)