You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
107
+
99
108
```
100
109
101
110
Let’s say we want to get the final feature map before global average pooling. We could do the following:
@@ -104,26 +113,26 @@ Let’s say we want to get the final feature map before global average pooling.
104
113
105
114
```python
106
115
defforward(self, x):
107
-
for block inself.blocks:
108
-
x = block(x)
109
-
self.final_feature_map = x
110
-
x =self.global_pool(x)
111
-
x = x.flatten(1)
112
-
x =self.cls(x)
113
-
return x
116
+
for block inself.blocks:
117
+
x = block(x)
118
+
self.final_feature_map = x
119
+
x =self.global_pool(x)
120
+
x = x.flatten(1)
121
+
x =self.cls(x)
122
+
return x
114
123
```
115
124
116
125
Or return it directly:
117
126
118
127
```python
119
128
defforward(self, x):
120
-
for block inself.blocks:
121
-
x = block(x)
122
-
self.final_feature_map = x
123
-
x =self.global_pool(x)
124
-
x = x.flatten(1)
125
-
x =self.cls(x)
126
-
return x
129
+
for block inself.blocks:
130
+
x = block(x)
131
+
final_feature_map = x
132
+
x =self.global_pool(x)
133
+
x = x.flatten(1)
134
+
x =self.cls(x)
135
+
return x, final_feature_map
127
136
```
128
137
That looks pretty easy. But there are some downsides here which all stem from the same underlying issue: that is, modifying the source code is not ideal:
129
138
@@ -140,15 +149,18 @@ Following on the example from above, say we want to get a feature map from each
140
149
141
150
```python
142
151
classCNNFeatures(nn.Module):
143
-
def__init__(self, backbone):
144
-
super().__init__()
145
-
self.blocks = backbone.blocks
146
-
defforward(self, x):
147
-
feature_maps = []
148
-
for block inself.blocks:
149
-
x = block(x)
150
-
feature_maps.append(x)
151
-
return feature_maps
152
+
def__init__(self, backbone):
153
+
super().__init__()
154
+
self.blocks = backbone.blocks
155
+
156
+
defforward(self, x):
157
+
feature_maps = []
158
+
for block inself.blocks:
159
+
x = block(x)
160
+
feature_maps.append(x)
161
+
return feature_maps
162
+
163
+
152
164
backbone = CNN(3, 4, 10)
153
165
model = CNNFeatures(backbone)
154
166
out = model(torch.zeros(1, 3, 32, 32)) # This is now a list of Tensors, each representing a feature map
@@ -171,10 +183,13 @@ Hooks move us away from the paradigm of writing source code, towards one of spec
171
183
```python
172
184
model = CNN(3, 4, 10)
173
185
feature_maps = [] # This will be a list of Tensors, each representing a feature map
186
+
174
187
defhook_feat_map(mod, inp, out):
175
188
feature_maps.append(out)
189
+
176
190
for block in model.blocks:
177
191
block.register_forward_hook(hook_feat_map)
192
+
178
193
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
179
194
```
180
195
@@ -207,12 +222,13 @@ The natural question for some new-starters in Python and coding at this point mi
0 commit comments