-
Notifications
You must be signed in to change notification settings - Fork 10.5k
/
Copy pathsimple_math.swift
76 lines (67 loc) · 1.79 KB
/
simple_math.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
#if os(macOS)
import Darwin.C
#else
import Glibc
#endif
var SimpleMathTests = TestSuite("SimpleMath")
SimpleMathTests.test("Arithmetics") {
let foo1 = { (x: Float, y: Float) -> Float in
return x * y
}
expectEqual((4, 3), gradient(at: 3, 4, in: foo1))
let foo2 = { (x: Float, y: Float) -> Float in
return -x * y
}
expectEqual((-4, -3), gradient(at: 3, 4, in: foo2))
let foo3 = { (x: Float, y: Float) -> Float in
return -x + y
}
expectEqual((-1, 1), gradient(at: 3, 4, in: foo3))
}
SimpleMathTests.test("Fanout") {
let foo1 = { (x: Float) -> Float in
x - x
}
expectEqual(0, gradient(at: 100, in: foo1))
let foo2 = { (x: Float) -> Float in
x + x
}
expectEqual(2, gradient(at: 100, in: foo2))
let foo3 = { (x: Float, y: Float) -> Float in
x + x + x * y
}
expectEqual((4, 3), gradient(at: 3, 2, in: foo3))
}
SimpleMathTests.test("FunctionCall") {
func foo(_ x: Float, _ y: Float) -> Float {
return 3 * x + { $0 * 3 }(3) * y
}
expectEqual((3, 9), gradient(at: 3, 4, in: foo))
expectEqual(3, gradient(at: 3) { x in foo(x, 4) })
}
SimpleMathTests.test("ResultSelection") {
func foo(_ x: Float, _ y: Float) -> (Float, Float) {
return (x + 1, y + 2)
}
expectEqual((1, 0), gradient(at: 3, 3, in: { x, y in foo(x, y).0 }))
expectEqual((0, 1), gradient(at: 3, 3, in: { x, y in foo(x, y).1 }))
}
SimpleMathTests.test("CaptureLocal") {
let z: Float = 10
func foo(_ x: Float) -> Float {
return z * x
}
expectEqual(10, gradient(at: 0, in: foo))
}
var globalVar: Float = 10
SimpleMathTests.test("CaptureGlobal") {
let foo: (Float) -> Float = { x in
globalVar += 20
return globalVar * x
}
expectEqual(30, gradient(at: 0, in: foo))
}
runAllTests()