-
Notifications
You must be signed in to change notification settings - Fork 10.5k
/
Copy pathderivative_registration.swift
75 lines (65 loc) · 2.02 KB
/
derivative_registration.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// RUN: %target-run-simple-swift
// REQUIRES: executable_test
import StdlibUnittest
var DerivativeRegistrationTests = TestSuite("DerivativeRegistration")
@_semantics("autodiff.opaque")
func unary(x: Float) -> Float {
return x
}
@differentiating(unary)
func _vjpUnary(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (value: x, pullback: { v in v })
}
DerivativeRegistrationTests.test("UnaryFreeFunction") {
expectEqual(1, gradient(at: 3.0, in: unary))
}
@_semantics("autodiff.opaque")
func multiply(_ x: Float, _ y: Float) -> Float {
return x * y
}
@differentiating(multiply)
func _vjpMultiply(_ x: Float, _ y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x * y, { v in (v * y, v * x) })
}
DerivativeRegistrationTests.test("BinaryFreeFunction") {
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in multiply(x, y) }))
}
struct Wrapper : Differentiable {
var float: Float
}
extension Wrapper {
@_semantics("autodiff.opaque")
static func multiply(_ x: Float, _ y: Float) -> Float {
return x * y
}
@differentiating(multiply)
static func _vjpMultiply(_ x: Float, _ y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x * y, { v in (v * y, v * x) })
}
}
DerivativeRegistrationTests.test("StaticMethod") {
expectEqual((3.0, 2.0), gradient(at: 2.0, 3.0, in: { x, y in Wrapper.multiply(x, y) }))
}
extension Wrapper {
@_semantics("autodiff.opaque")
func multiply(_ x: Float) -> Float {
return float * x
}
@differentiating(multiply)
func _vjpMultiply(_ x: Float)
-> (value: Float, pullback: (Float) -> (Wrapper.TangentVector, Float)) {
return (float * x, { v in
(Wrapper.TangentVector(float: v * x), v * self.float)
})
}
}
DerivativeRegistrationTests.test("InstanceMethod") {
let x: Float = 2
let wrapper = Wrapper(float: 3)
let (𝛁wrapper, 𝛁x) = wrapper.gradient(at: x) { wrapper, x in wrapper.multiply(x) }
expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper)
expectEqual(3, 𝛁x)
}
runAllTests()