@@ -167,96 +167,3 @@ def step(self, closure=None):
167
167
p .data .add_ (- group ['lr' ] * group ['weight_decay' ], p .data )
168
168
169
169
return loss
170
-
171
-
172
-
173
- class Lamb (Optimizer ):
174
- """ Implements the LAMB algorithm (Layer-wise Adaptive Moments optimizer for Batch training).
175
-
176
- Adapted from the huggingface/transformers ADAM optimizer
177
- Inspired from the Google Research implementation available in ALBERT: https://github.com/google-research/google-research/blob/master/albert/lamb_optimizer.py
178
- Inspired from cybertronai's PyTorch LAMB implementation: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
179
-
180
-
181
- Parameters:
182
- lr (float): learning rate. Default 1e-3.
183
- betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
184
- eps (float): Adams epsilon. Default: 1e-6
185
- weight_decay (float): Weight decay. Default: 0.0
186
- """
187
-
188
- def __init__ (self , params , lr = 1e-3 , betas = (0.9 , 0.999 ), eps = 1e-6 , weight_decay = 0.0 , correct_bias = True ):
189
- if lr < 0.0 :
190
- raise ValueError ("Invalid learning rate: {} - should be >= 0.0" .format (lr ))
191
- if not 0.0 <= betas [0 ] < 1.0 :
192
- raise ValueError ("Invalid beta parameter: {} - should be in [0.0, 1.0[" .format (betas [0 ]))
193
- if not 0.0 <= betas [1 ] < 1.0 :
194
- raise ValueError ("Invalid beta parameter: {} - should be in [0.0, 1.0[" .format (betas [1 ]))
195
- if not 0.0 <= eps :
196
- raise ValueError ("Invalid epsilon value: {} - should be >= 0.0" .format (eps ))
197
- defaults = dict (lr = lr , betas = betas , eps = eps , weight_decay = weight_decay ,
198
- correct_bias = correct_bias )
199
- super (Lamb , self ).__init__ (params , defaults )
200
-
201
- def step (self , closure = None ):
202
- """Performs a single optimization step.
203
-
204
- Arguments:
205
- closure (callable, optional): A closure that reevaluates the model
206
- and returns the loss.
207
- """
208
- loss = None
209
- if closure is not None :
210
- loss = closure ()
211
-
212
- for group in self .param_groups :
213
- for p in group ['params' ]:
214
- if p .grad is None :
215
- continue
216
- grad = p .grad .data
217
- if grad .is_sparse :
218
- raise RuntimeError ('LAMB does not support sparse gradients.' )
219
-
220
- state = self .state [p ]
221
-
222
- # State initialization
223
- if len (state ) == 0 :
224
- state ['step' ] = 0
225
- # Exponential moving average of gradient values
226
- state ['exp_avg' ] = torch .zeros_like (p .data )
227
- # Exponential moving average of squared gradient values
228
- state ['exp_avg_sq' ] = torch .zeros_like (p .data )
229
-
230
- exp_avg , exp_avg_sq = state ['exp_avg' ], state ['exp_avg_sq' ]
231
- beta1 , beta2 = group ['betas' ]
232
-
233
- state ['step' ] += 1
234
-
235
- # Decay the first and second moment running average coefficient
236
- # In-place operations to update the averages at the same time
237
- exp_avg .mul_ (beta1 ).add_ (1.0 - beta1 , grad )
238
- exp_avg_sq .mul_ (beta2 ).addcmul_ (1.0 - beta2 , grad , grad )
239
- denom = exp_avg_sq .sqrt ().add_ (group ['eps' ])
240
-
241
-
242
- # Inspired from cybertronai's PyTorch LAMB implementation: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
243
- step_size = group ['lr' ]
244
- weight_norm = p .data .pow (2 ).sum ().sqrt ().clamp (0 , 10 )
245
-
246
- adam_step = exp_avg / exp_avg_sq .sqrt ().add (group ['eps' ])
247
- if group ['weight_decay' ] != 0 :
248
- adam_step .add_ (group ['weight_decay' ], p .data )
249
-
250
- adam_norm = adam_step .pow (2 ).sum ().sqrt ()
251
- if weight_norm == 0 or adam_norm == 0 :
252
- trust_ratio = 1
253
- else :
254
- trust_ratio = weight_norm / adam_norm
255
-
256
-
257
- state ['weight_norm' ] = weight_norm
258
- state ['adam_norm' ] = adam_norm
259
- state ['trust_ratio' ] = trust_ratio
260
-
261
- p .data .add_ (- step_size * trust_ratio , adam_step )
262
- return loss
0 commit comments