Skip to content

Commit 345c221

Browse files
authored
[Concurrency] Distributed actor's unownedExecutor should be optional (#64499)
1 parent 230dfcc commit 345c221

25 files changed

+575
-72
lines changed

include/swift/AST/Decl.h

+1
Original file line numberDiff line numberDiff line change
@@ -4533,6 +4533,7 @@ class ClassDecl final : public NominalTypeDecl {
45334533

45344534
/// Fetch this class's unownedExecutor property, if it has one.
45354535
const VarDecl *getUnownedExecutorProperty() const;
4536+
const VarDecl *getLocalUnownedExecutorProperty() const;
45364537

45374538
/// Is this the NSObject class type?
45384539
bool isNSObject() const;

include/swift/AST/KnownIdentifiers.def

+2
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ IDENTIFIER(className)
231231
IDENTIFIER(_defaultActorInitialize)
232232
IDENTIFIER(_defaultActorDestroy)
233233
IDENTIFIER(unownedExecutor)
234+
IDENTIFIER(localUnownedExecutor)
235+
IDENTIFIER(_unwrapLocalUnownedExecutor)
234236

235237
IDENTIFIER_(ErrorType)
236238
IDENTIFIER(Code)

include/swift/AST/KnownSDKDecls.def

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#endif
2121

2222
KNOWN_SDK_FUNC_DECL(Distributed, IsRemoteDistributedActor, "__isRemoteActor")
23+
KNOWN_SDK_FUNC_DECL(Distributed, IsLocalDistributedActor, "__isLocalActor")
2324

2425
#undef KNOWN_SDK_FUNC_DECL
2526

include/swift/SIL/SILFunction.h

+1
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ class SILFunction
12731273
const SILBasicBlock *getEntryBlock() const { return &front(); }
12741274

12751275
SILBasicBlock *createBasicBlock();
1276+
SILBasicBlock *createBasicBlock(llvm::StringRef debugName);
12761277
SILBasicBlock *createBasicBlockAfter(SILBasicBlock *afterBB);
12771278
SILBasicBlock *createBasicBlockBefore(SILBasicBlock *beforeBB);
12781279

lib/AST/Decl.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -9625,6 +9625,29 @@ const VarDecl *ClassDecl::getUnownedExecutorProperty() const {
96259625
return nullptr;
96269626
}
96279627

9628+
const VarDecl *ClassDecl::getLocalUnownedExecutorProperty() const {
9629+
auto &C = getASTContext();
9630+
9631+
if (!isDistributedActor())
9632+
return nullptr;
9633+
9634+
llvm::SmallVector<ValueDecl *, 2> results;
9635+
this->lookupQualified(getSelfNominalTypeDecl(),
9636+
DeclNameRef(C.Id_localUnownedExecutor),
9637+
NL_ProtocolMembers,
9638+
results);
9639+
9640+
for (auto candidate: results) {
9641+
if (isa<ProtocolDecl>(candidate->getDeclContext()))
9642+
continue;
9643+
9644+
if (VarDecl *var = dyn_cast<VarDecl>(candidate))
9645+
return var;
9646+
}
9647+
9648+
return nullptr;
9649+
}
9650+
96289651
bool ClassDecl::isRootDefaultActor() const {
96299652
return isRootDefaultActor(getModuleContext(), ResilienceExpansion::Maximal);
96309653
}

lib/SIL/IR/SILFunction.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,13 @@ SILBasicBlock *SILFunction::createBasicBlock() {
533533
return newBlock;
534534
}
535535

536+
SILBasicBlock *SILFunction::createBasicBlock(llvm::StringRef debugName) {
537+
SILBasicBlock *newBlock = new (getModule()) SILBasicBlock(this);
538+
newBlock->setDebugName(debugName);
539+
BlockList.push_back(newBlock);
540+
return newBlock;
541+
}
542+
536543
SILBasicBlock *SILFunction::createBasicBlockAfter(SILBasicBlock *afterBB) {
537544
SILBasicBlock *newBlock = new (getModule()) SILBasicBlock(this);
538545
BlockList.insertAfter(afterBB->getIterator(), newBlock);

lib/SILOptimizer/Mandatory/LowerHopToActor.cpp

+51-4
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,28 @@ static AccessorDecl *getUnownedExecutorGetter(ASTContext &ctx,
154154
return nullptr;
155155
}
156156

157+
static AccessorDecl *getUnwrapLocalUnownedExecutorGetter(ASTContext &ctx,
158+
ProtocolDecl *actorProtocol) {
159+
for (auto member: actorProtocol->getAllMembers()) { // FIXME: remove this, just go to the extension
160+
if (auto var = dyn_cast<VarDecl>(member)) {
161+
if (var->getName() == ctx.Id__unwrapLocalUnownedExecutor)
162+
return var->getAccessor(AccessorKind::Get);
163+
}
164+
}
165+
166+
for (auto extension: actorProtocol->getExtensions()) {
167+
for (auto member: extension->getAllMembers()) {
168+
if (auto var = dyn_cast<VarDecl>(member)) {
169+
if (var->getName() == ctx.Id__unwrapLocalUnownedExecutor) {
170+
return var->getAccessor(AccessorKind::Get);
171+
}
172+
}
173+
}
174+
}
175+
176+
return nullptr;
177+
}
178+
157179
SILValue LowerHopToActor::emitGetExecutor(SILBuilderWithScope &B,
158180
SILLocation loc, SILValue actor,
159181
bool makeOptional) {
@@ -186,11 +208,36 @@ SILValue LowerHopToActor::emitGetExecutor(SILBuilderWithScope &B,
186208
unmarkedExecutor =
187209
B.createBuiltin(loc, builtinName, resultType, subs, {actor});
188210

189-
// Otherwise, go through Actor.unownedExecutor.
211+
// Otherwise, go through Actor.unownedExecutor.
212+
} else if (actorType->isDistributedActor()) {
213+
auto actorKind = KnownProtocolKind::DistributedActor;
214+
auto actorProtocol = ctx.getProtocol(actorKind);
215+
auto req = getUnwrapLocalUnownedExecutorGetter(ctx, actorProtocol);
216+
assert(req && "Distributed library broken");
217+
SILDeclRef fn(req, SILDeclRef::Kind::Func);
218+
219+
auto actorConf = module->lookupConformance(actorType, actorProtocol);
220+
assert(actorConf &&
221+
"hop_to_executor with distributed actor that doesn't conform to DistributedActor");
222+
223+
auto subs = SubstitutionMap::get(req->getGenericSignature(),
224+
{actorType}, {actorConf});
225+
auto fnType = F->getModule().Types.getConstantFunctionType(*F, fn);
226+
227+
auto witness =
228+
B.createWitnessMethod(loc, actorType, actorConf, fn,
229+
SILType::getPrimitiveObjectType(fnType));
230+
auto witnessCall = B.createApply(loc, witness, subs, {actor});
231+
232+
// The protocol requirement returns an Optional<UnownedSerialExecutor>;
233+
// extract the Builtin.Executor from it.
234+
auto executorDecl = ctx.getUnownedSerialExecutorDecl();
235+
auto executorProps = executorDecl->getStoredProperties();
236+
assert(executorProps.size() == 1);
237+
unmarkedExecutor =
238+
B.createStructExtract(loc, witnessCall, executorProps[0]);
190239
} else {
191-
auto actorKind = actorType->isDistributedActor() ?
192-
KnownProtocolKind::DistributedActor :
193-
KnownProtocolKind::Actor;
240+
auto actorKind = KnownProtocolKind::Actor;
194241
auto actorProtocol = ctx.getProtocol(actorKind);
195242
auto req = getUnownedExecutorGetter(ctx, actorProtocol);
196243
assert(req && "Concurrency library broken");

lib/Sema/DerivedConformanceDistributedActor.cpp

+62-27
Original file line numberDiff line numberDiff line change
@@ -621,10 +621,13 @@ static Expr *constructDistributedUnownedSerialExecutor(ASTContext &ctx,
621621
}
622622

623623
static std::pair<BraceStmt *, bool>
624-
deriveBodyDistributedActor_unownedExecutor(AbstractFunctionDecl *getter, void *) {
625-
// var unownedExecutor: UnownedSerialExecutor {
624+
deriveBodyDistributedActor_localUnownedExecutor(AbstractFunctionDecl *getter, void *) {
625+
// var localUnownedExecutor: UnownedSerialExecutor? {
626626
// get {
627-
// return Builtin.buildDefaultActorExecutorRef(self)
627+
// guard __isLocalActor(self) else {
628+
// return nil
629+
// }
630+
// return Optional(Builtin.buildDefaultActorExecutorRef(self))
628631
// }
629632
// }
630633
ASTContext &ctx = getter->getASTContext();
@@ -641,33 +644,53 @@ deriveBodyDistributedActor_unownedExecutor(AbstractFunctionDecl *getter, void *)
641644
Expr *selfArg = DerivedConformance::createSelfDeclRef(getter);
642645
selfArg->setType(selfType);
643646

647+
// Prepare the builtin call, we'll use it after the guard, but want to take the type
648+
// of its return type earlier, so we prepare it here.
649+
644650
// The builtin call gives us a Builtin.Executor.
645651
auto builtinCall =
646652
DerivedConformance::createBuiltinCall(ctx,
647653
BuiltinValueKind::BuildDefaultActorExecutorRef,
648654
{selfType}, {}, {selfArg});
649-
650655
// Turn that into an UnownedSerialExecutor.
651656
auto initCall = constructDistributedUnownedSerialExecutor(ctx, builtinCall);
652657
if (!initCall) return failure();
653658

654-
auto ret = new (ctx) ReturnStmt(SourceLoc(), initCall, /*implicit*/ true);
659+
// guard __isLocalActor(self) else {
660+
// return nil
661+
// }
662+
auto isLocalActorDecl = ctx.getIsLocalDistributedActor();
663+
DeclRefExpr *isLocalActorExpr =
664+
new (ctx) DeclRefExpr(ConcreteDeclRef(isLocalActorDecl), DeclNameLoc(), /*implicit=*/true,
665+
AccessSemantics::Ordinary,
666+
FunctionType::get({AnyFunctionType::Param(ctx.getAnyObjectType())},
667+
ctx.getBoolType()));
668+
Expr *selfForIsLocalArg = DerivedConformance::createSelfDeclRef(getter);
669+
selfForIsLocalArg->setType(selfType);
670+
auto *argListForIsLocal =
671+
ArgumentList::forImplicitSingle(ctx, Identifier(),
672+
ErasureExpr::create(ctx, selfForIsLocalArg, ctx.getAnyObjectType(), {}, {}));
673+
CallExpr *isLocalActorCall = CallExpr::createImplicit(ctx, isLocalActorExpr, argListForIsLocal);
674+
isLocalActorCall->setType(ctx.getBoolType());
675+
isLocalActorCall->setThrows(false);
676+
auto returnNilIfRemoteStmt = DerivedConformance::returnNilIfFalseGuardTypeChecked(
677+
ctx, isLocalActorCall, /*optionalWrappedType=*/initCall->getType());
678+
679+
680+
// Finalize preparing the unowned executor for returning.
681+
auto wrappedCall = new (ctx) InjectIntoOptionalExpr(initCall, initCall->getType()->wrapInOptionalType());
682+
683+
auto ret = new (ctx) ReturnStmt(SourceLoc(), wrappedCall, /*implicit*/ true);
655684

656685
auto body = BraceStmt::create(
657-
ctx, SourceLoc(), { ret }, SourceLoc(), /*implicit=*/true);
686+
ctx, SourceLoc(), { returnNilIfRemoteStmt, ret }, SourceLoc(), /*implicit=*/true);
658687
return { body, /*isTypeChecked=*/true };
659688
}
660689

661-
/// Derive the declaration of DistributedActor's unownedExecutor property.
662-
static ValueDecl *deriveDistributedActor_unownedExecutor(DerivedConformance &derived) {
690+
/// Derive the declaration of DistributedActor's localUnownedExecutor property.
691+
static ValueDecl *deriveDistributedActor_localUnownedExecutor(DerivedConformance &derived) {
663692
ASTContext &ctx = derived.Context;
664693

665-
if (auto classDecl = dyn_cast<ClassDecl>(derived.Nominal)) {
666-
if (auto existing = classDecl->getUnownedExecutorProperty()) {
667-
return const_cast<VarDecl*>(existing);
668-
}
669-
}
670-
671694
// Retrieve the types and declarations we'll need to form this operation.
672695
auto executorDecl = ctx.getUnownedSerialExecutorDecl();
673696
if (!executorDecl) {
@@ -676,16 +699,28 @@ static ValueDecl *deriveDistributedActor_unownedExecutor(DerivedConformance &der
676699
return nullptr;
677700
}
678701
Type executorType = executorDecl->getDeclaredInterfaceType();
702+
Type optionalExecutorType = executorType->wrapInOptionalType();
703+
704+
if (auto classDecl = dyn_cast<ClassDecl>(derived.Nominal)) {
705+
if (auto existing = classDecl->getLocalUnownedExecutorProperty()) {
706+
if (existing->getInterfaceType()->isEqual(optionalExecutorType)) {
707+
return const_cast<VarDecl *>(existing);
708+
} else {
709+
// bad type, should be diagnosed elsewhere
710+
return nullptr;
711+
}
712+
}
713+
}
679714

680715
auto propertyPair = derived.declareDerivedProperty(
681-
DerivedConformance::SynthesizedIntroducer::Var, ctx.Id_unownedExecutor,
682-
executorType, executorType,
716+
DerivedConformance::SynthesizedIntroducer::Var, ctx.Id_localUnownedExecutor,
717+
optionalExecutorType, optionalExecutorType,
683718
/*static*/ false, /*final*/ false);
684719
auto property = propertyPair.first;
685720
property->setSynthesized(true);
686721
property->getAttrs().add(new (ctx) SemanticsAttr(SEMANTICS_DEFAULT_ACTOR,
687722
SourceLoc(), SourceRange(),
688-
/*implicit*/ true));
723+
/*implicit*/ true));
689724
property->getAttrs().add(new (ctx) NonisolatedAttr(/*IsImplicit=*/true));
690725

691726
// Make the property implicitly final.
@@ -703,8 +738,8 @@ static ValueDecl *deriveDistributedActor_unownedExecutor(DerivedConformance &der
703738
property, asAvailableAs, ctx);
704739

705740
auto getter =
706-
derived.addGetterToReadOnlyDerivedProperty(property, executorType);
707-
getter->setBodySynthesizer(deriveBodyDistributedActor_unownedExecutor);
741+
derived.addGetterToReadOnlyDerivedProperty(property, optionalExecutorType);
742+
getter->setBodySynthesizer(deriveBodyDistributedActor_localUnownedExecutor);
708743

709744
// IMPORTANT: MUST BE AFTER [id, actorSystem].
710745
if (auto id = derived.Nominal->getDistributedActorIDProperty()) {
@@ -747,24 +782,24 @@ static void assertRequiredSynthesizedPropertyOrder(DerivedConformance &derived,
747782
if (auto id = Nominal->getDistributedActorIDProperty()) {
748783
if (auto system = Nominal->getDistributedActorSystemProperty()) {
749784
if (auto classDecl = dyn_cast<ClassDecl>(derived.Nominal)) {
750-
if (auto unownedExecutor = classDecl->getUnownedExecutorProperty()) {
751-
int idIdx, actorSystemIdx, unownedExecutorIdx = 0;
785+
if (auto localUnownedExecutor = classDecl->getLocalUnownedExecutorProperty()) {
786+
int idIdx, actorSystemIdx, localUnownedExecutorIdx = 0;
752787
int idx = 0;
753788
for (auto member: Nominal->getMembers()) {
754789
if (auto binding = dyn_cast<PatternBindingDecl>(member)) {
755790
if (binding->getSingleVar()->getName() == Context.Id_id) {
756791
idIdx = idx;
757792
} else if (binding->getSingleVar()->getName() == Context.Id_actorSystem) {
758793
actorSystemIdx = idx;
759-
} else if (binding->getSingleVar()->getName() == Context.Id_unownedExecutor) {
760-
unownedExecutorIdx = idx;
794+
} else if (binding->getSingleVar()->getName() == Context.Id_localUnownedExecutor) {
795+
localUnownedExecutorIdx = idx;
761796
}
762797
idx += 1;
763798
}
764799
}
765-
if (idIdx + actorSystemIdx + unownedExecutorIdx >= 0 + 1 + 2) {
800+
if (idIdx + actorSystemIdx + localUnownedExecutorIdx >= 0 + 1 + 2) {
766801
// we have found all the necessary fields, let's assert their order
767-
assert(idIdx < actorSystemIdx < unownedExecutorIdx && "order of fields MUST be exact.");
802+
assert(idIdx < actorSystemIdx < localUnownedExecutorIdx && "order of fields MUST be exact.");
768803
}
769804
}
770805
}
@@ -786,8 +821,8 @@ ValueDecl *DerivedConformance::deriveDistributedActor(ValueDecl *requirement) {
786821
derivedValue = deriveDistributedActor_id(*this);
787822
} else if (var->getName() == Context.Id_actorSystem) {
788823
derivedValue = deriveDistributedActor_actorSystem(*this);
789-
} else if (var->getName() == Context.Id_unownedExecutor) {
790-
derivedValue = deriveDistributedActor_unownedExecutor(*this);
824+
} else if (var->getName() == Context.Id_localUnownedExecutor) {
825+
derivedValue = deriveDistributedActor_localUnownedExecutor(*this);
791826
}
792827

793828
assertRequiredSynthesizedPropertyOrder(*this, derivedValue);

lib/Sema/DerivedConformances.cpp

+36-5
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
335335

336336
// Actor.unownedExecutor
337337
if (name.isSimpleName(ctx.Id_unownedExecutor)) {
338-
if (nominal->isDistributedActor()) {
339-
return getRequirement(KnownProtocolKind::DistributedActor);
340-
} else {
341-
return getRequirement(KnownProtocolKind::Actor);
342-
}
338+
return getRequirement(KnownProtocolKind::Actor);
343339
}
344340

345341
// DistributedActor.id
@@ -350,6 +346,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
350346
if (name.isSimpleName(ctx.Id_actorSystem))
351347
return getRequirement(KnownProtocolKind::DistributedActor);
352348

349+
// DistributedActor.localUnownedExecutor
350+
if (name.isSimpleName(ctx.Id_localUnownedExecutor)) {
351+
return getRequirement(KnownProtocolKind::DistributedActor);
352+
}
353+
353354
return nullptr;
354355
}
355356

@@ -674,6 +675,36 @@ GuardStmt *DerivedConformance::returnFalseIfNotEqualGuard(ASTContext &C,
674675
auto falseExpr = new (C) BooleanLiteralExpr(false, SourceLoc(), true);
675676
return returnIfNotEqualGuard(C, lhsExpr, rhsExpr, falseExpr);
676677
}
678+
/// Returns a generated guard statement that checks whether the given expr is true.
679+
/// If it is false, the else block for the guard returns `nil`.
680+
/// \p C The AST context.
681+
/// \p testExpr The expression that should be tested.
682+
/// \p baseType The wrapped type of the to-be-returned Optional<Wrapped>.
683+
GuardStmt *DerivedConformance::returnNilIfFalseGuardTypeChecked(ASTContext &C,
684+
Expr *testExpr,
685+
Type optionalWrappedType) {
686+
auto nilExpr = new (C) NilLiteralExpr(SourceLoc(), /*implicit=*/true);
687+
nilExpr->setType(optionalWrappedType->wrapInOptionalType());
688+
689+
SmallVector<StmtConditionElement, 1> conditions;
690+
SmallVector<ASTNode, 1> statements;
691+
692+
auto returnStmt = new (C) ReturnStmt(SourceLoc(), nilExpr);
693+
statements.push_back(returnStmt);
694+
695+
// Next, generate the condition being checked.
696+
// auto cmpFuncExpr = new (C) UnresolvedDeclRefExpr(
697+
// DeclNameRef(C.Id_EqualsOperator), DeclRefKind::BinaryOperator,
698+
// DeclNameLoc());
699+
// auto *cmpExpr = BinaryExpr::create(C, lhsExpr, cmpFuncExpr, rhsExpr,
700+
// /*implicit*/ true);
701+
conditions.emplace_back(testExpr);
702+
703+
// Build and return the complete guard statement.
704+
// guard lhs == rhs else { return lhs < rhs }
705+
auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc());
706+
return new (C) GuardStmt(SourceLoc(), C.AllocateCopy(conditions), body);
707+
}
677708
/// Returns a generated guard statement that checks whether the given lhs and
678709
/// rhs expressions are equal. If not equal, the else block for the guard
679710
/// returns lhs < rhs.

lib/Sema/DerivedConformances.h

+6
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,12 @@ class DerivedConformance {
409409
// return false
410410
static GuardStmt *returnFalseIfNotEqualGuard(ASTContext &C, Expr *lhsExpr,
411411
Expr *rhsExpr);
412+
413+
// Return `nil` is the `testExp` is `false`.
414+
static GuardStmt *returnNilIfFalseGuardTypeChecked(ASTContext &C,
415+
Expr *testExpr,
416+
Type optionalWrappedType);
417+
412418
// return lhs < rhs
413419
static GuardStmt *
414420
returnComparisonIfNotEqualGuard(ASTContext &C, Expr *lhsExpr, Expr *rhsExpr);

0 commit comments

Comments
 (0)