Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b6229d8

Browse files
authored
[Optimizers] Simplified the Adam optimizer. (#306)
1 parent 36769e3 commit b6229d8

File tree

3 files changed

+43
-59
lines changed

3 files changed

+43
-59
lines changed

Sources/TensorFlow/Operators/Math.swift

+18-2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,29 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
infix operator .>: ComparisonPrecedence
16-
infix operator .==: ComparisonPrecedence
15+
infix operator .> : ComparisonPrecedence
16+
infix operator .== : ComparisonPrecedence
1717

1818
// TODO:
1919
// - Consider explicit broadcasting for elementwise binary ops when
2020
// scalarization and rank getter are implemented.
2121

22+
// TODO: Remove the following extension once `./` and `./=` are defined for
23+
// `PointwiseMultiplicative`.
24+
25+
infix operator ./ : MultiplicationPrecedence
26+
infix operator ./= : AssignmentPrecedence
27+
28+
public extension PointwiseMultiplicative {
29+
static func ./ (lhs: Self, rhs: Self) -> Self {
30+
lhs .* rhs.reciprocal
31+
}
32+
33+
static func ./= (lhs: inout Self, rhs: Self) {
34+
lhs = lhs ./ rhs
35+
}
36+
}
37+
2238
//===------------------------------------------------------------------------------------------===//
2339
// Generic elementary functions
2440
//===------------------------------------------------------------------------------------------===//

Sources/TensorFlow/Optimizer.swift

+24-56
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
/// A machine learning optimizer.
15+
/// A numerical optimizer.
1616
///
17-
/// Optimizers apply an optimization algorithm to update the differentiable variables of a machine
18-
/// learning model.
17+
/// Optimizers apply an optimization algorithm to update the differentiable models.
1918
public protocol Optimizer {
2019
/// The type of the model whose parameters are optimized.
2120
associatedtype Model: Differentiable
2221
/// The scalar parameter type.
2322
associatedtype Scalar: FloatingPoint
2423
/// The learning rate.
2524
var learningRate: Scalar { get set }
26-
/// Updates the specified differentiable variables along the specified
27-
/// direction.
25+
/// Updates the specified differentiable variables along the specified direction.
2826
mutating func update(_ variables: inout Model, along direction: Model.TangentVector)
2927
}
3028

@@ -38,16 +36,15 @@ fileprivate extension Tensor where Scalar: Numeric {
3836
///
3937
/// Reference: ["Adam - A Method for Stochastic Optimization"](
4038
/// https://arxiv.org/abs/1412.6980v8)
41-
public class Adam<Model: Layer>: Optimizer
42-
where Model.AllDifferentiableVariables == Model.TangentVector {
39+
public class Adam<Model: Differentiable>: Optimizer
40+
where Model.TangentVector: VectorProtocol & PointwiseMultiplicative & ElementaryFunctions,
41+
Model.TangentVector.VectorSpaceScalar == Float {
4342
public typealias Model = Model
4443
/// The learning rate.
4544
public var learningRate: Float
46-
/// A coefficient used to calculate the first and second moments of
47-
/// gradients.
45+
/// A coefficient used to calculate the first and second moments of the gradients.
4846
public var beta1: Float
49-
/// A coefficient used to calculate the first and second moments of
50-
/// gradients.
47+
/// A coefficient used to calculate the first and second moments of the gradients.
5148
public var beta2: Float
5249
/// A small scalar added to the denominator to improve numerical stability.
5350
public var epsilon: Float
@@ -56,9 +53,9 @@ public class Adam<Model: Layer>: Optimizer
5653
/// The current step.
5754
public var step: Int = 0
5855
/// The first moments of the weights.
59-
public var firstMoments: Model.AllDifferentiableVariables
56+
public var firstMoments: Model.TangentVector = .zero
6057
/// The second moments of the weights.
61-
public var secondMoments: Model.AllDifferentiableVariables
58+
public var secondMoments: Model.TangentVector = .zero
6259

6360
public init(
6461
for model: __shared Model,
@@ -78,57 +75,28 @@ public class Adam<Model: Layer>: Optimizer
7875
self.beta2 = beta2
7976
self.epsilon = epsilon
8077
self.decay = decay
78+
}
8179

82-
// Initialize first & second moments to be zeros of the same shape.
83-
// We can't use `Model.AllDifferentiableVariables.zero` due to the
84-
// interaction between Key Paths and Differentiable Arrays.
85-
firstMoments = model.allDifferentiableVariables
86-
secondMoments = model.allDifferentiableVariables
87-
for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
88-
firstMoments[keyPath: kp].resetToZero()
89-
secondMoments[keyPath: kp].resetToZero()
90-
}
91-
for kp in firstMoments.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
92-
firstMoments[keyPath: kp].resetToZero()
93-
secondMoments[keyPath: kp].resetToZero()
94-
}
80+
public func update(_ model: inout Model, along direction: Model.TangentVector) {
81+
update(&model.allDifferentiableVariables, along: direction)
9582
}
9683

9784
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
98-
public func update(_ model: inout Model.AllDifferentiableVariables,
99-
along direction: Model.AllDifferentiableVariables) {
85+
public func update(
86+
_ model: inout Model.AllDifferentiableVariables,
87+
along direction: Model.TangentVector
88+
) {
10089
step += 1
10190
let learningRate = self.learningRate * 1 / (1 + decay * Float(step))
102-
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
103-
// this expression in reasonable time" error.
91+
// Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
92+
// unable to type-check this expression in reasonable time" error.
10493
var stepSize = learningRate * sqrt(1 - pow(beta2, Float(step)))
10594
stepSize = stepSize / (1 - pow(beta1, Float(step)))
106-
// Update Float & Double Tensor variables.
107-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Float>.self) {
108-
firstMoments[keyPath: kp] =
109-
firstMoments[keyPath: kp] * beta1 + (1 - beta1) * direction[keyPath: kp]
110-
secondMoments[keyPath: kp] =
111-
secondMoments[keyPath: kp] * beta2 + (1 - beta2) *
112-
direction[keyPath: kp] * direction[keyPath: kp]
113-
model[keyPath: kp] -=
114-
stepSize * firstMoments[keyPath: kp] / (sqrt(secondMoments[keyPath: kp]) + epsilon)
115-
}
116-
for kp in model.recursivelyAllWritableKeyPaths(to: Tensor<Double>.self) {
117-
firstMoments[keyPath: kp] =
118-
firstMoments[keyPath: kp] * Double(beta1) +
119-
Double((1 - beta1)) * direction[keyPath: kp]
120-
secondMoments[keyPath: kp] =
121-
secondMoments[keyPath: kp] * Double(beta2) + Double(1 - beta2) *
122-
direction[keyPath: kp] * direction[keyPath: kp]
123-
model[keyPath: kp] -=
124-
Double(stepSize) * firstMoments[keyPath: kp] /
125-
sqrt(secondMoments[keyPath: kp]) + Double(epsilon)
126-
}
127-
}
128-
129-
public func update(_ model: inout Model,
130-
along direction: Model.TangentVector) {
131-
update(&model.allDifferentiableVariables, along: direction)
95+
firstMoments = firstMoments * beta1 + direction * (1 - beta1)
96+
secondMoments = secondMoments * beta2
97+
secondMoments += direction .* direction * (1 - beta2)
98+
let denominator = Model.TangentVector.sqrt(secondMoments) + epsilon
99+
model.move(along: -stepSize * firstMoments ./ denominator)
132100
}
133101
}
134102

Tests/TensorFlowTests/SequentialTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ final class SequentialTests: XCTestCase {
5757
adadelta.update(&model.allDifferentiableVariables, along: 𝛁model)
5858
}
5959
XCTAssertEqual(model.inferring(from: [[0, 0], [0, 1], [1, 0], [1, 1]]),
60-
[[0.47683996], [0.47683996], [0.47683996], [0.47683996]])
60+
[[0.47620767], [0.47620767], [0.47620767], [0.47620767]])
6161
}
6262

6363
static var allTests = [

0 commit comments

Comments
 (0)