-
Notifications
You must be signed in to change notification settings - Fork 10.5k
/
Copy pathwitness_method_autodiff.sil
48 lines (39 loc) · 2.59 KB
/
witness_method_autodiff.sil
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// RUN: %target-sil-opt -differentiation -assume-parsing-unqualified-ownership-sil %s | %FileCheck %s
sil_stage raw
import Builtin
import Swift
import SwiftShims
protocol DiffReq {
@differentiable(wrt: (.0))
func f(_ x: Float) -> Float
}
sil @differentiateWitnessMethod : $@convention(thin) <T where T : DiffReq> (@in_guaranteed T) -> () {
bb0(%0 : $*T):
%1 = witness_method $T, #DiffReq.f!1 : <Self where Self : DiffReq> (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
%2 = autodiff_function [wrt 0] [order 1] %1 : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
%ret = tuple ()
return %ret : $()
}
// CHECK-LABEL: sil @differentiateWitnessMethod
// CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f!1
// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.1.SU
// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.1.SU
// CHECK: autodiff_function [wrt 0] [order 1] [[ORIG_REF]] : {{.*}} with {[[JVP_REF]] : {{.*}}, [[VJP_REF]] : {{.*}}}
// CHECK: } // end sil function 'differentiateWitnessMethod'
sil @differentiatePartiallyAppliedWitnessMethod : $@convention(thin) <T where T : DiffReq> (@in_guaranteed T) -> () {
bb0(%0 : $*T):
%1 = witness_method $T, #DiffReq.f!1 : <Self where Self : DiffReq> (Self) -> (Float) -> Float : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
%2 = partial_apply [callee_guaranteed] %1<T>(%0) : $@convention(witness_method: DiffReq) <τ_0_0 where τ_0_0 : DiffReq> (Float, @in_guaranteed τ_0_0) -> Float
%3 = autodiff_function [wrt 0] [order 1] %2 : $@callee_guaranteed (Float) -> Float
%ret = tuple ()
return %ret : $()
}
// CHECK-LABEL: sil @differentiatePartiallyAppliedWitnessMethod
// CHECK: [[ORIG_REF:%.*]] = witness_method $T, #DiffReq.f!1
// CHECK: [[ORIG_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_REF]]<T>(%0)
// CHECK: [[JVP_REF:%.*]] = witness_method $T, #DiffReq.f!1.jvp.1.SU
// CHECK: [[JVP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[JVP_REF]]<T>(%0)
// CHECK: [[VJP_REF:%.*]] = witness_method $T, #DiffReq.f!1.vjp.1.SU
// CHECK: [[VJP_REF_PARTIALLY_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[VJP_REF]]<T>(%0)
// CHECK: = autodiff_function [wrt 0] [order 1] [[ORIG_REF_PARTIALLY_APPLIED]] : {{.*}} with {[[JVP_REF_PARTIALLY_APPLIED]] : {{.*}}, [[VJP_REF_PARTIALLY_APPLIED]] : {{.*}}}
// CHECK: } // end sil function 'differentiatePartiallyAppliedWitnessMethod'