Skip to content

Commit adc92e9

Browse files
committed
IRGen: Fix preservation of error result in async dispatch thunks
1 parent 662b553 commit adc92e9

File tree

5 files changed

+56
-5
lines changed

5 files changed

+56
-5
lines changed

lib/IRGen/GenThunk.cpp

+16-3
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,10 @@ void IRGenThunk::prepareArguments() {
165165
}
166166

167167
for (unsigned i = 0, e = asyncLayout->getArgumentCount(); i < e; ++i) {
168-
Address addr = asyncLayout->getArgumentLayout(i).project(
169-
IGF, context, llvm::None);
170-
params.add(IGF.Builder.CreateLoad(addr));
168+
auto layout = asyncLayout->getArgumentLayout(i);
169+
Address addr = layout.project(IGF, context, llvm::None);
170+
auto &ti = cast<LoadableTypeInfo>(layout.getType());
171+
ti.loadAsTake(IGF, addr, params);
171172
}
172173

173174
if (asyncLayout->hasBindings()) {
@@ -329,8 +330,20 @@ void IRGenThunk::emit() {
329330
emission->emitToExplosion(result, /*isOutlined=*/false);
330331
}
331332

333+
llvm::Value *errorValue = nullptr;
334+
335+
if (isAsync && origTy->hasErrorResult()) {
336+
SILType errorType = conv.getSILErrorType(expansionContext);
337+
Address calleeErrorSlot = emission->getCalleeErrorSlot(errorType);
338+
errorValue = IGF.Builder.CreateLoad(calleeErrorSlot);
339+
}
340+
332341
emission->end();
333342

343+
if (isAsync && errorValue) {
344+
IGF.Builder.CreateStore(errorValue, IGF.getCallerErrorResultSlot());
345+
}
346+
334347
if (isAsync) {
335348
emitAsyncReturn(IGF, *asyncLayout, origTy);
336349
IGF.emitCoroutineOrAsyncExit();

test/Concurrency/Runtime/Inputs/resilient_class.swift

+9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
public enum MyError : Error {
2+
case bad
3+
}
4+
15
open class BaseClass<T> {
26
let value: T
37

@@ -9,4 +13,9 @@ open class BaseClass<T> {
913
open func wait() async -> T {
1014
return value
1115
}
16+
open func wait(orThrow: Bool) async throws {
17+
if orThrow {
18+
throw MyError.bad
19+
}
20+
}
1221
}

test/Concurrency/Runtime/Inputs/resilient_protocol.swift

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ public protocol Awaitable {
22
associatedtype Result
33
func waitForNothing() async
44
func wait() async -> Result
5+
func wait(orThrow: Bool) async throws
56
}

test/Concurrency/Runtime/class_resilience.swift

+13-2
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,22 @@ class MyDerived : BaseClass<Int> {
2424
override func wait() async -> Int {
2525
return await super.wait() * 2
2626
}
27+
28+
override func wait(orThrow: Bool) async throws {
29+
return try await super.wait(orThrow: orThrow)
30+
}
2731
}
2832

2933
func virtualWaitForNothing<T>(_ c: BaseClass<T>) async {
3034
await c.waitForNothing()
3135
}
3236

33-
func virtualWait<T>(_ t: BaseClass<T>) async -> T {
34-
return await t.wait()
37+
func virtualWait<T>(_ c: BaseClass<T>) async -> T {
38+
return await c.wait()
39+
}
40+
41+
func virtualWait<T>(orThrow: Bool, _ c: BaseClass<T>) async throws {
42+
return try await c.wait(orThrow: orThrow)
3543
}
3644

3745
var AsyncVTableMethodSuite = TestSuite("ResilientClass")
@@ -43,6 +51,9 @@ AsyncVTableMethodSuite.test("AsyncVTableMethod") {
4351
await virtualWaitForNothing(x)
4452

4553
expectEqual(642, await virtualWait(x))
54+
55+
expectNil(try? await virtualWait(orThrow: true, x))
56+
try! await virtualWait(orThrow: false, x)
4657
}
4758
}
4859

test/Concurrency/Runtime/protocol_resilience.swift

+17
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@
1616
import StdlibUnittest
1717
import resilient_protocol
1818

19+
enum MyError : Error {
20+
case bad
21+
}
22+
1923
struct IntAwaitable : Awaitable {
2024
func waitForNothing() async {}
2125

2226
func wait() async -> Int {
2327
return 123
2428
}
29+
30+
func wait(orThrow: Bool) async throws {
31+
if (orThrow) {
32+
throw MyError.bad
33+
}
34+
}
2535
}
2636

2737
func genericWaitForNothing<T : Awaitable>(_ t: T) async {
@@ -32,6 +42,10 @@ func genericWait<T : Awaitable>(_ t: T) async -> T.Result {
3242
return await t.wait()
3343
}
3444

45+
func genericWait<T : Awaitable>(orThrow: Bool, _ t: T) async throws {
46+
return try await t.wait(orThrow: orThrow)
47+
}
48+
3549
var AsyncProtocolRequirementSuite = TestSuite("ResilientProtocol")
3650

3751
AsyncProtocolRequirementSuite.test("AsyncProtocolRequirement") {
@@ -41,6 +55,9 @@ AsyncProtocolRequirementSuite.test("AsyncProtocolRequirement") {
4155
await genericWaitForNothing(x)
4256

4357
expectEqual(123, await genericWait(x))
58+
59+
expectNil(try? await genericWait(orThrow: true, x))
60+
try! await genericWait(orThrow: false, x)
4461
}
4562
}
4663

0 commit comments

Comments
 (0)