@@ -108,45 +108,38 @@ def parse_record(raw_record):
108
108
# Convert bytes to a vector of uint8 that is record_bytes long.
109
109
record_vector = tf .decode_raw (raw_record , tf .uint8 )
110
110
111
- # The first byte represents the label, which we convert from uint8 to int32.
111
+ # The first byte represents the label, which we convert from uint8 to int32
112
+ # and then to one-hot.
112
113
label = tf .cast (record_vector [0 ], tf .int32 )
114
+ label = tf .one_hot (label , _NUM_CLASSES )
113
115
114
116
# The remaining bytes after the label represent the image, which we reshape
115
117
# from [depth * height * width] to [depth, height, width].
116
- depth_major = tf .reshape (record_vector [ label_bytes : record_bytes ],
117
- [_DEPTH , _HEIGHT , _WIDTH ])
118
+ depth_major = tf .reshape (
119
+ record_vector [ label_bytes : record_bytes ], [_DEPTH , _HEIGHT , _WIDTH ])
118
120
119
121
# Convert from [depth, height, width] to [height, width, depth], and cast as
120
122
# float32.
121
123
image = tf .cast (tf .transpose (depth_major , [1 , 2 , 0 ]), tf .float32 )
122
124
123
- return image , tf .one_hot (label , _NUM_CLASSES )
124
-
125
-
126
- def train_preprocess_fn (image ):
127
- """Preprocess a single training image of layout [height, width, depth]."""
128
- # Resize the image to add four extra pixels on each side.
129
- image = tf .image .resize_image_with_crop_or_pad (image , _HEIGHT + 8 , _WIDTH + 8 )
130
-
131
- # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
132
- image = tf .random_crop (image , [_HEIGHT , _WIDTH , _DEPTH ])
133
-
134
- # Randomly flip the image horizontally.
135
- image = tf .image .random_flip_left_right (image )
125
+ return image , label
136
126
137
- return image
138
127
128
+ def preprocess_image (image , is_training ):
129
+ """Preprocess a single image of layout [height, width, depth]."""
130
+ if is_training :
131
+ # Resize the image to add four extra pixels on each side.
132
+ image = tf .image .resize_image_with_crop_or_pad (image , _HEIGHT + 8 , _WIDTH + 8 )
139
133
140
- def parse_and_preprocess (record , is_training ):
141
- """Parse and preprocess records in the CIFAR-10 dataset."""
142
- image , label = parse_record (record )
134
+ # Randomly crop a [_HEIGHT, _WIDTH] section of the image.
135
+ image = tf .random_crop (image , [_HEIGHT , _WIDTH , _DEPTH ])
143
136
144
- if is_training :
145
- image = train_preprocess_fn (image )
137
+ # Randomly flip the image horizontally.
138
+ image = tf . image . random_flip_left_right (image )
146
139
147
140
# Subtract off the mean and divide by the variance of the pixels.
148
141
image = tf .image .per_image_standardization (image )
149
- return image , label
142
+ return image
150
143
151
144
152
145
def input_fn (is_training , data_dir , batch_size , num_epochs = 1 ):
@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
168
161
# randomness, while smaller sizes have better performance.
169
162
dataset = dataset .shuffle (buffer_size = _SHUFFLE_BUFFER )
170
163
164
+ dataset = dataset .map (parse_record )
171
165
dataset = dataset .map (
172
- lambda record : parse_and_preprocess (record , is_training ))
166
+ lambda image , label : (preprocess_image (image , is_training ), label ))
167
+
173
168
dataset = dataset .prefetch (2 * batch_size )
174
169
175
170
# We call repeat after shuffling, rather than before, to prevent separate
0 commit comments