Skip to content

Commit 2eeca01

Browse files
Add files via upload
1 parent b005361 commit 2eeca01

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
from tensorflow.keras import layers
4+
5+
# Filter out corrupted images
6+
import os
7+
8+
num_skipped = 0
9+
for folder_name in ("Cat", "Dog"):
10+
folder_path = os.path.join("PetImages", folder_name)
11+
for fname in os.listdir(folder_path):
12+
fpath = os.path.join(folder_path, fname)
13+
try:
14+
fobj = open(fpath, "rb")
15+
is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
16+
finally:
17+
fobj.close()
18+
19+
if not is_jfif:
20+
num_skipped += 1
21+
# Delete corrupted image
22+
os.remove(fpath)
23+
24+
print("Deleted %d images" % num_skipped)
25+
26+
image_size = (180, 180)
27+
batch_size = 32
28+
29+
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
30+
"PetImages",
31+
validation_split=0.2,
32+
subset="training",
33+
seed=1337,
34+
image_size=image_size,
35+
batch_size=batch_size,
36+
)
37+
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
38+
"PetImages",
39+
validation_split=0.2,
40+
subset="validation",
41+
seed=1337,
42+
image_size=image_size,
43+
batch_size=batch_size,
44+
)
45+
46+
# Visualize the data
47+
import matplotlib.pyplot as plt
48+
49+
plt.figure(figsize=(10, 10))
50+
for images, labels in train_ds.take(1):
51+
for i in range(9):
52+
ax = plt.subplot(3, 3, i + 1)
53+
plt.imshow(images[i].numpy().astype("uint8"))
54+
plt.title(int(labels[i]))
55+
plt.axis("off")
56+
plt.show()
57+
58+
# Using image data augmentation
59+
data_augmentation = keras.Sequential(
60+
[
61+
layers.RandomFlip("horizontal"),
62+
layers.RandomRotation(0.1),
63+
]
64+
)
65+
66+
# Visualize
67+
plt.figure(figsize=(10, 10))
68+
for images, _ in train_ds.take(1):
69+
for i in range(9):
70+
augmented_images = data_augmentation(images)
71+
ax = plt.subplot(3, 3, i + 1)
72+
plt.imshow(augmented_images[0].numpy().astype("uint8"))
73+
plt.axis("off")
74+
plt.show()
75+
76+
# Preprocess the data by applying it to the dataset (for training in CPU)
77+
augmented_train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))
78+
79+
# Configure the dataset for performance
80+
train_ds = train_ds.prefetch(buffer_size=32)
81+
val_ds = val_ds.prefetch(buffer_size=32)
82+
83+
# Build a model
84+
def make_model(input_shape, num_classes):
85+
inputs = keras.Input(shape=input_shape)
86+
# Image augmentation block
87+
x = data_augmentation(inputs)
88+
89+
# Entry block
90+
x = layers.Rescaling(1.0 / 255)(x)
91+
x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
92+
x = layers.BatchNormalization()(x)
93+
x = layers.Activation("relu")(x)
94+
95+
x = layers.Conv2D(64, 3, padding="same")(x)
96+
x = layers.BatchNormalization()(x)
97+
x = layers.Activation("relu")(x)
98+
99+
previous_block_activation = x # Set aside residual
100+
101+
for size in [128, 256, 512, 728]:
102+
x = layers.Activation("relu")(x)
103+
x = layers.SeparableConv2D(size, 3, padding="same")(x)
104+
x = layers.BatchNormalization()(x)
105+
106+
x = layers.Activation("relu")(x)
107+
x = layers.SeparableConv2D(size, 3, padding="same")(x)
108+
x = layers.BatchNormalization()(x)
109+
110+
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
111+
112+
# Project residual
113+
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
114+
previous_block_activation
115+
)
116+
x = layers.add([x, residual]) # Add back residual
117+
previous_block_activation = x # Set aside next residual
118+
119+
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
120+
x = layers.BatchNormalization()(x)
121+
x = layers.Activation("relu")(x)
122+
123+
x = layers.GlobalAveragePooling2D()(x)
124+
if num_classes == 2:
125+
activation = "sigmoid"
126+
units = 1
127+
else:
128+
activation = "softmax"
129+
units = num_classes
130+
131+
x = layers.Dropout(0.5)(x)
132+
outputs = layers.Dense(units, activation=activation)(x)
133+
return keras.Model(inputs, outputs)
134+
135+
136+
model = make_model(input_shape=image_size + (3,), num_classes=2)
137+
keras.utils.plot_model(model, show_shapes=True)
138+
139+
140+
# Train the model
141+
epochs = 50
142+
143+
callbacks = [
144+
keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5"),
145+
]
146+
model.compile(
147+
optimizer=keras.optimizers.Adam(1e-3),
148+
loss="binary_crossentropy",
149+
metrics=["accuracy"],
150+
)
151+
model.fit(
152+
train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
153+
)
154+

0 commit comments

Comments
 (0)