@@ -74,7 +74,7 @@ def forward(self, x):
74
74
x = self .conv2 (x )
75
75
x = x .view (x .size (0 ), - 1 ) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
76
76
output = self .out (x )
77
- return output
77
+ return output , x # return x for visualization
78
78
79
79
80
80
cnn = CNN ()
@@ -83,24 +83,53 @@ def forward(self, x):
83
83
optimizer = torch .optim .Adam (cnn .parameters (), lr = LR ) # optimize all cnn parameters
84
84
loss_func = nn .CrossEntropyLoss () # the target label is not one-hotted
85
85
86
+ # following function (plot_with_labels) is for visualization, can be ignored if not interested
87
+ from matplotlib import cm
88
+ try :
89
+ from sklearn .manifold import TSNE
90
+ HAS_SK = True
91
+ except :
92
+ HAS_SK = False
93
+ print ('Please install sklearn for layer visualization' )
94
+ def plot_with_labels (lowDWeights , labels ):
95
+ plt .cla ()
96
+ X , Y = lowDWeights [:, 0 ], lowDWeights [:, 1 ]
97
+ for x , y , s in zip (X , Y , labels ):
98
+ c = cm .rainbow (int (255 * s / 9 ))
99
+ plt .text (x , y , s , backgroundcolor = c , fontsize = 9 )
100
+ plt .xlim (X .min (), X .max ())
101
+ plt .ylim (Y .min (), Y .max ())
102
+ plt .title ('Visualize last layer' )
103
+ plt .show ()
104
+ plt .pause (0.01 )
105
+
106
+ plt .ion ()
107
+
86
108
# training and testing
87
109
for epoch in range (EPOCH ):
88
110
for step , (x , y ) in enumerate (train_loader ): # gives batch data, normalize x when iterate train_loader
89
111
b_x = Variable (x ) # batch x
90
112
b_y = Variable (y ) # batch y
91
113
92
- output = cnn (b_x ) # cnn output
114
+ output = cnn (b_x )[ 0 ] # cnn output
93
115
loss = loss_func (output , b_y ) # cross entropy loss
94
116
optimizer .zero_grad () # clear gradients for this training step
95
117
loss .backward () # backpropagation, compute gradients
96
118
optimizer .step () # apply gradients
97
119
98
120
if step % 50 == 0 :
99
- test_output = cnn (test_x )
121
+ test_output , last_layer = cnn (test_x )
100
122
pred_y = torch .max (test_output , 1 )[1 ].data .squeeze ()
101
123
accuracy = sum (pred_y == test_y ) / float (test_y .size (0 ))
102
124
print ('Epoch: ' , epoch , '| train loss: %.4f' % loss .data [0 ], '| test accuracy: %.2f' % accuracy )
103
-
125
+ if HAS_SK :
126
+ # Visualization of trained flatten layer (T-SNE)
127
+ tsne = TSNE (perplexity = 30 , n_components = 2 , init = 'pca' , n_iter = 5000 )
128
+ plot_only = 500
129
+ low_dim_embs = tsne .fit_transform (last_layer .data .numpy ()[:plot_only , :])
130
+ labels = test_y .numpy ()[:plot_only ]
131
+ plot_with_labels (low_dim_embs , labels )
132
+ plt .ioff ()
104
133
105
134
# print 10 predictions from test data
106
135
test_output = cnn (test_x [:10 ])
0 commit comments