Skip to content

Commit f5ed53d

Browse files
committed
Revert "[stdlib] Distributed: Remove invokeOnReturn requirement and its synthesis"
This reverts commit 961aa30.
1 parent 11ef6e5 commit f5ed53d

File tree

3 files changed

+347
-6
lines changed

3 files changed

+347
-6
lines changed

lib/Sema/DerivedConformanceDistributedActor.cpp

+321
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,319 @@ static FuncDecl *deriveDistributedActor_resolve(DerivedConformance &derived) {
121121
return factoryDecl;
122122
}
123123

124+
/******************************************************************************/
125+
/*************** INVOKE HANDLER ON-RETURN FUNCTION ****************************/
126+
/******************************************************************************/
127+
128+
namespace {
129+
struct DoInvokeOnReturnContext {
130+
ParamDecl *handlerParam;
131+
ParamDecl *resultBufferParam;
132+
};
133+
} // namespace
134+
135+
static std::pair<BraceStmt *, bool>
136+
deriveBodyDistributed_doInvokeOnReturn(AbstractFunctionDecl *afd, void *arg) {
137+
auto &C = afd->getASTContext();
138+
auto *context = static_cast<DoInvokeOnReturnContext *>(arg);
139+
140+
// mock locations, we're a thunk and don't really need detailed locations
141+
const SourceLoc sloc = SourceLoc();
142+
const DeclNameLoc dloc = DeclNameLoc();
143+
bool implicit = true;
144+
145+
auto returnTypeParam = afd->getParameters()->get(0);
146+
SmallVector<ASTNode, 8> stmts;
147+
148+
VarDecl *resultVar =
149+
new (C) VarDecl(/*isStatic=*/false, VarDecl::Introducer::Let, sloc,
150+
C.getIdentifier("result"), afd);
151+
{
152+
auto resultLoadCall = CallExpr::createImplicit(
153+
C,
154+
UnresolvedDotExpr::createImplicit(
155+
C,
156+
/*base=*/
157+
new (C) DeclRefExpr(ConcreteDeclRef(context->resultBufferParam),
158+
dloc, implicit),
159+
/*baseName=*/DeclBaseName(C.getIdentifier("load")),
160+
/*argLabels=*/
161+
{C.getIdentifier("fromByteOffset"), C.getIdentifier("as")}),
162+
ArgumentList::createImplicit(
163+
C, {Argument(sloc, C.getIdentifier("as"),
164+
new (C) DeclRefExpr(ConcreteDeclRef(returnTypeParam),
165+
dloc, implicit))}));
166+
167+
auto resultPattern = NamedPattern::createImplicit(C, resultVar);
168+
auto resultPB = PatternBindingDecl::createImplicit(
169+
C, swift::StaticSpellingKind::None, resultPattern,
170+
/*expr=*/resultLoadCall, afd);
171+
172+
stmts.push_back(resultPB);
173+
stmts.push_back(resultVar);
174+
}
175+
176+
// call the ad-hoc `handler.onReturn`
177+
{
178+
// Find the ad-hoc requirement ensured function on the concrete handler:
179+
auto onReturnFunc = C.getOnReturnOnDistributedTargetInvocationResultHandler(
180+
context->handlerParam->getInterfaceType()->getAnyNominal());
181+
assert(onReturnFunc && "did not find ad-hoc requirement witness!");
182+
183+
Expr *callExpr = CallExpr::createImplicit(
184+
C,
185+
UnresolvedDotExpr::createImplicit(
186+
C,
187+
/*base=*/
188+
new (C) DeclRefExpr(ConcreteDeclRef(context->handlerParam), dloc,
189+
implicit),
190+
/*baseName=*/onReturnFunc->getBaseName(),
191+
/*paramList=*/onReturnFunc->getParameters()),
192+
ArgumentList::forImplicitCallTo(
193+
DeclNameRef(onReturnFunc->getName()),
194+
{new (C) DeclRefExpr(ConcreteDeclRef(resultVar), dloc, implicit)},
195+
C));
196+
callExpr = TryExpr::createImplicit(C, sloc, callExpr);
197+
callExpr = AwaitExpr::createImplicit(C, sloc, callExpr);
198+
199+
stmts.push_back(callExpr);
200+
}
201+
202+
auto body = BraceStmt::create(C, sloc, {stmts}, sloc, implicit);
203+
return {body, /*isTypeChecked=*/false};
204+
}
205+
206+
// Create local function:
207+
// func invokeOnReturn<R: Self.SerializationRequirement>(
208+
// _ returnType: R.Type
209+
// ) async throws {
210+
// let value = resultBuffer.load(as: returnType)
211+
// try await handler.onReturn(value: value)
212+
// }
213+
static FuncDecl* createLocalFunc_doInvokeOnReturn(
214+
ASTContext& C, FuncDecl* parentFunc,
215+
NominalTypeDecl* systemNominal,
216+
ParamDecl* handlerParam,
217+
ParamDecl* resultBufParam) {
218+
auto DC = parentFunc;
219+
auto DAS = C.getDistributedActorSystemDecl();
220+
auto doInvokeLocalFuncIdent = C.getIdentifier("doInvokeOnReturn");
221+
222+
// mock locations, we're a synthesized func and don't need real locations
223+
const SourceLoc sloc = SourceLoc();
224+
225+
// <R: Self.SerializationRequirement>
226+
// We create the generic param at invalid depth, which means it'll be filled
227+
// by semantic analysis.
228+
auto *resultGenericParamDecl = GenericTypeParamDecl::createImplicit(
229+
parentFunc, C.getIdentifier("R"), /*depth*/ 0, /*index*/ 0);
230+
GenericParamList *doInvokeGenericParamList =
231+
GenericParamList::create(C, sloc, {resultGenericParamDecl}, sloc);
232+
233+
auto returnTypeIdent = C.getIdentifier("returnType");
234+
auto resultTyParamDecl =
235+
ParamDecl::createImplicit(C,
236+
/*argument=*/returnTypeIdent,
237+
/*parameter=*/returnTypeIdent,
238+
resultGenericParamDecl->getInterfaceType(), DC);
239+
ParameterList *doInvokeParamsList =
240+
ParameterList::create(C, {resultTyParamDecl});
241+
242+
SmallVector<Requirement, 2> requirements;
243+
for (auto p : getDistributedSerializationRequirementProtocols(systemNominal, DAS)) {
244+
auto requirement =
245+
Requirement(RequirementKind::Conformance,
246+
resultGenericParamDecl->getDeclaredInterfaceType(),
247+
p->getDeclaredInterfaceType());
248+
requirements.push_back(requirement);
249+
}
250+
GenericSignature doInvokeGenSig =
251+
buildGenericSignature(C, parentFunc->getGenericSignature(),
252+
{resultGenericParamDecl->getDeclaredInterfaceType()
253+
->castTo<GenericTypeParamType>()},
254+
std::move(requirements),
255+
/*allowInverses=*/true);
256+
257+
FuncDecl *doInvokeOnReturnFunc = FuncDecl::createImplicit(
258+
C, swift::StaticSpellingKind::None,
259+
DeclName(C, doInvokeLocalFuncIdent, doInvokeParamsList),
260+
sloc,
261+
/*async=*/true,
262+
/*throws=*/true,
263+
/*ThrownType=*/Type(),
264+
doInvokeGenericParamList, doInvokeParamsList,
265+
/*returnType=*/C.TheEmptyTupleType, parentFunc);
266+
doInvokeOnReturnFunc->setImplicit();
267+
doInvokeOnReturnFunc->setSynthesized();
268+
doInvokeOnReturnFunc->setGenericSignature(doInvokeGenSig);
269+
270+
auto *doInvokeContext = C.Allocate<DoInvokeOnReturnContext>();
271+
doInvokeContext->handlerParam = handlerParam;
272+
doInvokeContext->resultBufferParam = resultBufParam;
273+
doInvokeOnReturnFunc->setBodySynthesizer(
274+
deriveBodyDistributed_doInvokeOnReturn, doInvokeContext);
275+
276+
return doInvokeOnReturnFunc;
277+
}
278+
279+
static std::pair<BraceStmt *, bool>
280+
deriveBodyDistributed_invokeHandlerOnReturn(AbstractFunctionDecl *afd,
281+
void *context) {
282+
auto implicit = true;
283+
ASTContext &C = afd->getASTContext();
284+
auto DC = afd->getDeclContext();
285+
auto DAS = C.getDistributedActorSystemDecl();
286+
287+
// mock locations, we're a thunk and don't really need detailed locations
288+
const SourceLoc sloc = SourceLoc();
289+
const DeclNameLoc dloc = DeclNameLoc();
290+
291+
NominalTypeDecl *nominal = dyn_cast<NominalTypeDecl>(DC);
292+
assert(nominal);
293+
294+
auto func = dyn_cast<FuncDecl>(afd);
295+
assert(func);
296+
297+
// === parameters
298+
auto params = func->getParameters();
299+
assert(params->size() == 3);
300+
auto handlerParam = params->get(0);
301+
auto resultBufParam = params->get(1);
302+
auto metatypeParam = params->get(2);
303+
304+
auto serializationRequirementTypeTy =
305+
getDistributedSerializationRequirementType(nominal, DAS);
306+
307+
auto serializationRequirementMetaTypeTy =
308+
ExistentialMetatypeType::get(serializationRequirementTypeTy);
309+
310+
// Statements
311+
SmallVector<ASTNode, 8> stmts;
312+
313+
// --- `let m = metatype as! SerializationRequirement.Type`
314+
VarDecl *metatypeVar =
315+
new (C) VarDecl(/*isStatic=*/false, VarDecl::Introducer::Let, sloc,
316+
C.getIdentifier("m"), func);
317+
{
318+
metatypeVar->setImplicit();
319+
metatypeVar->setSynthesized();
320+
321+
// metatype as! <<concrete SerializationRequirement.Type>>
322+
auto metatypeRef =
323+
new (C) DeclRefExpr(ConcreteDeclRef(metatypeParam), dloc, implicit);
324+
auto metatypeSRCastExpr = ForcedCheckedCastExpr::createImplicit(
325+
C, metatypeRef, serializationRequirementMetaTypeTy);
326+
327+
auto metatypePattern = NamedPattern::createImplicit(C, metatypeVar);
328+
auto metatypePB = PatternBindingDecl::createImplicit(
329+
C, swift::StaticSpellingKind::None, metatypePattern,
330+
/*expr=*/metatypeSRCastExpr, func);
331+
332+
stmts.push_back(metatypePB);
333+
stmts.push_back(metatypeVar);
334+
}
335+
336+
// --- Declare the local function `doInvokeOnReturn`...
337+
FuncDecl *doInvokeOnReturnFunc = createLocalFunc_doInvokeOnReturn(
338+
C, func,
339+
nominal, handlerParam, resultBufParam);
340+
stmts.push_back(doInvokeOnReturnFunc);
341+
342+
// --- try await _openExistential(metatypeVar, do: <<doInvokeLocalFunc>>)
343+
{
344+
auto openExistentialBaseIdent = C.getIdentifier("_openExistential");
345+
auto doIdent = C.getIdentifier("do");
346+
347+
auto openExArgs = ArgumentList::createImplicit(
348+
C, {
349+
Argument(sloc, Identifier(),
350+
new (C) DeclRefExpr(ConcreteDeclRef(metatypeVar), dloc,
351+
implicit)),
352+
Argument(sloc, doIdent,
353+
new (C) DeclRefExpr(ConcreteDeclRef(doInvokeOnReturnFunc),
354+
dloc, implicit)),
355+
});
356+
Expr *tryAwaitDoOpenExistential =
357+
CallExpr::createImplicit(C,
358+
UnresolvedDeclRefExpr::createImplicit(
359+
C, openExistentialBaseIdent),
360+
openExArgs);
361+
362+
tryAwaitDoOpenExistential =
363+
AwaitExpr::createImplicit(C, sloc, tryAwaitDoOpenExistential);
364+
tryAwaitDoOpenExistential =
365+
TryExpr::createImplicit(C, sloc, tryAwaitDoOpenExistential);
366+
367+
stmts.push_back(tryAwaitDoOpenExistential);
368+
}
369+
370+
auto body = BraceStmt::create(C, sloc, {stmts}, sloc, implicit);
371+
return {body, /*isTypeChecked=*/false};
372+
}
373+
374+
/// Synthesizes the
375+
///
376+
/// \verbatim
377+
/// static func invokeHandlerOnReturn(
378+
//// handler: ResultHandler,
379+
//// resultBuffer: UnsafeRawPointer,
380+
//// metatype _metatype: Any.Type
381+
//// ) async throws
382+
/// \endverbatim
383+
static FuncDecl *deriveDistributedActorSystem_invokeHandlerOnReturn(
384+
DerivedConformance &derived) {
385+
auto system = derived.Nominal;
386+
auto &C = system->getASTContext();
387+
388+
// auto serializationRequirementType = getDistributedActorSystemType(decl);
389+
auto resultHandlerType = getDistributedActorSystemResultHandlerType(system);
390+
auto unsafeRawPointerType = C.getUnsafeRawPointerType();
391+
auto anyTypeType = ExistentialMetatypeType::get(C.TheAnyType); // Any.Type
392+
393+
// auto serializationRequirementType =
394+
// getDistributedSerializationRequirementType(system, DAS);
395+
396+
// params:
397+
// - handler: Self.ResultHandler
398+
// - resultBuffer:
399+
// - metatype _metatype: Any.Type
400+
auto *params = ParameterList::create(
401+
C,
402+
/*LParenLoc=*/SourceLoc(),
403+
/*params=*/
404+
{
405+
ParamDecl::createImplicit(
406+
C, C.Id_handler, C.Id_handler,
407+
system->mapTypeIntoContext(resultHandlerType), system),
408+
ParamDecl::createImplicit(
409+
C, C.Id_resultBuffer, C.Id_resultBuffer,
410+
unsafeRawPointerType, system),
411+
ParamDecl::createImplicit(
412+
C, C.Id_metatype, C.Id_metatype,
413+
anyTypeType, system)
414+
},
415+
/*RParenLoc=*/SourceLoc());
416+
417+
// Func name: invokeHandlerOnReturn(handler:resultBuffer:metatype)
418+
DeclName name(C, C.Id_invokeHandlerOnReturn, params);
419+
420+
// Expected type: (Self.ResultHandler, UnsafeRawPointer, any Any.Type) async
421+
// throws -> ()
422+
auto *funcDecl =
423+
FuncDecl::createImplicit(C, StaticSpellingKind::None, name, SourceLoc(),
424+
/*async=*/true,
425+
/*throws=*/true,
426+
/*ThrownType=*/Type(),
427+
/*genericParams=*/nullptr, params,
428+
/*returnType*/ TupleType::getEmpty(C), system);
429+
funcDecl->setSynthesized(true);
430+
funcDecl->copyFormalAccessFrom(system, /*sourceIsParentContext=*/true);
431+
funcDecl->setBodySynthesizer(deriveBodyDistributed_invokeHandlerOnReturn);
432+
433+
derived.addMembersToConformanceContext({funcDecl});
434+
return funcDecl;
435+
}
436+
124437
/******************************************************************************/
125438
/******************************* PROPERTIES ***********************************/
126439
/******************************************************************************/
@@ -581,6 +894,14 @@ std::pair<Type, TypeDecl *> DerivedConformance::deriveDistributedActor(
581894

582895
ValueDecl *
583896
DerivedConformance::deriveDistributedActorSystem(ValueDecl *requirement) {
897+
if (auto func = dyn_cast<FuncDecl>(requirement)) {
898+
// just a simple name check is enough here,
899+
// if we are invoked here we know for sure it is for the "right" function
900+
if (func->getName().getBaseName() == Context.Id_invokeHandlerOnReturn) {
901+
return deriveDistributedActorSystem_invokeHandlerOnReturn(*this);
902+
}
903+
}
904+
584905
return nullptr;
585906
}
586907

lib/Sema/DerivedConformances.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
400400
}
401401
}
402402

