@@ -99,30 +99,17 @@ def forward(self, x):
99
99
f , axs = plt .subplots (4 , N_HIDDEN + 1 , figsize = (10 , 5 ))
100
100
plt .ion () # something about plotting
101
101
plt .show ()
102
-
103
102
def plot_histogram (l_in , l_in_bn , pre_ac , pre_ac_bn ):
104
103
for i , (ax_pa , ax_pa_bn , ax , ax_bn ) in enumerate (zip (axs [0 , :], axs [1 , :], axs [2 , :], axs [3 , :])):
105
104
[a .clear () for a in [ax_pa , ax_pa_bn , ax , ax_bn ]]
106
- if i == 0 :
107
- p_range = (- 7 , 10 )
108
- the_range = (- 7 , 10 )
109
- else :
110
- p_range = (- 4 , 4 )
111
- the_range = (- 1 , 1 )
105
+ if i == 0 : p_range = (- 7 , 10 );the_range = (- 7 , 10 )
106
+ else :p_range = (- 4 , 4 );the_range = (- 1 , 1 )
112
107
ax_pa .set_title ('L' + str (i ))
113
- ax_pa .hist (pre_ac [i ].data .numpy ().ravel (), bins = 10 , range = p_range , color = '#FF9359' , alpha = 0.5 )
114
- ax_pa_bn .hist (pre_ac_bn [i ].data .numpy ().ravel (), bins = 10 , range = p_range , color = '#74BCFF' , alpha = 0.5 )
115
- ax .hist (l_in [i ].data .numpy ().ravel (), bins = 10 , range = the_range , color = '#FF9359' )
116
- ax_bn .hist (l_in_bn [i ].data .numpy ().ravel (), bins = 10 , range = the_range , color = '#74BCFF' )
117
- for a in [ax_pa , ax , ax_pa_bn , ax_bn ]:
118
- a .set_yticks (())
119
- a .set_xticks (())
120
- ax_pa_bn .set_xticks (p_range )
121
- ax_bn .set_xticks (the_range )
122
- axs [0 , 0 ].set_ylabel ('PreAct' )
123
- axs [1 , 0 ].set_ylabel ('BN PreAct' )
124
- axs [2 , 0 ].set_ylabel ('Act' )
125
- axs [3 , 0 ].set_ylabel ('BN Act' )
108
+ ax_pa .hist (pre_ac [i ].data .numpy ().ravel (), bins = 10 , range = p_range , color = '#FF9359' , alpha = 0.5 );ax_pa_bn .hist (pre_ac_bn [i ].data .numpy ().ravel (), bins = 10 , range = p_range , color = '#74BCFF' , alpha = 0.5 )
109
+ ax .hist (l_in [i ].data .numpy ().ravel (), bins = 10 , range = the_range , color = '#FF9359' );ax_bn .hist (l_in_bn [i ].data .numpy ().ravel (), bins = 10 , range = the_range , color = '#74BCFF' )
110
+ for a in [ax_pa , ax , ax_pa_bn , ax_bn ]: a .set_yticks (());a .set_xticks (())
111
+ ax_pa_bn .set_xticks (p_range );ax_bn .set_xticks (the_range )
112
+ axs [0 , 0 ].set_ylabel ('PreAct' );axs [1 , 0 ].set_ylabel ('BN PreAct' );axs [2 , 0 ].set_ylabel ('Act' );axs [3 , 0 ].set_ylabel ('BN Act' )
126
113
plt .pause (0.01 )
127
114
128
115
# training
@@ -155,10 +142,7 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
155
142
plt .figure (2 )
156
143
plt .plot (losses [0 ], c = '#FF9359' , lw = 3 , label = 'Original' )
157
144
plt .plot (losses [1 ], c = '#74BCFF' , lw = 3 , label = 'Batch Normalization' )
158
- plt .xlabel ('step' )
159
- plt .ylabel ('test loss' )
160
- plt .ylim ((0 , 2000 ))
161
- plt .legend (loc = 'best' )
145
+ plt .xlabel ('step' );plt .ylabel ('test loss' );plt .ylim ((0 , 2000 ));plt .legend (loc = 'best' )
162
146
163
147
# evaluation
164
148
# set net to eval mode to freeze the parameters in batch normalization layers
0 commit comments