Skip to content

Commit b2ff918

Browse files
authored
Update 2021-10-28-FX-feature-extraction-torchvision.md
1 parent 8b4ae7a commit b2ff918

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

_posts/2021-10-28-FX-feature-extraction-torchvision.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
layout: blog_detail
3-
title: 'FX based Feature Extraction in TorchVision'
3+
title: 'Feature Extraction in TorchVision using Torch FX'
44
author: Alexander Soare and Francisco Massa
55
featured-img: 'assets/images/fx-image2.png'
66
---
@@ -98,7 +98,7 @@ model = CNN(3, 4, 10)
9898
out = model(torch.zeros(1, 3, 32, 32)) # This will be the final logits over classes
9999
```
100100

101-
Let’s say we want to get the final feature map before global average pooling. We could
101+
Let’s say we want to get the final feature map before global average pooling. We could do the following:
102102

103103
### Modify the forward method
104104

@@ -198,7 +198,7 @@ Here’s a summary of the different methods and their pros/cons:
198198

199199
Table 1: The pros (or cons) of some of the existing methods for feature extraction with PyTorch
200200

201-
In the next section of this article, let’s see how we can get greens across the board.
201+
In the next section of this article, let’s see how we can get YES across the board.
202202

203203

204204
## FX to The Rescue
@@ -241,7 +241,7 @@ Note that we call this a graph, and not just a set of steps, because it’s poss
241241
Figure 4: Graphical representation of a residual skip connection. The middle node is like the main branch of a residual block, and the final node represents the sum of the input and output of the main branch.
242242
</p>
243243

244-
Now, TorchVision’s **get_graph_node_names** function applies FX as described above, and in the process of doing so, tags each node with a human readable name. Let’s try this with our toy CNN model from the previous section:
244+
Now, TorchVision’s **[get_graph_node_names](https://pytorch.org/vision/stable/feature_extraction.html#torchvision.models.feature_extraction.get_graph_node_names)** function applies FX as described above, and in the process of doing so, tags each node with a human readable name. Let’s try this with our toy CNN model from the previous section:
245245

246246
```python
247247
model = CNN(3, 4, 10)
@@ -297,7 +297,7 @@ Here’s that table again with another row added for FX feature extraction
297297
| FX | YES | YES | YES | YES |
298298
|-------------------------------------------------------------------|:-----------------------------------------------------------------:|:--------------------------------------------------------------------------------------:|:--------------------------------------:|:--------------------:|
299299

300-
Table 2: A copy of Table 1 with an added row for FX feature extraction. FX feature extraction gets greens across the board!
300+
Table 2: A copy of Table 1 with an added row for FX feature extraction. FX feature extraction gets YES across the board!
301301

302302

303303
## Current FX Limitations

0 commit comments

Comments
 (0)