403+
// DistributedActor.actorSystem
404+
if (name.isCompoundName() &&
405+
name.getBaseName() == ctx.Id_invokeHandlerOnReturn)
406+
return getRequirement(KnownProtocolKind::DistributedActorSystem);
407+
403408
return nullptr;
404409
}
405410

stdlib/public/Distributed/DistributedActorSystem.swift

+21-6
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,22 @@ public protocol DistributedActorSystem<SerializationRequirement>: Sendable {
413413
where Act: DistributedActor,
414414
Act.ID == ActorID,
415415
Err: Error
416+
417+
// Implementation notes:
418+
// The `metatype` must be the type of `Value`, and it must conform to
419+
// `SerializationRequirement`. If it does not, the method will crash at
420+
// runtime. This is because we cannot express
421+
// `Value: SerializationRequirement`, however the generic `Value` is still
422+
// useful since it allows us to avoid boxing the value into an existential,
423+
// before we'd right away unbox it as first thing in the implementation of
424+
// this function.
425+
/// Implementation synthesized by the compiler.
426+
/// Not intended to be invoked explicitly from user code!
427+
func invokeHandlerOnReturn(
428+
handler: ResultHandler,
429+
resultBuffer: UnsafeRawPointer,
430+
metatype: Any.Type
431+
) async throws
416432
}
417433

418434
// ==== ----------------------------------------------------------------------------------------------------------------
@@ -650,12 +666,11 @@ extension DistributedActorSystem {
650666
if returnType == Void.self {
651667
try await handler.onReturnVoid()
652668
} else {
653-
func invokeOnReturn<R>(_ returnType: R.Type) async throws {
654-
let value = resultBuffer.load(as: returnType)
655-
try await handler.onReturn(value: value)
656-
}
657-
658-
try await _openExistential(returnType, do: invokeOnReturn)
669+
try await self.invokeHandlerOnReturn(
670+
handler: handler,
671+
resultBuffer: resultBuffer,
672+
metatype: returnType
673+
)
659674
}
660675
} catch {
661676
try await handler.onThrow(error: error)

0 commit comments

Comments
 (0)