-
Notifications
You must be signed in to change notification settings - Fork 10.5k
/
Copy pathderivative_symbols.swift
151 lines (124 loc) · 3.87 KB
/
derivative_symbols.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=all %s -O
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing
// RUN: %target-swift-frontend -emit-ir -o/dev/null -parse-as-library -module-name test -validate-tbd-against-ir=missing %s -enable-testing -O
import _Differentiation
@differentiable(reverse)
public func topLevelDifferentiable(_ x: Float, _ y: Float) -> Float { x }
public func topLevelHasDerivative<T: Differentiable>(_ x: T) -> T {
x
}
@derivative(of: topLevelHasDerivative)
public func topLevelDerivative<T: Differentiable>(_ x: T) -> (
value: T, pullback: (T.TangentVector) -> T.TangentVector
) {
fatalError()
}
public struct Struct: Differentiable {
var stored: Float
// Test property: getter and setter.
public var property: Float {
@differentiable(reverse)
get { stored }
@differentiable(reverse)
set { stored = newValue }
}
// Test initializer.
@differentiable(reverse)
public init(_ x: Float) {
stored = x.squareRoot()
}
// Test delegating initializer.
@differentiable(reverse)
public init(blah x: Float) {
self.init(x)
}
// Test method.
public func method(_ x: Float, _ y: Float) -> Float { x }
@derivative(of: method)
public func jvpMethod(_ x: Float, _ y: Float) -> (
value: Float, differential: (TangentVector, Float, Float) -> Float
) {
fatalError()
}
// Test subscript: getter and setter.
public subscript(_ x: Float) -> Float {
@differentiable(reverse)
get { x }
@differentiable(reverse)
set { stored = newValue }
}
@derivative(of: subscript)
public func vjpSubscript(_ x: Float) -> (
value: Float, pullback: (Float) -> (TangentVector, Float)
) {
fatalError()
}
@derivative(of: subscript.set)
public mutating func vjpSubscriptSetter(_ x: Float, _ newValue: Float) -> (
value: (), pullback: (inout TangentVector) -> (Float, Float)
) {
fatalError()
}
}
extension Array where Element == Struct {
@differentiable(reverse)
public func sum() -> Float {
return 0
}
}
// SR-13866: Dispatch thunks and method descriptor mangling.
public protocol P: Differentiable {
@differentiable(reverse, wrt: self)
@differentiable(reverse, wrt: (self, x))
func method(_ x: Float) -> Float
@differentiable(reverse, wrt: self)
var property: Float { get set }
@differentiable(reverse, wrt: self)
@differentiable(reverse, wrt: (self, x))
subscript(_ x: Float) -> Float { get set }
}
public final class Class: Differentiable {
var stored: Float
// Test initializer.
// FIXME(rdar://74380324)
// @differentiable(reverse)
public init(_ x: Float) {
stored = x
}
// Test delegating initializer.
// FIXME(rdar://74380324)
// @differentiable(reverse)
// public convenience init(blah x: Float) {
// self.init(x)
// }
// Test method.
public func method(_ x: Float, _ y: Float) -> Float { x }
@derivative(of: method)
public func jvpMethod(_ x: Float, _ y: Float) -> (
value: Float, differential: (TangentVector, Float, Float) -> Float
) {
fatalError()
}
// Test subscript: getter and setter.
public subscript(_ x: Float) -> Float {
@differentiable(reverse)
get { x }
// FIXME(SR-13096)
// @differentiable(reverse)
// set { stored = newValue }
}
@derivative(of: subscript)
public func vjpSubscript(_ x: Float) -> (
value: Float, pullback: (Float) -> (TangentVector, Float)
) {
fatalError()
}
// FIXME(SR-13096)
// @derivative(of: subscript.set)
// public func vjpSubscriptSetter(_ x: Float, _ newValue: Float) -> (
// value: (), pullback: (inout TangentVector) -> (Float, Float)
// ) {
// fatalError()
// }
}