-
Notifications
You must be signed in to change notification settings - Fork 10.5k
/
Copy pathautodiff_function_inst.sil
57 lines (49 loc) · 3.72 KB
/
autodiff_function_inst.sil
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// RUN: %target-sil-opt %s | %FileCheck %s
// RUN: %empty-directory(%t)
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name autodiff_function
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name autodiff_function
// RUN: %target-sil-opt %t/tmp.2.sib -module-name autodiff_function | %FileCheck %s
sil_stage raw
import Swift
import Builtin
// The adjoint function emitted by the compiler. Parameter are a vector, as in
// vector-Jacobian products, and pullback values. The function is partially
// applied to a pullback struct to form a pullback, which takes a vector and
// returns vector-Jacobian products evaluated at the original parameter.
sil hidden @foo_adj : $@convention(thin) (Float, Float, Float) -> Float {
bb0(%0 : $Float, %1 : $Float, %2 : $Float):
return %2 : $Float
}
// The original function with an attribute that specifies the compiler-emitted pullback.
sil hidden [differentiable source 0 wrt 0] @foo : $@convention(thin) (Float) -> Float {
bb0(%0 : $Float):
return %0 : $Float
}
// The vector-Jacobian product function, which returns the original result and a pullback.
sil hidden @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
bb0(%0 : $Float):
%1 = function_ref @foo : $@convention(thin) (Float) -> Float
%2 = apply %1(%0) : $@convention(thin) (Float) -> Float
%3 = function_ref @foo_adj : $@convention(thin) (Float, Float, Float) -> Float
%4 = partial_apply [callee_guaranteed] %3(%0, %2) : $@convention(thin) (Float, Float, Float) -> Float
%5 = tuple (%2 : $Float, %4 : $@callee_guaranteed (Float) -> Float)
return %5 : $(Float, @callee_guaranteed (Float) -> Float)
}
sil @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float {
bb0:
%orig = function_ref @foo : $@convention(thin) (Float) -> Float
%undiffedFunc = autodiff_function [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float
%vjp = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
%diffFunc = autodiff_function [wrt 0] [order 1] %orig : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), %vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
%extractedVJP = autodiff_function_extract [vjp] [order 1] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
%extractedOriginal = autodiff_function_extract [original] %diffFunc : $@differentiable @convention(thin) (Float) -> Float
return %undiffedFunc : $@differentiable @convention(thin) (Float) -> Float
}
// CHECK-LABEL: @make_diff_func : $@convention(thin) () -> @differentiable @convention(thin) (Float) -> Float
// CHECK: [[FOO:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float
// CHECK: [[UNDIFFED_FOO:%.*]] = autodiff_function [wrt 0] [order 1] [[FOO]] : $@convention(thin) (Float) -> Float
// CHECK: [[FOO_VJP:%.*]] = function_ref @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK: [[DIFFED_FOO:%.*]] = autodiff_function [wrt 0] [order 1] [[FOO]] : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[FOO_VJP]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
// CHECK: [[EXTRACTED_VJP:%.*]] = autodiff_function_extract [vjp] [order 1] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float
// CHECK: [[EXTRACTED_ORIG:%.*]] = autodiff_function_extract [original] [[DIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float
// CHECK: return [[UNDIFFED_FOO]] : $@differentiable @convention(thin) (Float) -> Float