Skip to content

Commit 389194c

Browse files
committed
2 parents 295cbb6 + 4b5b7c7 commit 389194c

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tensorflow/code/tensorflow.keras.mnist.classifier.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2)))
9090
model.add(tf.keras.layers.Dropout(0.25))
9191

92-
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding='Same', activation='relu'))
92+
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding='Same', activation='relu'))
9393
model.add(tf.keras.layers.BatchNormalization())
9494
model.add(tf.keras.layers.Dropout(0.25))
9595

@@ -99,6 +99,7 @@
9999
model.add(tf.keras.layers.Dropout(0.25))
100100

101101
model.add(tf.keras.layers.Dense(10, activation="softmax"))
102+
model.add(tf.keras.layers.Dense(256, kernel_constraint=tf.keras.constraints.MaxNorm(2)))
102103

103104
# 打印出model 看看
104105
tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)
@@ -118,6 +119,9 @@
118119
factor=0.5,
119120
min_lr=0.00001)
120121

122+
# should add early_stopping to the model training callbacks later
123+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True)
124+
121125
# 设置epochs和batch size
122126
epochs = 20
123127
batch_size = 128
@@ -143,12 +147,12 @@
143147
validation_data=(X_val, Y_val),
144148
verbose=2,
145149
steps_per_epoch=X_train.shape[0] // batch_size,
146-
callbacks=[learning_rate_reduction, tf.keras.callbacks.TensorBoard(log_dir='./log_dir')])
150+
callbacks=[learning_rate_reduction, early_stopping, tf.keras.callbacks.TensorBoard(log_dir='./log_dir')])
147151

148152
# 画训练集和验证集的loss和accuracy曲线。可以判断是否欠拟合或过拟合
149153
fig, ax = plt.subplots(2, 1)
150154
ax[0].plot(history.history['loss'], color='b', label="Training loss")
151-
ax[0].plot(history.history['val_loss'], color='r', label="validation loss",axes =ax[0])
155+
ax[0].plot(history.history['val_loss'], color='r', label="validation loss", axes =ax[0])
152156
legend = ax[0].legend(loc='best', shadow=True)
153157

154158
ax[1].plot(history.history['acc'], color='b', label="Training accuracy")

0 commit comments

Comments
 (0)