Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3e23041

Browse files
vballolit-aedan-zheng
authored
Fix RAdam stepSize (#700)
`RAdam` step size was missing a factor of `learningRate / (1 - beta1Power)`. Update `RAdam` numerical correctness test. Co-authored-by: t.ae <t-ae@users.noreply.github.com> Co-authored-by: Dan Zheng <danielzheng@google.com>
1 parent a89e745 commit 3e23041

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

Sources/TensorFlow/Optimizers/MomentumBased.swift

+7-7
Original file line numberDiff line numberDiff line change
@@ -592,18 +592,18 @@ where
592592
let N_sma_inf = 2 / (1 - beta2) - 1
593593
let N_sma_t = N_sma_inf - 2 * step * beta2Power / (1 - beta2Power)
594594

595-
if N_sma_t > 5 {
595+
if N_sma_t >= 5 {
596596
// Compute bias-corrected second moments, rectification and adapted momentum.
597597
let secondMoments_h = Model.TangentVector.sqrt(secondMoments).adding(epsilon)
598-
let stepSize = sqrtf(
599-
(N_sma_t - 4) * (N_sma_t - 2) * N_sma_inf
600-
/ ((N_sma_inf - 4) * (N_sma_inf - 2) * (N_sma_t)))
598+
let stepSize =
599+
sqrtf(
600+
(N_sma_t - 4) * (N_sma_t - 2) * N_sma_inf
601+
/ ((N_sma_inf - 4) * (N_sma_inf - 2) * (N_sma_t))) * learningRate / (1 - beta1Power)
601602
model.move(
602-
along: (firstMoments ./ secondMoments_h).scaled(
603-
by: -stepSize * sqrtf(1 - beta2Power)))
603+
along: (firstMoments ./ secondMoments_h).scaled(by: -stepSize * sqrtf(1 - beta2Power)))
604604
} else {
605605
// Update with un-adapted momentum.
606-
let stepSize = self.learningRate * step / (1 - beta1Power)
606+
let stepSize = learningRate / (1 - beta1Power)
607607
model.move(along: firstMoments.scaled(by: -stepSize))
608608
}
609609
}

Tests/TensorFlowTests/OptimizerTests.swift

+29-6
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,42 @@ class OptimizerTests: XCTestCase {
1919
/// A dense layer for testing optimizer convergence.
2020
// TODO: Consider replacing users with `Dense`.
2121
struct Model: Layer {
22-
var dense1 = Dense<Float>(weight: [[0.8]], bias: [0.8], activation: identity)
22+
var dense = Dense<Float>(weight: [[0.8]], bias: [0.8], activation: identity)
2323

2424
@differentiable
2525
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
26-
dense1(input)
26+
dense(input)
2727
}
2828
}
2929

30+
/// Check expected weight and bias after updating `model` with `optimizer` `stepCount` times.
31+
///
32+
/// - Note: optimizer correctness reference implementations exist at
33+
/// `Utilities/ReferenceImplementations/optimizers.py`.
34+
func testCorrectness<Opt: Optimizer>(
35+
optimizer: Opt,
36+
model: Model,
37+
expectedWeight: Tensor<Float>,
38+
expectedBias: Tensor<Float>,
39+
stepCount: Int = 1000,
40+
file: StaticString = #file,
41+
line: UInt = #line
42+
) where Opt.Model == Model {
43+
var optimizer = optimizer
44+
var model = model
45+
let grad = Model.TangentVector(dense: .init(weight: [[0.1]], bias: [0.2]))
46+
for _ in 0..<stepCount {
47+
optimizer.update(&model, along: grad)
48+
}
49+
XCTAssertEqual(model.dense.weight, expectedWeight, file: file, line: line)
50+
XCTAssertEqual(model.dense.bias, expectedBias, file: file, line: line)
51+
}
52+
3053
/// Check that `model` converges after updating it with `optimizer` `stepCount` times.
3154
func testConvergence<Opt: Optimizer>(
3255
optimizer: Opt,
3356
model: Model,
57+
stepCount: Int = 1000,
3458
file: StaticString = #file,
3559
line: UInt = #line
3660
) where Opt.Model == Model {
@@ -40,7 +64,7 @@ class OptimizerTests: XCTestCase {
4064
.reshaped(to: [-1, 1])
4165
let y: Tensor<Float> = x + 1
4266

43-
for _ in 0..<1000 {
67+
for _ in 0..<stepCount {
4468
let grad = gradient(at: model) { model -> Tensor<Float> in
4569
let yy = model(x)
4670
return meanSquaredError(predicted: yy, expected: y)
@@ -102,7 +126,7 @@ class OptimizerTests: XCTestCase {
102126
func testRAdam() {
103127
let model = Model()
104128
let optimizer = RAdam(for: model)
105-
testConvergence(optimizer: optimizer, model: model)
129+
testConvergence(optimizer: optimizer, model: model, stepCount: 1400)
106130
}
107131

108132
/// A `Tensor<Float>` wrapper for testing optimizer numerical correctness.
@@ -207,10 +231,9 @@ class OptimizerTests: XCTestCase {
207231
let optimizer = RAdam(for: values, learningRate: 1e-3, epsilon: 1e-7)
208232
// FIXME(TF-759): Investigate large differences with Python reference implementation results:
209233
// `[ 0.46914074, -0.44463935, -0.44513944]`.
210-
// Pending fix: https://github.com/tensorflow/swift-apis/pull/700
211234
testNumericalCorrectness(
212235
optimizer: optimizer, startingValues: values,
213-
expectedValues: [ 443.81192, -443.80478, -443.85016])
236+
expectedValues: [ 0.44664007, -0.44463903, -0.45914108])
214237
}
215238

216239
static var allTests = [

0 commit comments

Comments
 (0)