@@ -19,18 +19,42 @@ class OptimizerTests: XCTestCase {
19
19
/// A dense layer for testing optimizer convergence.
20
20
// TODO: Consider replacing users with `Dense`.
21
21
struct Model : Layer {
22
- var dense1 = Dense < Float > ( weight: [ [ 0.8 ] ] , bias: [ 0.8 ] , activation: identity)
22
+ var dense = Dense < Float > ( weight: [ [ 0.8 ] ] , bias: [ 0.8 ] , activation: identity)
23
23
24
24
@differentiable
25
25
func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
26
- dense1 ( input)
26
+ dense ( input)
27
27
}
28
28
}
29
29
30
+ /// Check expected weight and bias after updating `model` with `optimizer` `stepCount` times.
31
+ ///
32
+ /// - Note: optimizer correctness reference implementations exist at
33
+ /// `Utilities/ReferenceImplementations/optimizers.py`.
34
+ func testCorrectness< Opt: Optimizer > (
35
+ optimizer: Opt ,
36
+ model: Model ,
37
+ expectedWeight: Tensor < Float > ,
38
+ expectedBias: Tensor < Float > ,
39
+ stepCount: Int = 1000 ,
40
+ file: StaticString = #file,
41
+ line: UInt = #line
42
+ ) where Opt. Model == Model {
43
+ var optimizer = optimizer
44
+ var model = model
45
+ let grad = Model . TangentVector ( dense: . init( weight: [ [ 0.1 ] ] , bias: [ 0.2 ] ) )
46
+ for _ in 0 ..< stepCount {
47
+ optimizer. update ( & model, along: grad)
48
+ }
49
+ XCTAssertEqual ( model. dense. weight, expectedWeight, file: file, line: line)
50
+ XCTAssertEqual ( model. dense. bias, expectedBias, file: file, line: line)
51
+ }
52
+
30
53
/// Check that `model` converges after updating it with `optimizer` `stepCount` times.
31
54
func testConvergence< Opt: Optimizer > (
32
55
optimizer: Opt ,
33
56
model: Model ,
57
+ stepCount: Int = 1000 ,
34
58
file: StaticString = #file,
35
59
line: UInt = #line
36
60
) where Opt. Model == Model {
@@ -40,7 +64,7 @@ class OptimizerTests: XCTestCase {
40
64
. reshaped ( to: [ - 1 , 1 ] )
41
65
let y : Tensor < Float > = x + 1
42
66
43
- for _ in 0 ..< 1000 {
67
+ for _ in 0 ..< stepCount {
44
68
let grad = gradient ( at: model) { model -> Tensor < Float > in
45
69
let yy = model ( x)
46
70
return meanSquaredError ( predicted: yy, expected: y)
@@ -102,7 +126,7 @@ class OptimizerTests: XCTestCase {
102
126
func testRAdam( ) {
103
127
let model = Model ( )
104
128
let optimizer = RAdam ( for: model)
105
- testConvergence ( optimizer: optimizer, model: model)
129
+ testConvergence ( optimizer: optimizer, model: model, stepCount : 1400 )
106
130
}
107
131
108
132
/// A `Tensor<Float>` wrapper for testing optimizer numerical correctness.
@@ -207,10 +231,9 @@ class OptimizerTests: XCTestCase {
207
231
let optimizer = RAdam ( for: values, learningRate: 1e-3 , epsilon: 1e-7 )
208
232
// FIXME(TF-759): Investigate large differences with Python reference implementation results:
209
233
// `[ 0.46914074, -0.44463935, -0.44513944]`.
210
- // Pending fix: https://github.com/tensorflow/swift-apis/pull/700
211
234
testNumericalCorrectness (
212
235
optimizer: optimizer, startingValues: values,
213
- expectedValues: [ 443.81192 , - 443.80478 , - 443.85016 ] )
236
+ expectedValues: [ 0.44664007 , - 0.44463903 , - 0.45914108 ] )
214
237
}
215
238
216
239
static var allTests = [
0 commit comments