Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d61d029

Browse files
authoredJun 30, 2019
[Optimizers] Added support for the AMSGrad optimizer. (#314)
1 parent 994d39e commit d61d029

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed
 

‎Sources/TensorFlow/Optimizers/MomentumBased.swift

+90
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,93 @@ public class AdaMax<Model: Differentiable & KeyPathIterable>: Optimizer
336336
model.move(along: -stepSize * firstMoments ./ denominator)
337337
}
338338
}
339+
340+
/// AMSGrad optimizer.
341+
///
342+
/// This algorithm is a modification of Adam with better convergence properties when close to local
343+
/// optima.
344+
///
345+
/// Reference: ["On the Convergence of Adam and Beyond"](
346+
/// https://openreview.net/pdf?id=ryQu7f-RZ)
347+
public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
348+
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative & ElementaryFunctions,
349+
Model.TangentVector.VectorSpaceScalar == Float,
350+
Model.AllDifferentiableVariables: KeyPathIterable,
351+
Model.AllDifferentiableVariables == Model.TangentVector {
352+
public typealias Model = Model
353+
/// The learning rate.
354+
public var learningRate: Float
355+
/// A coefficient used to calculate the first and second moments of the gradients.
356+
public var beta1: Float
357+
/// A coefficient used to calculate the first and second moments of the gradients.
358+
public var beta2: Float
359+
/// A small scalar added to the denominator to improve numerical stability.
360+
public var epsilon: Float
361+
/// The learning rate decay.
362+
public var decay: Float
363+
/// The current step.
364+
public var step: Int = 0
365+
/// The first moments of the weights.
366+
public var firstMoments: Model.TangentVector = .zero
367+
/// The second moments of the weights.
368+
public var secondMoments: Model.TangentVector = .zero
369+
/// The maximum of the second moments of the weights.
370+
public var secondMomentsMax: Model.TangentVector = .zero
371+
372+
public init(
373+
for model: __shared Model,
374+
learningRate: Float = 1e-3,
375+
beta1: Float = 0.9,
376+
beta2: Float = 0.999,
377+
epsilon: Float = 1e-8,
378+
decay: Float = 0
379+
) {
380+
precondition(learningRate >= 0, "Learning rate must be non-negative")
381+
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
382+
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1")
383+
precondition(decay >= 0, "Learning rate decay must be non-negative")
384+
385+
self.learningRate = learningRate
386+
self.beta1 = beta1
387+
self.beta2 = beta2
388+
self.epsilon = epsilon
389+
self.decay = decay
390+
}
391+
392+
public func update(_ model: inout Model, along direction: Model.TangentVector) {
393+
update(&model.allDifferentiableVariables, along: direction)
394+
}
395+
396+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
397+
public func update(
398+
_ model: inout Model.AllDifferentiableVariables,
399+
along direction: Model.TangentVector
400+
) {
401+
self.step += 1
402+
let step = Float(self.step)
403+
let beta1Power = pow(beta1, step)
404+
let beta2Power = pow(beta2, step)
405+
let learningRate = self.learningRate * 1 / (1 + decay * step)
406+
// Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
407+
// unable to type-check this expression in reasonable time" error.
408+
var stepSize = learningRate * sqrt(1 - pow(beta2Power, step))
409+
stepSize = stepSize / (1 - pow(beta1Power, step))
410+
firstMoments = firstMoments * beta1 + direction * (1 - beta1)
411+
secondMoments = secondMoments * beta2
412+
secondMoments += direction .* direction * (1 - beta2)
413+
414+
// Update `secondMomentsMax` using a key path approach because `max(_:_:)` cannot be
415+
// currently applied in a simpler manner.
416+
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
417+
secondMomentsMax[keyPath: kp] = max(
418+
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
419+
}
420+
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
421+
secondMomentsMax[keyPath: kp] = max(
422+
secondMomentsMax[keyPath: kp], secondMoments[keyPath: kp])
423+
}
424+
425+
let denominator = Model.TangentVector.sqrt(secondMomentsMax) + epsilon
426+
model.move(along: -stepSize * firstMoments ./ denominator)
427+
}
428+
}

‎Tests/TensorFlowTests/SequentialTests.swift

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ final class SequentialTests: XCTestCase {
3333
let rmsprop = RMSProp(for: model, learningRate: 0.02)
3434
let adam = Adam(for: model, learningRate: 0.02)
3535
let adamax = AdaMax(for: model, learningRate: 0.02)
36+
let amsgrad = AMSGrad(for: model, learningRate: 0.02)
3637
let adagrad = AdaGrad(for: model, learningRate: 0.02)
3738
let adadelta = AdaDelta(for: model, learningRate: 0.02)
3839
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
@@ -51,13 +52,15 @@ final class SequentialTests: XCTestCase {
5152
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
5253
adamax.update(&model, along: 𝛁model)
5354
adamax.update(&model.allDifferentiableVariables, along: 𝛁model)
55+
amsgrad.update(&model, along: 𝛁model)
56+
amsgrad.update(&model.allDifferentiableVariables, along: 𝛁model)
5457
adagrad.update(&model, along: 𝛁model)
5558
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
5659
adadelta.update(&model, along: 𝛁model)
5760
adadelta.update(&model.allDifferentiableVariables, along: 𝛁model)
5861
}
5962
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
60-
[[0.47620448], [0.47620448], [0.47620448], [0.47620448]])
63+
[[0.52508783], [0.52508783], [0.52508783], [0.52508783]])
6164
}
6265

6366
static var allTests = [

0 commit comments

Comments
 (0)