@@ -50,10 +50,15 @@ def main():
5050 W2 = np .random .randn (M , K ) / np .sqrt (M )
5151 b2 = np .zeros (K )
5252
53+ # save initial weights
54+ W1_0 = W1 .copy ()
55+ b1_0 = b1 .copy ()
56+ W2_0 = W2 .copy ()
57+ b2_0 = b2 .copy ()
58+
5359 # 1. batch
54- # cost = -16
55- LL_batch = []
56- CR_batch = []
60+ losses_batch = []
61+ errors_batch = []
5762 for i in range (max_iter ):
5863 for j in range (n_batches ):
5964 Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
@@ -68,26 +73,25 @@ def main():
6873 b1 -= lr * (derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1 )
6974
7075 if j % print_period == 0 :
71- # calculate just for LL
7276 pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
73- ll = cost (pY , Ytest_ind )
74- LL_batch .append (ll )
75- print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , ll ))
77+ l = cost (pY , Ytest_ind )
78+ losses_batch .append (l )
79+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
7680
77- err = error_rate (pY , Ytest )
78- CR_batch .append (err )
79- print ("Error rate:" , err )
81+ e = error_rate (pY , Ytest )
82+ errors_batch .append (e )
83+ print ("Error rate:" , e )
8084
8185 pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
8286 print ("Final error rate:" , error_rate (pY , Ytest ))
8387
8488 # 2. batch with momentum
85- W1 = np . random . randn ( D , M ) / np . sqrt ( D )
86- b1 = np . zeros ( M )
87- W2 = np . random . randn ( M , K ) / np . sqrt ( M )
88- b2 = np . zeros ( K )
89- LL_momentum = []
90- CR_momentum = []
89+ W1 = W1_0 . copy ( )
90+ b1 = b1_0 . copy ( )
91+ W2 = W2_0 . copy ( )
92+ b2 = b2_0 . copy ( )
93+ losses_momentum = []
94+ errors_momentum = []
9195 mu = 0.9
9296 dW2 = 0
9397 db2 = 0
@@ -99,100 +103,92 @@ def main():
99103 Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
100104 pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
101105
106+ # gradients
107+ gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
108+ gb2 = derivative_b2 (Ybatch , pYbatch ) + reg * b2
109+ gW1 = derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
110+ gb1 = derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1
111+
112+ # update velocities
113+ dW2 = mu * dW2 - lr * gW2
114+ db2 = mu * db2 - lr * gb2
115+ dW1 = mu * dW1 - lr * gW1
116+ db1 = mu * db1 - lr * gb1
117+
102118 # updates
103- dW2 = mu * dW2 - lr * (derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2 )
104119 W2 += dW2
105- db2 = mu * db2 - lr * (derivative_b2 (Ybatch , pYbatch ) + reg * b2 )
106120 b2 += db2
107- dW1 = mu * dW1 - lr * (derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1 )
108121 W1 += dW1
109- db1 = mu * db1 - lr * (derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1 )
110122 b1 += db1
111123
112124 if j % print_period == 0 :
113- # calculate just for LL
114125 pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
115- # print "pY:", pY
116- ll = cost (pY , Ytest_ind )
117- LL_momentum .append (ll )
118- print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , ll ))
119-
120- err = error_rate (pY , Ytest )
121- CR_momentum .append (err )
122- print ("Error rate:" , err )
126+ l = cost (pY , Ytest_ind )
127+ losses_momentum .append (l )
128+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
129+
130+ e = error_rate (pY , Ytest )
131+ errors_momentum .append (e )
132+ print ("Error rate:" , e )
123133 pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
124134 print ("Final error rate:" , error_rate (pY , Ytest ))
125135
126136
127137 # 3. batch with Nesterov momentum
128- W1 = np .random .randn (D , M ) / np .sqrt (D )
129- b1 = np .zeros (M )
130- W2 = np .random .randn (M , K ) / np .sqrt (M )
131- b2 = np .zeros (K )
132- LL_nest = []
133- CR_nest = []
138+ W1 = W1_0 .copy ()
139+ b1 = b1_0 .copy ()
140+ W2 = W2_0 .copy ()
141+ b2 = b2_0 .copy ()
142+
143+ losses_nesterov = []
144+ errors_nesterov = []
145+
134146 mu = 0.9
135- # alternate version uses dW
136- # dW2 = 0
137- # db2 = 0
138- # dW1 = 0
139- # db1 = 0
140147 vW2 = 0
141148 vb2 = 0
142149 vW1 = 0
143150 vb1 = 0
144151 for i in range (max_iter ):
145152 for j in range (n_batches ):
146- # because we want g(t) = grad(f(W(t-1) - lr*mu*dW(t-1)))
147- # dW(t) = mu*dW(t-1) + g(t)
148- # W(t) = W(t-1) - mu*dW(t)
149- W1_tmp = W1 - lr * mu * vW1
150- b1_tmp = b1 - lr * mu * vb1
151- W2_tmp = W2 - lr * mu * vW2
152- b2_tmp = b2 - lr * mu * vb2
153-
154153 Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
155154 Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
156- # pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
157- pYbatch , Z = forward (Xbatch , W1_tmp , b1_tmp , W2_tmp , b2_tmp )
155+ pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
158156
159157 # updates
160- # dW2 = mu*mu*dW2 - (1 + mu)*lr*( derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
161- # W2 += dW2
162- # db2 = mu*mu*db2 - (1 + mu)*lr*(derivative_b2( Ybatch, pYbatch) + reg*b2)
163- # b2 += db2
164- # dW1 = mu*mu*dW1 - (1 + mu)*lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
165- # W1 += dW1
166- # db1 = mu*mu*db1 - (1 + mu)* lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
167- # b1 += db1
168- vW2 = mu * vW2 + derivative_w2 ( Z , Ybatch , pYbatch ) + reg * W2_tmp
169- W2 -= lr * vW2
170- vb2 = mu * vb2 + derivative_b2 ( Ybatch , pYbatch ) + reg * b2_tmp
171- b2 -= lr * vb2
172- vW1 = mu * vW1 + derivative_w1 ( Xbatch , Z , Ybatch , pYbatch , W2_tmp ) + reg * W1_tmp
173- W1 -= lr * vW1
174- vb1 = mu * vb1 + derivative_b1 ( Z , Ybatch , pYbatch , W2_tmp ) + reg * b1_tmp
175- b1 -= lr * vb1
158+ gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
159+ gb2 = derivative_b2 ( Ybatch , pYbatch ) + reg * b2
160+ gW1 = derivative_w1 ( Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
161+ gb1 = derivative_b1 ( Z , Ybatch , pYbatch , W2 ) + reg * b1
162+
163+ # v update
164+ vW2 = mu * vW2 - lr * gW2
165+ vb2 = mu * vb2 - lr * gb2
166+ vW1 = mu * vW1 - lr * gW1
167+ vb1 = mu * vb1 - lr * gb1
168+
169+ # param update
170+ W2 + = mu * vW2 - lr * gW2
171+ b2 += mu * vb2 - lr * gb2
172+ W1 + = mu * vW1 - lr * gW1
173+ b1 += mu * vb1 - lr * gb1
176174
177175 if j % print_period == 0 :
178- # calculate just for LL
179176 pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
180- # print "pY:", pY
181- ll = cost (pY , Ytest_ind )
182- LL_nest .append (ll )
183- print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , ll ))
184-
185- err = error_rate (pY , Ytest )
186- CR_nest .append (err )
187- print ("Error rate:" , err )
177+ l = cost (pY , Ytest_ind )
178+ losses_nesterov .append (l )
179+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
180+
181+ e = error_rate (pY , Ytest )
182+ errors_nesterov .append (e )
183+ print ("Error rate:" , e )
188184 pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
189185 print ("Final error rate:" , error_rate (pY , Ytest ))
190186
191187
192188
193- plt .plot (LL_batch , label = "batch" )
194- plt .plot (LL_momentum , label = "momentum" )
195- plt .plot (LL_nest , label = "nesterov" )
189+ plt .plot (losses_batch , label = "batch" )
190+ plt .plot (losses_momentum , label = "momentum" )
191+ plt .plot (losses_nesterov , label = "nesterov" )
196192 plt .legend ()
197193 plt .show ()
198194
0 commit comments