Skip to content

Commit 7411f60

Browse files
402 - remove squeeze for 1-D array
1 parent 40bcdb5 commit 7411f60

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tutorial-contents/402_RNN_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# convert test data into Variable, pick 2000 samples to speed up testing
4848
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
4949
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
50-
test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy array
50+
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array
5151

5252

5353
class RNN(nn.Module):
@@ -94,13 +94,13 @@ def forward(self, x):
9494

9595
if step % 50 == 0:
9696
test_output = rnn(test_x) # (samples, time_step, input_size)
97-
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
97+
pred_y = torch.max(test_output, 1)[1].data.numpy()
9898
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
9999
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
100100

101101
# print 10 predictions from test data
102102
test_output = rnn(test_x[:10].view(-1, 28, 28))
103-
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
103+
pred_y = torch.max(test_output, 1)[1].data.numpy()
104104
print(pred_y, 'prediction number')
105105
print(test_y[:10], 'real number')
106106

0 commit comments

Comments
 (0)