Skip to content

Commit 6e52c27

Browse files
committed
Separate parse_and_preprocess into two different dataset.map calls, which also keeps tests passing
1 parent 807d6bd commit 6e52c27

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

official/resnet/cifar10_main.py

+19-24
Original file line numberDiff line numberDiff line change
@@ -108,45 +108,38 @@ def parse_record(raw_record):
108108
# Convert bytes to a vector of uint8 that is record_bytes long.
109109
record_vector = tf.decode_raw(raw_record, tf.uint8)
110110

111-
# The first byte represents the label, which we convert from uint8 to int32.
111+
# The first byte represents the label, which we convert from uint8 to int32
112+
# and then to one-hot.
112113
label = tf.cast(record_vector[0], tf.int32)
114+
label = tf.one_hot(label, _NUM_CLASSES)
113115

114116
# The remaining bytes after the label represent the image, which we reshape
115117
# from [depth * height * width] to [depth, height, width].
116-
depth_major = tf.reshape(record_vector[label_bytes:record_bytes],
117-
[_DEPTH, _HEIGHT, _WIDTH])
118+
depth_major = tf.reshape(
119+
record_vector[label_bytes:record_bytes], [_DEPTH, _HEIGHT, _WIDTH])
118120

119121
# Convert from [depth, height, width] to [height, width, depth], and cast as
120122
# float32.
121123
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
122124

123-
return image, tf.one_hot(label, _NUM_CLASSES)
124-
125-
126-
def train_preprocess_fn(image):
127-
"""Preprocess a single training image of layout [height, width, depth]."""
128-
# Resize the image to add four extra pixels on each side.
129-
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
130-
131-
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
132-
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
133-
134-
# Randomly flip the image horizontally.
135-
image = tf.image.random_flip_left_right(image)
125+
return image, label
136126

137-
return image
138127

128+
def preprocess_image(image, is_training):
129+
"""Preprocess a single image of layout [height, width, depth]."""
130+
if is_training:
131+
# Resize the image to add four extra pixels on each side.
132+
image = tf.image.resize_image_with_crop_or_pad(image, _HEIGHT + 8, _WIDTH + 8)
139133

140-
def parse_and_preprocess(record, is_training):
141-
"""Parse and preprocess records in the CIFAR-10 dataset."""
142-
image, label = parse_record(record)
134+
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
135+
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _DEPTH])
143136

144-
if is_training:
145-
image = train_preprocess_fn(image)
137+
# Randomly flip the image horizontally.
138+
image = tf.image.random_flip_left_right(image)
146139

147140
# Subtract off the mean and divide by the variance of the pixels.
148141
image = tf.image.per_image_standardization(image)
149-
return image, label
142+
return image
150143

151144

152145
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
@@ -168,8 +161,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
168161
# randomness, while smaller sizes have better performance.
169162
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
170163

164+
dataset = dataset.map(parse_record)
171165
dataset = dataset.map(
172-
lambda record: parse_and_preprocess(record, is_training))
166+
lambda image, label: (preprocess_image(image, is_training), label))
167+
173168
dataset = dataset.prefetch(2 * batch_size)
174169

175170
# We call repeat after shuffling, rather than before, to prevent separate

0 commit comments

Comments
 (0)