Skip to content

Commit 6ebb0ff

Browse files
committed
Replace AsyncIteratorProtocol.nextElement() with isolated next(_:)
Use an optional isolated parameter to this new `next(_:)` overload to keep it on the same actor as the caller, and pass `#isolation` when desugaring the async for..in loop. This keeps async iteration loops on the same actor, allowing non-Sendable values to be used with many async sequences.
1 parent 3558237 commit 6ebb0ff

29 files changed

+140
-116
lines changed

Diff for: include/swift/AST/ASTContext.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,8 @@ class ASTContext final {
650650
/// Get AsyncIteratorProtocol.next().
651651
FuncDecl *getAsyncIteratorNext() const;
652652

653-
/// Get AsyncIteratorProtocol.nextElement().
654-
FuncDecl *getAsyncIteratorNextElement() const;
653+
/// Get AsyncIteratorProtocol.next(actor).
654+
FuncDecl *getAsyncIteratorNextIsolated() const;
655655

656656
/// Check whether the standard library provides all the correct
657657
/// intrinsic support for Optional<T>.

Diff for: include/swift/AST/KnownIdentifiers.def

-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ IDENTIFIER(load)
123123
IDENTIFIER(main)
124124
IDENTIFIER_WITH_NAME(MainEntryPoint, "$main")
125125
IDENTIFIER(next)
126-
IDENTIFIER(nextElement)
127126
IDENTIFIER_(nsErrorDomain)
128127
IDENTIFIER(objectAtIndexedSubscript)
129128
IDENTIFIER(objectForKeyedSubscript)

Diff for: lib/AST/ASTContext.cpp

+34-23
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ struct ASTContext::Implementation {
287287
/// The declaration of 'AsyncIteratorProtocol.next()'.
288288
FuncDecl *AsyncIteratorNext = nullptr;
289289

290-
/// The declaration of 'AsyncIteratorProtocol.nextElement()'.
291-
FuncDecl *AsyncIteratorNextElement = nullptr;
290+
/// The declaration of 'AsyncIteratorProtocol.next(_:)' that takes
291+
/// an actor isolation.
292+
FuncDecl *AsyncIteratorNextIsolated = nullptr;
292293

293294
/// The declaration of Swift.Optional<T>.Some.
294295
EnumElementDecl *OptionalSomeDecl = nullptr;
@@ -951,38 +952,48 @@ FuncDecl *ASTContext::getIteratorNext() const {
951952
return nullptr;
952953
}
953954

954-
FuncDecl *ASTContext::getAsyncIteratorNext() const {
955-
if (getImpl().AsyncIteratorNext) {
956-
return getImpl().AsyncIteratorNext;
957-
}
958-
959-
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
955+
static std::pair<FuncDecl *, FuncDecl *>
956+
getAsyncIteratorNextRequirements(const ASTContext &ctx) {
957+
auto proto = ctx.getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
960958
if (!proto)
961-
return nullptr;
959+
return { nullptr, nullptr };
962960

963-
if (auto *func = lookupRequirement(proto, Id_next)) {
964-
getImpl().AsyncIteratorNext = func;
965-
return func;
961+
FuncDecl *next = nullptr;
962+
FuncDecl *nextThrowing = nullptr;
963+
for (auto result : proto->lookupDirect(ctx.Id_next)) {
964+
if (result->getDeclContext() != proto)
965+
continue;
966+
967+
if (auto func = dyn_cast<FuncDecl>(result)) {
968+
switch (func->getParameters()->size()) {
969+
case 0: next = func; break;
970+
case 1: nextThrowing = func; break;
971+
default: break;
972+
}
973+
}
966974
}
967975

968-
return nullptr;
976+
return { next, nextThrowing };
969977
}
970978

971-
FuncDecl *ASTContext::getAsyncIteratorNextElement() const {
972-
if (getImpl().AsyncIteratorNextElement) {
973-
return getImpl().AsyncIteratorNextElement;
979+
FuncDecl *ASTContext::getAsyncIteratorNext() const {
980+
if (getImpl().AsyncIteratorNext) {
981+
return getImpl().AsyncIteratorNext;
974982
}
975983

976-
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
977-
if (!proto)
978-
return nullptr;
984+
auto next = getAsyncIteratorNextRequirements(*this).first;
985+
getImpl().AsyncIteratorNext = next;
986+
return next;
987+
}
979988

980-
if (auto *func = lookupRequirement(proto, Id_nextElement)) {
981-
getImpl().AsyncIteratorNextElement = func;
982-
return func;
989+
FuncDecl *ASTContext::getAsyncIteratorNextIsolated() const {
990+
if (getImpl().AsyncIteratorNextIsolated) {
991+
return getImpl().AsyncIteratorNextIsolated;
983992
}
984993

985-
return nullptr;
994+
auto nextThrowing = getAsyncIteratorNextRequirements(*this).second;
995+
getImpl().AsyncIteratorNextIsolated = nextThrowing;
996+
return nextThrowing;
986997
}
987998

988999
#define KNOWN_STDLIB_TYPE_DECL(NAME, DECL_CLASS, NUM_GENERIC_PARAMS) \

Diff for: lib/Sema/CSGen.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -4603,7 +4603,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
46034603
nextId, /*labels=*/ArrayRef<Identifier>());
46044604
nextRef->setFunctionRefKind(FunctionRefKind::SingleApply);
46054605

4606-
Expr *nextCall = CallExpr::createImplicitEmpty(ctx, nextRef);
4606+
ArgumentList *nextArgs;
4607+
if (nextFn && nextFn->getParameters()->size() == 1) {
4608+
auto isolationArg =
4609+
new (ctx) CurrentContextIsolationExpr(stmt->getForLoc(), Type());
4610+
nextArgs = ArgumentList::forImplicitUnlabeled(ctx, { isolationArg });
4611+
} else {
4612+
nextArgs = ArgumentList::createImplicit(ctx, {});
4613+
}
4614+
Expr *nextCall = CallExpr::createImplicit(ctx, nextRef, nextArgs);
46074615

46084616
// `next` is always async but witness might not be throwing
46094617
if (isAsync) {

Diff for: lib/Sema/CSSimplify.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -10705,7 +10705,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
1070510705
// Handle `next` reference.
1070610706
if (getContextualTypePurpose(baseExpr) == CTP_ForEachSequence &&
1070710707
(isRefTo(memberRef, ctx.Id_next, /*labels=*/{}) ||
10708-
isRefTo(memberRef, ctx.Id_nextElement, /*labels=*/{}))) {
10708+
isRefTo(memberRef, ctx.Id_next, /*labels=*/{StringRef()}))) {
1070910709
auto *iteratorProto = cast<ProtocolDecl>(
1071010710
getContextualType(baseExpr, /*forConstraint=*/false)
1071110711
->getAnyNominal());
@@ -10931,8 +10931,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyMemberConstraint(
1093110931
if (auto *base = dyn_cast<DeclRefExpr>(UDE->getBase())) {
1093210932
if (auto var = dyn_cast_or_null<VarDecl>(base->getDecl())) {
1093310933
if (var->getNameStr().contains("$generator") &&
10934-
(UDE->getName().getBaseIdentifier() == Context.Id_next ||
10935-
UDE->getName().getBaseIdentifier() == Context.Id_nextElement))
10934+
(UDE->getName().getBaseIdentifier() == Context.Id_next))
1093610935
return success();
1093710936
}
1093810937
}

Diff for: lib/Sema/TypeCheckEffects.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1335,7 +1335,7 @@ class ApplyClassifier {
13351335
isRethrowLikeTypedThrows(fnRef.getFunction())) {
13361336
// If we are in a rethrowing context and the function we're referring
13371337
// to is a rethrow-like function using typed throws or we are
1338-
// calling the next() or nextElement() of an async iterator,
1338+
// calling the next() or next(_:) of an async iterator,
13391339
// then look at all of the closure arguments.
13401340
LLVM_FALLTHROUGH;
13411341
} else {

Diff for: lib/Sema/TypeCheckStmt.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -3237,17 +3237,17 @@ FuncDecl *TypeChecker::getForEachIteratorNextFunction(
32373237
if (!isAsync)
32383238
return ctx.getIteratorNext();
32393239

3240-
// If AsyncIteratorProtocol.nextElement() isn't available at all,
3240+
// If AsyncIteratorProtocol.next(_:) isn't available at all,
32413241
// we're stuck using AsyncIteratorProtocol.next().
3242-
auto nextElement = ctx.getAsyncIteratorNextElement();
3242+
auto nextElement = ctx.getAsyncIteratorNextIsolated();
32433243
if (!nextElement)
32443244
return ctx.getAsyncIteratorNext();
32453245

3246-
// If availability checking is disabled, use nextElement().
3246+
// If availability checking is disabled, use next(_:).
32473247
if (ctx.LangOpts.DisableAvailabilityChecking || loc.isInvalid())
32483248
return nextElement;
32493249

3250-
// We can only call nextElement() if we are in an availability context
3250+
// We can only call next(_:) if we are in an availability context
32513251
// that supports typed throws.
32523252
auto availability = overApproximateAvailabilityAtLocation(loc, dc);
32533253
if (availability.isContainedIn(ctx.getTypedThrowsAvailability()))

Diff for: stdlib/public/Concurrency/AsyncCompactMapSequence.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ extension AsyncCompactMapSequence: AsyncSequence {
140140
/// that transforms to a non-`nil` value.
141141
@available(SwiftStdlib 5.11, *)
142142
@inlinable
143-
public mutating func nextElement() async throws(Failure) -> ElementOfResult? {
143+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> ElementOfResult? {
144144
while true {
145-
guard let element = try await baseIterator.nextElement() else {
145+
guard let element = try await baseIterator.next(actor) else {
146146
return nil
147147
}
148148

Diff for: stdlib/public/Concurrency/AsyncDropFirstSequence.swift

+5-5
Original file line numberDiff line numberDiff line change
@@ -118,24 +118,24 @@ extension AsyncDropFirstSequence: AsyncSequence {
118118
/// Produces the next element in the drop-first sequence.
119119
///
120120
/// Until reaching the number of elements to drop, this iterator calls
121-
/// `nextElement()` on its base iterator and discards the result. If the
121+
/// `next(_:)` on its base iterator and discards the result. If the
122122
/// base iterator returns `nil`, indicating the end of the sequence, this
123123
/// iterator returns `nil`. After reaching the number of elements to drop,
124-
/// this iterator passes along the result of calling `nextElement()` on the
124+
/// this iterator passes along the result of calling `next(_:)` on the
125125
/// base iterator.
126126
@available(SwiftStdlib 5.11, *)
127127
@inlinable
128-
public mutating func nextElement() async throws(Failure) -> Base.Element? {
128+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Base.Element? {
129129
var remainingToDrop = count
130130
while remainingToDrop > 0 {
131-
guard try await baseIterator.nextElement() != nil else {
131+
guard try await baseIterator.next(actor) != nil else {
132132
count = 0
133133
return nil
134134
}
135135
remainingToDrop -= 1
136136
}
137137
count = 0
138-
return try await baseIterator.nextElement()
138+
return try await baseIterator.next(actor)
139139
}
140140
}
141141

Diff for: stdlib/public/Concurrency/AsyncDropWhileSequence.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ extension AsyncDropWhileSequence: AsyncSequence {
129129

130130
/// Produces the next element in the drop-while sequence.
131131
///
132-
/// This iterator calls `nextElement()` on its base iterator and evaluates
132+
/// This iterator calls `next(_:)` on its base iterator and evaluates
133133
/// the result with the `predicate` closure. As long as the predicate
134134
/// returns `true`, this method returns `nil`. After the predicate returns
135135
/// `false`, for a value received from the base iterator, this method
@@ -138,17 +138,17 @@ extension AsyncDropWhileSequence: AsyncSequence {
138138
/// again.
139139
@available(SwiftStdlib 5.11, *)
140140
@inlinable
141-
public mutating func nextElement() async throws(Failure) -> Base.Element? {
141+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Base.Element? {
142142
while let predicate = self.predicate {
143-
guard let element = try await baseIterator.nextElement() else {
143+
guard let element = try await baseIterator.next(actor) else {
144144
return nil
145145
}
146146
if await predicate(element) == false {
147147
self.predicate = nil
148148
return element
149149
}
150150
}
151-
return try await baseIterator.nextElement()
151+
return try await baseIterator.next(actor)
152152
}
153153
}
154154

Diff for: stdlib/public/Concurrency/AsyncFilterSequence.swift

+5-5
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,16 @@ extension AsyncFilterSequence: AsyncSequence {
116116

117117
/// Produces the next element in the filter sequence.
118118
///
119-
/// This iterator calls `nextelement()` on its base iterator; if this call
120-
/// returns `nil`, `nextElement()` returns nil. Otherwise, `nextElement()`
119+
/// This iterator calls `next()` on its base iterator; if this call
120+
/// returns `nil`, `next()` returns nil. Otherwise, `next()`
121121
/// evaluates the result with the `predicate` closure. If the closure
122-
/// returns `true`, `nextElement()` returns the received element; otherwise
122+
/// returns `true`, `next()` returns the received element; otherwise
123123
/// it awaits the next element from the base iterator.
124124
@available(SwiftStdlib 5.11, *)
125125
@inlinable
126-
public mutating func nextElement() async throws(Failure) -> Base.Element? {
126+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Base.Element? {
127127
while true {
128-
guard let element = try await baseIterator.nextElement() else {
128+
guard let element = try await baseIterator.next(actor) else {
129129
return nil
130130
}
131131
if await isIncluded(element) {

Diff for: stdlib/public/Concurrency/AsyncFlatMapSequence.swift

+8-8
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,20 @@ extension AsyncFlatMapSequence: AsyncSequence {
270270

271271
/// Produces the next element in the flat map sequence.
272272
///
273-
/// This iterator calls `nextElement()` on its base iterator; if this call
274-
/// returns `nil`, `nextElement()` returns `nil`. Otherwise, `nextElement()`
273+
/// This iterator calls `next()` on its base iterator; if this call
274+
/// returns `nil`, `next()` returns `nil`. Otherwise, `next()`
275275
/// calls the transforming closure on the received element, takes the
276276
/// resulting asynchronous sequence, and creates an asynchronous iterator
277-
/// from it. `nextElement()` then consumes values from this iterator until
278-
/// it terminates. At this point, `nextElement()` is ready to receive the
277+
/// from it. `next()` then consumes values from this iterator until
278+
/// it terminates. At this point, `next()` is ready to receive the
279279
/// next value from the base sequence.
280280
@available(SwiftStdlib 5.11, *)
281281
@inlinable
282-
public mutating func nextElement() async throws(Failure) -> SegmentOfResult.Element? {
282+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> SegmentOfResult.Element? {
283283
while !finished {
284284
if var iterator = currentIterator {
285285
do throws(any Error) {
286-
let optElement = try await iterator.nextElement()
286+
let optElement = try await iterator.next(actor)
287287
guard let element = optElement else {
288288
currentIterator = nil
289289
continue
@@ -296,15 +296,15 @@ extension AsyncFlatMapSequence: AsyncSequence {
296296
throw error as! Failure
297297
}
298298
} else {
299-
let optItem = try await baseIterator.nextElement()
299+
let optItem = try await baseIterator.next(actor)
300300
guard let item = optItem else {
301301
finished = true
302302
return nil
303303
}
304304
do throws(any Error) {
305305
let segment = await transform(item)
306306
var iterator = segment.makeAsyncIterator()
307-
let optElement = try await iterator.nextElement()
307+
let optElement = try await iterator.next(actor)
308308
guard let element = optElement else {
309309
currentIterator = nil
310310
continue

Diff for: stdlib/public/Concurrency/AsyncIteratorProtocol.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,16 @@ public protocol AsyncIteratorProtocol<Element, Failure> {
106106
/// - Returns: The next element, if it exists, or `nil` to signal the end of
107107
/// the sequence.
108108
@available(SwiftStdlib 5.11, *)
109-
mutating func nextElement() async throws(Failure) -> Element?
109+
mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Element?
110110
}
111111

112112
@available(SwiftStdlib 5.1, *)
113113
extension AsyncIteratorProtocol {
114-
/// Default implementation of `nextElement()` in terms of `next()`, which is
114+
/// Default implementation of `next()` in terms of `next()`, which is
115115
/// required to maintain backward compatibility with existing async iterators.
116116
@available(SwiftStdlib 5.11, *)
117117
@inlinable
118-
public mutating func nextElement() async throws(Failure) -> Element? {
118+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Element? {
119119
do {
120120
return try await next()
121121
} catch {

Diff for: stdlib/public/Concurrency/AsyncMapSequence.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ extension AsyncMapSequence: AsyncSequence {
125125
/// calling the transforming closure on the received element.
126126
@available(SwiftStdlib 5.11, *)
127127
@inlinable
128-
public mutating func nextElement() async throws(Failure) -> Transformed? {
129-
guard let element = try await baseIterator.nextElement() else {
128+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Transformed? {
129+
guard let element = try await baseIterator.next(actor) else {
130130
return nil
131131
}
132132
return await transform(element)

Diff for: stdlib/public/Concurrency/AsyncPrefixSequence.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ extension AsyncPrefixSequence: AsyncSequence {
112112
/// Produces the next element in the prefix sequence.
113113
///
114114
/// Until reaching the number of elements to include, this iterator calls
115-
/// `nextElement()` on its base iterator and passes through the
115+
/// `next()` on its base iterator and passes through the
116116
/// result. After reaching the maximum number of elements, subsequent calls
117-
/// to `nextElement()` return `nil`.
117+
/// to `next()` return `nil`.
118118
@available(SwiftStdlib 5.11, *)
119119
@inlinable
120-
public mutating func nextElement() async throws(Failure) -> Base.Element? {
120+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Base.Element? {
121121
if remaining != 0 {
122122
remaining &-= 1
123-
return try await baseIterator.nextElement()
123+
return try await baseIterator.next(actor)
124124
} else {
125125
return nil
126126
}

Diff for: stdlib/public/Concurrency/AsyncPrefixWhileSequence.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ extension AsyncPrefixWhileSequence: AsyncSequence {
129129
/// `nil`, ending the sequence.
130130
@available(SwiftStdlib 5.11, *)
131131
@inlinable
132-
public mutating func nextElement() async throws(Failure) -> Base.Element? {
133-
if !predicateHasFailed, let nextElement = try await baseIterator.nextElement() {
132+
public mutating func next(_ actor: isolated (any Actor)?) async throws(Failure) -> Base.Element? {
133+
if !predicateHasFailed, let nextElement = try await baseIterator.next(actor) {
134134
if await predicate(nextElement) {
135135
return nextElement
136136
} else {

0 commit comments

Comments
 (0)