Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 36769e3

Browse files
jon-towsaeta
authored andcommitted
Add AdaMax optimizer (#304)
1 parent 472b29f commit 36769e3

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

Sources/TensorFlow/Optimizer.swift

+100
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,106 @@ public class Adam<Model: Layer>: Optimizer
132132
}
133133
}
134134

135+
/// AdaMax optimizer.
136+
///
137+
/// A variant of Adam based on the infinity-norm.
138+
///
139+
/// Reference: Section 7 of ["Adam - A Method for Stochastic Optimization"](
140+
/// https://arxiv.org/abs/1412.6980v8)
141+
public class AdaMax<Model: Layer>: Optimizer
142+
where Model.AllDifferentiableVariables == Model.TangentVector {
143+
public typealias Model = Model
144+
/// The learning rate.
145+
public var learningRate: Float
146+
/// Decay rate used to estimate the first moment (mean) of gradients.
147+
public var beta1: Float
148+
/// Decay rate used to estimate the exponentially weighted infinity norm.
149+
public var beta2: Float
150+
/// A small scalar added to the denominator to improve numerical stability.
151+
public var epsilon: Float
152+
/// The learning rate decay.
153+
public var decay: Float
154+
/// The step count.
155+
public var step: Int = 0
156+
/// The first moments of the weights.
157+
public var firstMoments: Model.TangentVector
158+
/// The exponentially weighted infinity norm of the weights.
159+
public var infinityNorm: Model.TangentVector
160+
161+
/// Note: The default parameters follow those provided in the paper.
162+
public init(
163+
for model: __shared Model,
164+
learningRate: Float = 0.002,
165+
beta1: Float = 0.9,
166+
beta2: Float = 0.999,
167+
epsilon: Float = 1e-8,
168+
decay: Float = 0
169+
) {
170+
precondition(learningRate >= 0, "Learning rate must be non-negative.")
171+
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1.")
172+
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1.")
173+
precondition(decay >= 0, "Learning rate decay must be non-negative.")
174+
175+
self.learningRate = learningRate
176+
self.beta1 = beta1
177+
self.beta2 = beta2
178+
self.epsilon = epsilon
179+
self.decay = decay
180+
181+
// Initialize first moments and infinity norm to be zeros of the same shape.
182+
// We can't use `Model.AllDifferentiableVariables.zero` due to the
183+
// interaction between Key Paths and Differentiable Arrays.
184+
firstMoments = model.allDifferentiableVariables
185+
infinityNorm = model.allDifferentiableVariables
186+
for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
187+
firstMoments[keyPath: kp].resetToZero()
188+
infinityNorm[keyPath: kp].resetToZero()
189+
}
190+
for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
191+
firstMoments[keyPath: kp].resetToZero()
192+
infinityNorm[keyPath: kp].resetToZero()
193+
}
194+
}
195+
196+
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
197+
public func update(_ model: inout Model.AllDifferentiableVariables,
198+
along direction: Model.AllDifferentiableVariables) {
199+
step += 1
200+
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
201+
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
202+
// this expression in reasonable time" error.
203+
var stepSize = learningRate * sqrt(1 - pow(beta2, Float(step)))
204+
stepSize = stepSize / (1 - pow(beta1, Float(step)))
205+
// Update `Tensor<Float>` & `Tensor<Double>` variables.
206+
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
207+
firstMoments[keyPath: kp] =
208+
(beta1 * firstMoments[keyPath: kp]) + (1 - beta1) * direction[keyPath: kp]
209+
infinityNorm[keyPath: kp] =
210+
max(beta2 * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
211+
let biasCorrection = stepSize / (1 - pow(beta1, Float(step)))
212+
model[keyPath: kp] -=
213+
biasCorrection * firstMoments[keyPath: kp]
214+
/ (infinityNorm[keyPath: kp] + Float(self.epsilon))
215+
}
216+
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
217+
firstMoments[keyPath: kp] =
218+
Double(beta1) * firstMoments[keyPath: kp]
219+
+ Double(1 - beta2) * direction[keyPath: kp]
220+
infinityNorm[keyPath: kp] =
221+
max(Double(beta2) * infinityNorm[keyPath: kp], abs(direction[keyPath: kp]))
222+
let biasCorrection = Double(stepSize) / Double(1 - pow(beta1, Float(step)))
223+
model[keyPath: kp] -=
224+
biasCorrection * firstMoments[keyPath: kp]
225+
/ (infinityNorm[keyPath: kp] + Double(self.epsilon))
226+
}
227+
}
228+
229+
public func update(_ model: inout Model,
230+
along direction: Model.TangentVector) {
231+
update(&model.allDifferentiableVariables, along: direction)
232+
}
233+
}
234+
135235
/// RMSProp optimizer.
136236
///
137237
/// It is recommended to leave the parameters of this optimizer at their default values (except the

Tests/TensorFlowTests/SequentialTests.swift

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ final class SequentialTests: XCTestCase {
3232
let sgd = SGD(for: model, learningRate: 0.02)
3333
let rmsprop = RMSProp(for: model, learningRate: 0.02)
3434
let adam = Adam(for: model, learningRate: 0.02)
35+
let adamax = AdaMax(for: model, learningRate: 0.02)
3536
let adagrad = AdaGrad(for: model, learningRate: 0.02)
3637
let adadelta = AdaDelta(for: model, learningRate: 0.02)
3738
let x: Tensor<Float> = [[0, 0], [0, 1], [1, 0], [1, 1]]
@@ -48,6 +49,8 @@ final class SequentialTests: XCTestCase {
4849
rmsprop.update(&model.allDifferentiableVariables, along: 𝛁model)
4950
adam.update(&model, along: 𝛁model)
5051
adam.update(&model.allDifferentiableVariables, along: 𝛁model)
52+
adamax.update(&model, along: 𝛁model)
53+
adamax.update(&model.allDifferentiableVariables, along: 𝛁model)
5154
adagrad.update(&model, along: 𝛁model)
5255
adagrad.update(&model.allDifferentiableVariables, along: 𝛁model)
5356
adadelta.update(&model, along: 𝛁model)

0 commit comments

Comments
 (0)