-
Notifications
You must be signed in to change notification settings - Fork 10.4k
/
Copy pathgenerics.swift
44 lines (37 loc) · 1.77 KB
/
generics.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// RUN: %target-swift-frontend -emit-sil -verify %s
struct Tensor<Scalar : FloatingPoint & Differentiable> : VectorNumeric, Differentiable {
// NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`)
// until differentiation has indirect passing support.
var value: Float
init(_ value: Float) { self.value = value }
}
func generic<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Float {
return x.value + x.value
}
_ = gradient(at: Tensor<Float>(1), in: generic)
// Test case where associated derivative function's requirements are unmet.
@differentiable(vjp: vjpWeirdExtraRequirements where T : CaseIterable, T.AllCases : ExpressibleByStringLiteral)
func weird<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
return x
}
func vjpWeirdExtraRequirements<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) where T : CaseIterable, T.AllCases : ExpressibleByStringLiteral {
return (x, { $0 })
}
func weirdWrapper<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
return weird(x) // expected-note {{function call is not differentiable because generic requirements are not met}}
}
_ = pullback(at: Tensor<Float>(1), in: weirdWrapper) // expected-error {{function is not differentiable}}
_ = pullback(at: Tensor<Float>(3), in: weirdWrapper)
// Test case where associated derivative function's requirements are met.
extension Tensor where Scalar : Numeric {
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func mean() -> Tensor {
return self
}
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
func variance() -> Tensor {
return mean() // ok
}
}
_ = pullback(at: Tensor<Float>(1), in: { $0.variance() })
// TODO: add more tests.