12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
- /// A machine learning optimizer.
15
+ /// A numerical optimizer.
16
16
///
17
- /// Optimizers apply an optimization algorithm to update the differentiable variables of a machine
18
- /// learning model.
17
+ /// Optimizers apply an optimization algorithm to update the differentiable models.
19
18
public protocol Optimizer {
20
19
/// The type of the model whose parameters are optimized.
21
20
associatedtype Model : Differentiable
22
21
/// The scalar parameter type.
23
22
associatedtype Scalar : FloatingPoint
24
23
/// The learning rate.
25
24
var learningRate : Scalar { get set }
26
- /// Updates the specified differentiable variables along the specified
27
- /// direction.
25
+ /// Updates the specified differentiable variables along the specified direction.
28
26
mutating func update( _ variables: inout Model , along direction: Model . TangentVector )
29
27
}
30
28
@@ -38,16 +36,15 @@ fileprivate extension Tensor where Scalar: Numeric {
38
36
///
39
37
/// Reference: ["Adam - A Method for Stochastic Optimization"](
40
38
/// https://arxiv.org/abs/1412.6980v8)
41
- public class Adam < Model: Layer > : Optimizer
42
- where Model. AllDifferentiableVariables == Model . TangentVector {
39
+ public class Adam < Model: Differentiable > : Optimizer
40
+ where Model. TangentVector: VectorProtocol & PointwiseMultiplicative & ElementaryFunctions ,
41
+ Model. TangentVector. VectorSpaceScalar == Float {
43
42
public typealias Model = Model
44
43
/// The learning rate.
45
44
public var learningRate : Float
46
- /// A coefficient used to calculate the first and second moments of
47
- /// gradients.
45
+ /// A coefficient used to calculate the first and second moments of the gradients.
48
46
public var beta1 : Float
49
- /// A coefficient used to calculate the first and second moments of
50
- /// gradients.
47
+ /// A coefficient used to calculate the first and second moments of the gradients.
51
48
public var beta2 : Float
52
49
/// A small scalar added to the denominator to improve numerical stability.
53
50
public var epsilon : Float
@@ -56,9 +53,9 @@ public class Adam<Model: Layer>: Optimizer
56
53
/// The current step.
57
54
public var step : Int = 0
58
55
/// The first moments of the weights.
59
- public var firstMoments : Model . AllDifferentiableVariables
56
+ public var firstMoments : Model . TangentVector = . zero
60
57
/// The second moments of the weights.
61
- public var secondMoments : Model . AllDifferentiableVariables
58
+ public var secondMoments : Model . TangentVector = . zero
62
59
63
60
public init (
64
61
for model: __shared Model,
@@ -78,57 +75,28 @@ public class Adam<Model: Layer>: Optimizer
78
75
self . beta2 = beta2
79
76
self . epsilon = epsilon
80
77
self . decay = decay
78
+ }
81
79
82
- // Initialize first & second moments to be zeros of the same shape.
83
- // We can't use `Model.AllDifferentiableVariables.zero` due to the
84
- // interaction between Key Paths and Differentiable Arrays.
85
- firstMoments = model. allDifferentiableVariables
86
- secondMoments = model. allDifferentiableVariables
87
- for kp in firstMoments. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
88
- firstMoments [ keyPath: kp] . resetToZero ( )
89
- secondMoments [ keyPath: kp] . resetToZero ( )
90
- }
91
- for kp in firstMoments. recursivelyAllWritableKeyPaths ( to: Tensor< Double> . self ) {
92
- firstMoments [ keyPath: kp] . resetToZero ( )
93
- secondMoments [ keyPath: kp] . resetToZero ( )
94
- }
80
+ public func update( _ model: inout Model , along direction: Model . TangentVector ) {
81
+ update ( & model. allDifferentiableVariables, along: direction)
95
82
}
96
83
97
84
// TODO: Deprecate this when `Differentiable.AllDifferentiableVariables` is removed.
98
- public func update( _ model: inout Model . AllDifferentiableVariables ,
99
- along direction: Model . AllDifferentiableVariables ) {
85
+ public func update(
86
+ _ model: inout Model . AllDifferentiableVariables ,
87
+ along direction: Model . TangentVector
88
+ ) {
100
89
step += 1
101
90
let learningRate = self . learningRate * 1 / ( 1 + decay * Float( step) )
102
- // Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
103
- // this expression in reasonable time" error.
91
+ // Note: `stepSize` and `secondMoments` are split into two lines to avoid the "compiler is
92
+ // unable to type-check this expression in reasonable time" error.
104
93
var stepSize = learningRate * sqrt( 1 - pow( beta2, Float ( step) ) )
105
94
stepSize = stepSize / ( 1 - pow( beta1, Float ( step) ) )
106
- // Update Float & Double Tensor variables.
107
- for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Float> . self ) {
108
- firstMoments [ keyPath: kp] =
109
- firstMoments [ keyPath: kp] * beta1 + ( 1 - beta1) * direction[ keyPath: kp]
110
- secondMoments [ keyPath: kp] =
111
- secondMoments [ keyPath: kp] * beta2 + ( 1 - beta2) *
112
- direction[ keyPath: kp] * direction[ keyPath: kp]
113
- model [ keyPath: kp] -=
114
- stepSize * firstMoments[ keyPath: kp] / ( sqrt ( secondMoments [ keyPath: kp] ) + epsilon)
115
- }
116
- for kp in model. recursivelyAllWritableKeyPaths ( to: Tensor< Double> . self ) {
117
- firstMoments [ keyPath: kp] =
118
- firstMoments [ keyPath: kp] * Double( beta1) +
119
- Double( ( 1 - beta1) ) * direction[ keyPath: kp]
120
- secondMoments [ keyPath: kp] =
121
- secondMoments [ keyPath: kp] * Double( beta2) + Double( 1 - beta2) *
122
- direction[ keyPath: kp] * direction[ keyPath: kp]
123
- model [ keyPath: kp] -=
124
- Double ( stepSize) * firstMoments[ keyPath: kp] /
125
- sqrt( secondMoments [ keyPath: kp] ) + Double( epsilon)
126
- }
127
- }
128
-
129
- public func update( _ model: inout Model ,
130
- along direction: Model . TangentVector ) {
131
- update ( & model. allDifferentiableVariables, along: direction)
95
+ firstMoments = firstMoments * beta1 + direction * ( 1 - beta1)
96
+ secondMoments = secondMoments * beta2
97
+ secondMoments += direction .* direction * ( 1 - beta2)
98
+ let denominator = Model . TangentVector. sqrt ( secondMoments) + epsilon
99
+ model. move ( along: - stepSize * firstMoments ./ denominator)
132
100
}
133
101
}
134
102
0 commit comments