|
89 | 89 | model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2)))
|
90 | 90 | model.add(tf.keras.layers.Dropout(0.25))
|
91 | 91 |
|
92 |
| -model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding='Same', activation='relu')) |
| 92 | +model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3), padding='Same', activation='relu')) |
93 | 93 | model.add(tf.keras.layers.BatchNormalization())
|
94 | 94 | model.add(tf.keras.layers.Dropout(0.25))
|
95 | 95 |
|
|
99 | 99 | model.add(tf.keras.layers.Dropout(0.25))
|
100 | 100 |
|
101 | 101 | model.add(tf.keras.layers.Dense(10, activation="softmax"))
|
| 102 | +model.add(tf.keras.layers.Dense(256, kernel_constraint=tf.keras.constraints.MaxNorm(2))) |
102 | 103 |
|
103 | 104 | # 打印出model 看看
|
104 | 105 | tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)
|
|
118 | 119 | factor=0.5,
|
119 | 120 | min_lr=0.00001)
|
120 | 121 |
|
| 122 | +# should add early_stopping to the model training callbacks later |
| 123 | +early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', restore_best_weights=True) |
| 124 | + |
121 | 125 | # 设置epochs和batch size
|
122 | 126 | epochs = 20
|
123 | 127 | batch_size = 128
|
|
143 | 147 | validation_data=(X_val, Y_val),
|
144 | 148 | verbose=2,
|
145 | 149 | steps_per_epoch=X_train.shape[0] // batch_size,
|
146 |
| - callbacks=[learning_rate_reduction, tf.keras.callbacks.TensorBoard(log_dir='./log_dir')]) |
| 150 | + callbacks=[learning_rate_reduction, early_stopping, tf.keras.callbacks.TensorBoard(log_dir='./log_dir')]) |
147 | 151 |
|
148 | 152 | # 画训练集和验证集的loss和accuracy曲线。可以判断是否欠拟合或过拟合
|
149 | 153 | fig, ax = plt.subplots(2, 1)
|
150 | 154 | ax[0].plot(history.history['loss'], color='b', label="Training loss")
|
151 |
| -ax[0].plot(history.history['val_loss'], color='r', label="validation loss",axes =ax[0]) |
| 155 | +ax[0].plot(history.history['val_loss'], color='r', label="validation loss", axes =ax[0]) |
152 | 156 | legend = ax[0].legend(loc='best', shadow=True)
|
153 | 157 |
|
154 | 158 | ax[1].plot(history.history['acc'], color='b', label="Training accuracy")
|
|
0 commit comments