Skip to content

Commit 5abb580

Browse files
authoredJul 25, 2023
[AutoDiff] Fix return type of subset parameters thunk function (#67487)
The patch resolves #67402. When the original function has a tuple result type, we should append thunkedLinearMap as the last element of the tuple to match the function declaration. Before this patch, the compiler used to wrap the original result tuple and thunkedLinearMap into another tuple, and caused the verifier error. Before the patch: return %{{.*}} : $((Float, Double), @callee_guaranteed (Float) -> X.TangentVector) After the patch: return %{{.*}} : $(Float, Double, @callee_guaranteed (Float) -> X.TangentVector)

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed
 

‎lib/SILOptimizer/Differentiation/Thunk.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
786786
// Extract all direct results.
787787
SmallVector<SILValue, 8> directResults;
788788
extractAllElements(apply, builder, directResults);
789-
auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
790-
auto originalDirectResult =
791-
joinElements(originalDirectResults, builder, apply->getLoc());
792789
auto linearMap = directResults.back();
790+
directResults.pop_back();
793791

794792
auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
795793
auto linearMapTargetType = targetType->getResults()
@@ -830,8 +828,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
830828
0);
831829
if (origFnType->getNumResults() > 0 &&
832830
origFnType->getResults().front().isFormalDirect()) {
833-
auto result =
834-
joinElements({originalDirectResult, thunkedLinearMap}, builder, loc);
831+
directResults.push_back(thunkedLinearMap);
832+
auto result = joinElements(directResults, builder, loc);
835833
builder.createReturn(loc, result);
836834
} else {
837835
builder.createReturn(loc, thunkedLinearMap);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
3+
// Verify the result type of a subset parameters thunk matches the declaration:
4+
//
5+
// CHECK: // autodiff subset parameters thunk for forward-mode derivative from f(x:)
6+
// CHECK-NEXT: sil shared [transparent] [thunk] @$s17param_thunk_tuple{{.*}} : $@convention(thin) (X)
7+
// CHECK-SAME: -> (Float, Double, @owned @callee_guaranteed (X.TangentVector) -> Float)
8+
// CHECK: return
9+
// CHECK-SAME: %{{.*}} : $(Float, Double, @callee_guaranteed (X.TangentVector) -> Float)
10+
//
11+
// CHECK: // autodiff subset parameters thunk for reverse-mode derivative from f(x:)
12+
// CHECK-NEXT: sil shared [transparent] [thunk] @$s17param_thunk_tuple{{.*}} : $@convention(thin) (X)
13+
// CHECK-SAME: -> (Float, Double, @owned @callee_guaranteed (Float) -> X.TangentVector)
14+
// CHECK: return
15+
// CHECK-SAME: %{{.*}} : $(Float, Double, @callee_guaranteed (Float) -> X.TangentVector)
16+
17+
import _Differentiation
18+
19+
struct X: Differentiable {
20+
var a: Float
21+
var b: Double
22+
}
23+
24+
@differentiable(reverse)
25+
func f(x: X) -> (Float, Double) {
26+
(x.a, x.b)
27+
}
28+
29+
@differentiable(reverse)
30+
func g1(x: X) -> Float {
31+
f(x: x).0
32+
}
33+

0 commit comments

Comments
 (0)
Please sign in to comment.