Skip to content

Commit f2ebd8a

Browse files
committed
update accuracy function
1 parent 79380fb commit f2ebd8a

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

tutorial-contents/302_classification.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, x):
6262
pred_y = prediction.data.numpy().squeeze()
6363
target_y = y.data.numpy()
6464
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
65-
accuracy = sum(pred_y == target_y)/200.
65+
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
6666
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
6767
plt.pause(0.1)
6868

tutorial-contents/401_CNN.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def plot_with_labels(lowDWeights, labels):
115115

116116
if step % 50 == 0:
117117
test_output, last_layer = cnn(test_x)
118-
pred_y = torch.max(test_output, 1)[1].data.squeeze()
119-
accuracy = float(sum(pred_y == test_y)) / float(test_y.size(0))
118+
pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
119+
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
120120
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
121121
if HAS_SK:
122122
# Visualization of trained flatten layer (T-SNE)

tutorial-contents/402_RNN_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def forward(self, x):
9595
if step % 50 == 0:
9696
test_output = rnn(test_x) # (samples, time_step, input_size)
9797
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
98-
accuracy = float(sum(pred_y == test_y)) / float(test_y.size)
98+
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
9999
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
100100

101101
# print 10 predictions from test data

0 commit comments

Comments
 (0)