This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathOptimizerTests.swift
113 lines (97 loc) · 3.4 KB
/
OptimizerTests.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import XCTest
import TensorFlow
class OptimizerTests: XCTestCase {
struct Model: Layer {
var dense1 = Dense<Float>(weight: [[0.8]], bias: [0.8], activation: identity)
@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
dense1(input)
}
}
func convergenceTest<Opt: Optimizer>(
optimizer: Opt,
model: Model,
file: StaticString = #file,
line: UInt = #line
) where Opt.Model == Model {
var optimizer = optimizer
var model = model
let x: Tensor<Float> = Tensor(rangeFrom: -1, to: 1, stride: 0.01)
.reshaped(to: [-1, 1])
let y: Tensor<Float> = x + 1
for _ in 0..<1000 {
let grad = gradient(at: model) { model -> Tensor<Float> in
let yy = model(x)
return meanSquaredError(predicted: yy, expected: y)
}
optimizer.update(&model, along: grad)
if model(x).isAlmostEqual(to: y) {
break
}
}
XCTAssertTrue(model(x).isAlmostEqual(to: y), file: file, line: line)
}
func testSGD() {
let model = Model()
let optimizer = SGD(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
func testRMSProp() {
let model = Model()
let optimizer = RMSProp(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
func testAdaGrad() {
let model = Model()
let optimizer = AdaGrad(for: model, learningRate: 0.01)
convergenceTest(optimizer: optimizer, model: model)
}
func testAdaDelta() {
let model = Model()
let optimizer = AdaDelta(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
func testAdam() {
let model = Model()
let optimizer = Adam(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
func testAdaMax() {
let model = Model()
let optimizer = AdaMax(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
func testAMSGrad() {
let model = Model()
let optimizer = AMSGrad(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
func testRAdam() {
let model = Model()
let optimizer = RAdam(for: model)
convergenceTest(optimizer: optimizer, model: model)
}
static var allTests = [
("testSGD", testSGD),
("testRMSProp", testRMSProp),
("testAdaGrad", testAdaGrad),
("testAdaDelta", testAdaDelta),
("testAdam", testAdam),
("testAdaMax", testAdaMax),
("testAMSGrad", testAMSGrad),
("testRAdam", testRAdam),
]
}