Skip to content

Commit dcc4e7c

Browse files
MorvanZhouMorvan Zhou
authored and
Morvan Zhou
committed
update
1 parent 3fc6f4e commit dcc4e7c

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ You can watch my [Youtube channel](https://www.youtube.com/channel/UCdyjiB5H8Pu7
5656
<img class="course-image" src="https://morvanzhou.github.io/static/results/torch/1-1-3.gif">
5757
</a>
5858

59+
### [CNN](https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py)
60+
<a href="https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/401_CNN.py">
61+
<img class="course-image" src="https://morvanzhou.github.io/static/results/torch/4-1-2.gif" >
62+
</a>
63+
5964
### [RNN](https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/403_RNN_regressor.py)
6065

6166
<a href="https://github.com/MorvanZhou/PyTorch-Tutorial/blob/master/tutorial-contents/403_RNN_regressor.py">

tutorial-contents/401_CNN.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def forward(self, x):
7474
x = self.conv2(x)
7575
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
7676
output = self.out(x)
77-
return output
77+
return output, x # return x for visualization
7878

7979

8080
cnn = CNN()
@@ -83,24 +83,53 @@ def forward(self, x):
8383
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
8484
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
8585

86+
# following function (plot_with_labels) is for visualization, can be ignored if not interested
87+
from matplotlib import cm
88+
try:
89+
from sklearn.manifold import TSNE
90+
HAS_SK = True
91+
except:
92+
HAS_SK = False
93+
print('Please install sklearn for layer visualization')
94+
def plot_with_labels(lowDWeights, labels):
95+
plt.cla()
96+
X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
97+
for x, y, s in zip(X, Y, labels):
98+
c = cm.rainbow(int(255 * s / 9))
99+
plt.text(x, y, s, backgroundcolor=c, fontsize=9)
100+
plt.xlim(X.min(), X.max())
101+
plt.ylim(Y.min(), Y.max())
102+
plt.title('Visualize last layer')
103+
plt.show()
104+
plt.pause(0.01)
105+
106+
plt.ion()
107+
86108
# training and testing
87109
for epoch in range(EPOCH):
88110
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
89111
b_x = Variable(x) # batch x
90112
b_y = Variable(y) # batch y
91113

92-
output = cnn(b_x) # cnn output
114+
output = cnn(b_x)[0] # cnn output
93115
loss = loss_func(output, b_y) # cross entropy loss
94116
optimizer.zero_grad() # clear gradients for this training step
95117
loss.backward() # backpropagation, compute gradients
96118
optimizer.step() # apply gradients
97119

98120
if step % 50 == 0:
99-
test_output = cnn(test_x)
121+
test_output, last_layer = cnn(test_x)
100122
pred_y = torch.max(test_output, 1)[1].data.squeeze()
101123
accuracy = sum(pred_y == test_y) / float(test_y.size(0))
102124
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
103-
125+
if HAS_SK:
126+
# Visualization of trained flatten layer (T-SNE)
127+
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
128+
plot_only = 500
129+
low_dim_embs = tsne.fit_transform(last_layer.data.numpy()[:plot_only, :])
130+
labels = test_y.numpy()[:plot_only]
131+
plot_with_labels(low_dim_embs, labels)
132+
plt.ioff()
104133

105134
# print 10 predictions from test data
106135
test_output = cnn(test_x[:10])

0 commit comments

Comments
 (0)