-
Notifications
You must be signed in to change notification settings - Fork 10.4k
/
Copy pathderivative_attr_parse.swift
134 lines (106 loc) · 3.78 KB
/
derivative_attr_parse.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// RUN: %target-swift-frontend -parse -verify %s
/// Good
@derivative(of: sin, wrt: x) // ok
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
return (x, { $0 })
}
@derivative(of: add, wrt: (x, y)) // ok
func vjpAdd(x: Float, y: Float)
-> (value: Float, pullback: (Float) -> (Float, Float)) {
return (x + y, { ($0, $0) })
}
extension AdditiveArithmetic where Self: Differentiable {
@derivative(of: +) // ok
static func vjpAdd(x: Self, y: Self)
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (x + y, { v in (v, v) })
}
}
@derivative(of: foo) // ok
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
@derivative(of: property.get) // ok
func dPropertyGetter() -> ()
@derivative(of: subscript.get) // ok
func dSubscriptGetter() -> ()
@derivative(of: subscript(_:label:).get) // ok
func dLabeledSubscriptGetter() -> ()
@derivative(of: property.set) // ok
func dPropertySetter() -> ()
@derivative(of: subscript.set) // ok
func dSubscriptSetter() -> ()
@derivative(of: subscript(_:label:).set) // ok
func dLabeledSubscriptSetter() -> ()
@derivative(of: nestedType.name) // ok
func dNestedTypeFunc() -> ()
/// Bad
// expected-error @+2 {{expected an original function name}}
// expected-error @+1 {{expected declaration}}
@derivative(of: 3)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: wrt, foo)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected a colon ':' after 'wrt'}}
@derivative(of: foo, wrt)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, blah, wrt: x)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x, blah)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo, wrt: x,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
@derivative(of: foo, foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// expected-error @+1 {{unexpected ',' separator}}
@derivative(of: foo,)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
// TF-1168: missing comma before `wrt:`.
// expected-error @+2 {{expected ',' separator}}
// expected-error @+1 {{expected declaration}}
@derivative(of: foo wrt: x)
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}
func testLocalDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}
func testLocalDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}
// expected-error @+2 {{expected ',' separator}}
// expected-error @+1 {{expected declaration}}
@derivative(of: nestedType.name.set)
func dNestedTypePropertySetter() -> ()