File tree 1 file changed +11
-2
lines changed
1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change 7
7
torchvision
8
8
matplotlib
9
9
"""
10
+ # library
11
+ # standard library
12
+ import os
13
+
14
+ # third-party library
10
15
import torch
11
16
import torch .nn as nn
12
17
from torch .autograd import Variable
20
25
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
21
26
BATCH_SIZE = 50
22
27
LR = 0.001 # learning rate
23
- DOWNLOAD_MNIST = True # set to False if you have downloaded
28
+ DOWNLOAD_MNIST = False
24
29
25
30
26
31
# Mnist digits dataset
32
+ if not (os .path .exists ('./mnist/' )) or not os .listdir ('./mnist/' ):
33
+ # not mnist dir or mnist is empyt dir
34
+ DOWNLOAD_MNIST = True
35
+
27
36
train_data = torchvision .datasets .MNIST (
28
37
root = './mnist/' ,
29
38
train = True , # this is training data
30
39
transform = torchvision .transforms .ToTensor (), # Converts a PIL.Image or numpy.ndarray to
31
40
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
32
- download = DOWNLOAD_MNIST , # download it if you don't have it
41
+ download = DOWNLOAD_MNIST ,
33
42
)
34
43
35
44
# plot one example
You can’t perform that action at this time.
0 commit comments