-
Notifications
You must be signed in to change notification settings - Fork 10.4k
/
Copy pathcustom_derivatives.swift
73 lines (64 loc) · 1.98 KB
/
custom_derivatives.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
#if os(macOS)
import Darwin.C
#else
import Glibc
#endif
var CustomDerivativesTests = TestSuite("CustomDerivatives")
// Specify non-differentiable functions.
// These will be wrapped in `differentiableFunction` and tested.
func unary(_ x: Float) -> Float {
var x = x
x *= 2
return x
}
func binary(_ x: Float, _ y: Float) -> Float {
var x = x
x *= y
return x
}
CustomDerivativesTests.test("differentiableFunction-unary") {
let diffableUnary = differentiableFunction { x in
(value: unary(x), pullback: { v in v * x * 2 })
}
expectEqual(20, gradient(at: 10, in: diffableUnary))
// Test differentiation of @differentiable function.
expectEqual(20, gradient(at: 10, in: { diffableUnary($0) }))
expectEqual(40, gradient(at: 10, in: { diffableUnary($0) * 2 }))
}
CustomDerivativesTests.test("differentiableFunction-binary") {
let diffableBinary = differentiableFunction { (x, y) in
(value: binary(x, y), pullback: { v in (v * y, v * x) })
}
expectEqual((10, 5), gradient(at: 5, 10, in: diffableBinary))
// Test differentiation of @differentiable function.
expectEqual((10, 5), gradient(at: 5, 10, in: { diffableBinary($0, $1) }))
expectEqual((20, 10), gradient(at: 5, 10, in: { diffableBinary($0, $1) * 2 }))
}
CustomDerivativesTests.test("Checkpointing") {
var count = 0
func f(_ x: Float) -> Float {
count += 1
return x * x * x
}
// Test the top-level function variant of the checkpointing API.
expectEqual(324, gradient(at: 3) { (x: Float) -> Float in
expectEqual(0, count)
let y = withRecomputationInPullbacks(f)(x)
expectEqual(1, count)
return y * 3 * x
})
expectEqual(2, count)
// Reset and test the method variant.
count = 0
expectEqual(324, gradient(at: 3) { (x: Float) -> Float in
expectEqual(0, count)
let y = x.withRecomputationInPullbacks(f)
expectEqual(1, count)
return y * 3 * x
})
expectEqual(2, count)
}
runAllTests()