Skip to content

Commit 4cd242b

Browse files
committed
add early_stopping
1 parent 0fa3d09 commit 4cd242b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

Diff for: tensorflow/code/tensorflow.keras.mnist.classifier.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
verbose=1,
118118
factor=0.5,
119119
min_lr=0.00001)
120+
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True)
120121

121122
# 设置epochs和batch size
122123
epochs = 20
@@ -143,12 +144,12 @@
143144
validation_data=(X_val, Y_val),
144145
verbose=2,
145146
steps_per_epoch=X_train.shape[0] // batch_size,
146-
callbacks=[learning_rate_reduction, tf.keras.callbacks.TensorBoard(log_dir='./log_dir')])
147+
callbacks=[learning_rate_reduction, early_stopping, tf.keras.callbacks.TensorBoard(log_dir='./log_dir')])
147148

148149
# 画训练集和验证集的loss和accuracy曲线。可以判断是否欠拟合或过拟合
149150
fig, ax = plt.subplots(2, 1)
150151
ax[0].plot(history.history['loss'], color='b', label="Training loss")
151-
ax[0].plot(history.history['val_loss'], color='r', label="validation loss",axes =ax[0])
152+
ax[0].plot(history.history['val_loss'], color='r', label="validation loss", axes =ax[0])
152153
legend = ax[0].legend(loc='best', shadow=True)
153154

154155
ax[1].plot(history.history['acc'], color='b', label="Training accuracy")

0 commit comments

Comments
 (0)