@@ -336,3 +336,93 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
336
336
model. move ( along: - stepSize * firstMoments ./ denominator)
337
337
}
338
338
}
339
+
340
+ /// AMSGrad optimizer.
341
+ ///
342
+ /// This algorithm is a modification of Adam with better convergence properties when close to local
343
+ /// optima.
344
+ ///
345
+ /// Reference: ["On the Convergence of Adam and Beyond"](
346
+ /// https://openreview.net/pdf?id=ryQu7f-RZ)
347
+ public class AMSGrad < Model: Differentiable & KeyPathIterable > : Optimizer
348
+ where Model. TangentVector: VectorProtocol & PointwiseMultiplicative & ElementaryFunctions ,
349
+ Model. TangentVector. VectorSpaceScalar == Float ,
350
+ Model. AllDifferentiableVariables: KeyPathIterable ,
351
+ Model. AllDifferentiableVariables == Model . TangentVector {
352
+ public typealias Model = Model
353
+ /// The learning rate.
354
+ public var learningRate : Float
355
+ /// A coefficient used to calculate the first and second moments of the gradients.
356
+ public var beta1 : Float
357
+ /// A coefficient used to calculate the first and second moments of the gradients.
358
+ public var beta2 : Float
359
+ /// A small scalar added to the denominator to improve numerical stability.
360
+ public var epsilon : Float
361
+ /// The learning rate decay.
362
+ public var decay : Float
363
+ /// The current step.
364
+ public var step : Int = 0
365
+ /// The first moments of the weights.
366
+ public var firstMoments : Model . TangentVector = . zero
367
+ /// The second moments of the weights.
368
+ public var secondMoments : Model . TangentVector = . zero
369
+ /// The maximum of the second moments of the weights.
370
+ public var secondMomentsMax : Model . TangentVector = . zero
371
+
372
+ public init (
373
+ for model: __shared Model,
374
+ learningRate: Float = 1e-3 ,
375
+ beta1: Float = 0.9 ,
376
+ beta2: Float = 0.999 ,
377
+ epsilon: Float = 1e-8 ,
378
+ decay: Float = 0
379
+ ) {
380
+ precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
381
+ precondition ( 0 <= beta1 && beta1 <= 1 , " Beta parameter must be between 0 and 1 " )
382
+ precondition ( 0 <= beta2 && beta2 <= 1 , " Beta parameter must be between 0 and 1 " )
383
+ precondition ( decay >= 0 , " Learning rate decay must be non-negative " )
384
+
385
+ self . learningRate = learningRate
386
+ self . beta1 = beta1
387
+ self . beta2 = beta2
388
+ self . epsilon = epsilon
389
+ self . decay = decay
390
+ }
391
+
392
+ public func update( _ model: inout Model , along direction: Model . TangentVector ) {
393
+ update ( & model. allDifferentiableVariables, along: direction)
394
+ }
395
+
396
+ // TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
397
+ public func update(
398
+ _ model: inout Model . AllDifferentiableVariables ,
399
+ along direction: Model . TangentVector
400
+ ) {
401
+ self . step += 1
402
+ let step = Float ( self . step)
403
+ let beta1Power = pow ( beta1, step)
404
+ let beta2Power = pow ( beta2, step)
405
+ let learningRate = self . learningRate * 1 / ( 1 + decay * step)
406
+ // Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
407
+ // unable to type-check this expression in reasonable time" error.
408
+ var stepSize = learningRate * sqrt( 1 - pow( beta2Power, step) )
409
+ stepSize = stepSize / ( 1 - pow( beta1Power, step) )
410
+ firstMoments = firstMoments * beta1 + direction * ( 1 - beta1)
411
+ secondMoments = secondMoments * beta2
412
+ secondMoments += direction .* direction * ( 1 - beta2)
413
+
414
+ // Update `secondMomentsMax` using a key path approach because `max(_:_:)` cannot be
415
+ // currently applied in a simpler manner.
416
+ for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
417
+ secondMomentsMax [ keyPath: kp] = max (
418
+ secondMomentsMax [ keyPath: kp] , secondMoments [ keyPath: kp] )
419
+ }
420
+ for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Double> . self ) {
421
+ secondMomentsMax [ keyPath: kp] = max (
422
+ secondMomentsMax [ keyPath: kp] , secondMoments [ keyPath: kp] )
423
+ }
424
+
425
+ let denominator = Model . TangentVector. sqrt ( secondMomentsMax) + epsilon
426
+ model. move ( along: - stepSize * firstMoments ./ denominator)
427
+ }
428
+ }
0 commit comments