Skip to content

Commit 895ad9a

Browse files
committed
mnist dataset download setting
1 parent 2ae0c0f commit 895ad9a

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tutorial-contents/401_CNN.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
torchvision
88
matplotlib
99
"""
10+
# library
11+
# standard library
12+
import os
13+
14+
# third-party library
1015
import torch
1116
import torch.nn as nn
1217
from torch.autograd import Variable
@@ -20,16 +25,20 @@
2025
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
2126
BATCH_SIZE = 50
2227
LR = 0.001 # learning rate
23-
DOWNLOAD_MNIST = True # set to False if you have downloaded
28+
DOWNLOAD_MNIST = False
2429

2530

2631
# Mnist digits dataset
32+
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
33+
# not mnist dir or mnist is empyt dir
34+
DOWNLOAD_MNIST = True
35+
2736
train_data = torchvision.datasets.MNIST(
2837
root='./mnist/',
2938
train=True, # this is training data
3039
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
3140
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
32-
download=DOWNLOAD_MNIST, # download it if you don't have it
41+
download=DOWNLOAD_MNIST,
3342
)
3443

3544
# plot one example

0 commit comments

Comments
 (0)