2121import collections
2222import gzip
2323import os
24+ import urllib
2425
2526import numpy
26- from six.moves import urllib
27- from six.moves import xrange # pylint: disable=redefined-builtin
28-
29- from tensorflow.python.framework import dtypes
30- from tensorflow.python.framework import random_seed
27+ from tensorflow .python .framework import dtypes , random_seed
3128from tensorflow .python .platform import gfile
3229from tensorflow .python .util .deprecation import deprecated
3330
@@ -46,16 +43,16 @@ def _read32(bytestream):
4643def _extract_images (f ):
4744 """Extract the images into a 4D uint8 numpy array [index, y, x, depth].
4845
49- Args:
50- f: A file object that can be passed into a gzip reader.
46+ Args:
47+ f: A file object that can be passed into a gzip reader.
5148
52- Returns:
53- data: A 4D uint8 numpy array [index, y, x, depth].
49+ Returns:
50+ data: A 4D uint8 numpy array [index, y, x, depth].
5451
55- Raises:
56- ValueError: If the bytestream does not start with 2051.
52+ Raises:
53+ ValueError: If the bytestream does not start with 2051.
5754
58- """
55+ """
5956 print ("Extracting" , f .name )
6057 with gzip .GzipFile (fileobj = f ) as bytestream :
6158 magic = _read32 (bytestream )
@@ -86,17 +83,17 @@ def _dense_to_one_hot(labels_dense, num_classes):
8683def _extract_labels (f , one_hot = False , num_classes = 10 ):
8784 """Extract the labels into a 1D uint8 numpy array [index].
8885
89- Args:
90- f: A file object that can be passed into a gzip reader.
91- one_hot: Does one hot encoding for the result.
92- num_classes: Number of classes for the one hot encoding.
86+ Args:
87+ f: A file object that can be passed into a gzip reader.
88+ one_hot: Does one hot encoding for the result.
89+ num_classes: Number of classes for the one hot encoding.
9390
94- Returns:
95- labels: a 1D uint8 numpy array.
91+ Returns:
92+ labels: a 1D uint8 numpy array.
9693
97- Raises:
98- ValueError: If the bystream doesn't start with 2049.
99- """
94+ Raises:
95+ ValueError: If the bystream doesn't start with 2049.
96+ """
10097 print ("Extracting" , f .name )
10198 with gzip .GzipFile (fileobj = f ) as bytestream :
10299 magic = _read32 (bytestream )
@@ -115,8 +112,8 @@ def _extract_labels(f, one_hot=False, num_classes=10):
115112class _DataSet :
116113 """Container class for a _DataSet (deprecated).
117114
118- THIS CLASS IS DEPRECATED.
119- """
115+ THIS CLASS IS DEPRECATED.
116+ """
120117
121118 @deprecated (
122119 None ,
@@ -135,21 +132,21 @@ def __init__(
135132 ):
136133 """Construct a _DataSet.
137134
138- one_hot arg is used only if fake_data is true. `dtype` can be either
139- `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
140- `[0, 1]`. Seed arg provides for convenient deterministic testing.
141-
142- Args:
143- images: The images
144- labels: The labels
145- fake_data: Ignore inages and labels, use fake data.
146- one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
147- False).
148- dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
149- range [0,255]. float32 output has range [0,1].
150- reshape: Bool. If True returned images are returned flattened to vectors.
151- seed: The random seed to use.
152- """
135+ one_hot arg is used only if fake_data is true. `dtype` can be either
136+ `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
137+ `[0, 1]`. Seed arg provides for convenient deterministic testing.
138+
139+ Args:
140+ images: The images
141+ labels: The labels
142+ fake_data: Ignore inages and labels, use fake data.
143+ one_hot: Bool, return the labels as one hot vectors (if True) or ints (if
144+ False).
145+ dtype: Output image dtype. One of [uint8, float32]. `uint8` output has
146+ range [0,255]. float32 output has range [0,1].
147+ reshape: Bool. If True returned images are returned flattened to vectors.
148+ seed: The random seed to use.
149+ """
153150 seed1 , seed2 = random_seed .get_seed (seed )
154151 # If op level seed is not set, use whatever graph level seed is returned
155152 numpy .random .seed (seed1 if seed is None else seed2 )
@@ -206,8 +203,8 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
206203 else :
207204 fake_label = 0
208205 return (
209- [fake_image for _ in xrange (batch_size)],
210- [fake_label for _ in xrange (batch_size)],
206+ [fake_image for _ in range (batch_size )],
207+ [fake_label for _ in range (batch_size )],
211208 )
212209 start = self ._index_in_epoch
213210 # Shuffle for the first epoch
@@ -250,19 +247,19 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
250247def _maybe_download (filename , work_directory , source_url ):
251248 """Download the data from source url, unless it's already here.
252249
253- Args:
254- filename: string, name of the file in the directory.
255- work_directory: string, path to working directory.
256- source_url: url to download from if file doesn't exist.
250+ Args:
251+ filename: string, name of the file in the directory.
252+ work_directory: string, path to working directory.
253+ source_url: url to download from if file doesn't exist.
257254
258- Returns:
259- Path to resulting file.
260- """
255+ Returns:
256+ Path to resulting file.
257+ """
261258 if not gfile .Exists (work_directory ):
262259 gfile .MakeDirs (work_directory )
263260 filepath = os .path .join (work_directory , filename )
264261 if not gfile .Exists (filepath ):
265- urllib.request.urlretrieve(source_url, filepath)
262+ urllib .request .urlretrieve (source_url , filepath ) # noqa: S310
266263 with gfile .GFile (filepath ) as f :
267264 size = f .size ()
268265 print ("Successfully downloaded" , filename , size , "bytes." )
@@ -328,15 +325,16 @@ def fake():
328325
329326 if not 0 <= validation_size <= len (train_images ):
330327 raise ValueError (
331- f"Validation size should be between 0 and {len(train_images)}. Received: {validation_size}."
328+ f"Validation size should be between 0 and { len (train_images )} . "
329+ f"Received: { validation_size } ."
332330 )
333331
334332 validation_images = train_images [:validation_size ]
335333 validation_labels = train_labels [:validation_size ]
336334 train_images = train_images [validation_size :]
337335 train_labels = train_labels [validation_size :]
338336
339- options = dict( dtype= dtype, reshape= reshape, seed= seed)
337+ options = { " dtype" : dtype , " reshape" : reshape , " seed" : seed }
340338
341339 train = _DataSet (train_images , train_labels , ** options )
342340 validation = _DataSet (validation_images , validation_labels , ** options )
0 commit comments