Skip to content

Commit 45eaf33

Browse files
committed
[Distributed] Require that SerReq can only be used with protocols
1 parent 49eccc4 commit 45eaf33

6 files changed

+191
-4
lines changed

include/swift/AST/DiagnosticsSema.def

+3
Original file line numberDiff line numberDiff line change
@@ -4631,6 +4631,9 @@ ERROR(distributed_actor_func_param_not_codable,none,
46314631
ERROR(distributed_actor_target_result_not_codable,none,
46324632
"result type %0 of %1 %2 does not conform to serialization requirement '%3'",
46334633
(Type, DescriptiveDeclKind, Identifier, StringRef))
4634+
ERROR(distributed_actor_system_serialization_req_must_be_protocol,none,
4635+
"'SerializationRequirement' type witness %0 must be a protocol type",
4636+
(Type))
46344637
ERROR(distributed_actor_remote_func_implemented_manually,none,
46354638
"distributed instance method's %0 remote counterpart %1 cannot not be implemented manually.",
46364639
(Identifier, Identifier))

lib/AST/DistributedDecl.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,24 @@ swift::getDistributedSerializationRequirements(
271271
if (existentialRequirementTy->isAny())
272272
return true; // we're done here, any means there are no requirements
273273

274-
auto serialReqType = existentialRequirementTy->castTo<ExistentialType>()
275-
->getConstraintType()
276-
->getDesugaredType();
274+
if (auto alias = dyn_cast<TypeAliasType>(existentialRequirementTy.getPointer())) {
275+
auto ty = alias->getDesugaredType();
276+
if (isa<ClassType>(ty) || isa<StructType>(ty) || isa<EnumType>(ty)) {
277+
// SerializationRequirement cannot be class or struct nowadays
278+
return false;
279+
}
280+
}
281+
282+
ExistentialType *serialReqType = existentialRequirementTy
283+
->castTo<ExistentialType>();
284+
if (!serialReqType || serialReqType->hasError()) {
285+
return false;
286+
}
287+
288+
auto desugaredTy = serialReqType->getConstraintType()->getDesugaredType();
277289
auto flattenedRequirements =
278290
flattenDistributedSerializationTypeToRequiredProtocols(
279-
serialReqType);
291+
desugaredTy);
280292
for (auto p : flattenedRequirements) {
281293
requirementProtos.insert(p);
282294
}

lib/Sema/TypeCheckDistributed.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,47 @@ static bool checkDistributedTargetResultType(
431431
return false;
432432
}
433433

434+
bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) {
435+
auto nominal = const_cast<NominalTypeDecl *>(system);
436+
437+
// ==== Ensure the Distributed module is available,
438+
// without it there's no reason to check the decl in more detail anyway.
439+
if (!swift::ensureDistributedModuleLoaded(nominal))
440+
return true;
441+
442+
auto &C = nominal->getASTContext();
443+
auto DAS = C.getDistributedActorSystemDecl();
444+
445+
// === AssociatedTypes
446+
// --- SerializationRequirement MUST be a protocol TODO(distributed): rdar://91663941
447+
// we may lift this in the future and allow classes but this requires more
448+
// work to enable associatedtypes to be constrained to class or protocol,
449+
// which then will unlock using them as generic constraints in protocols.
450+
Type requirementTy = getDistributedSerializationRequirementType(nominal, DAS);
451+
requirementTy->dump();
452+
453+
if (auto existentialTy = requirementTy->getAs<ExistentialType>()) {
454+
requirementTy = existentialTy->getConstraintType();
455+
}
456+
457+
if (auto alias = dyn_cast<TypeAliasType>(requirementTy.getPointer())) {
458+
auto concreteReqTy = alias->getDesugaredType();
459+
if (auto comp = dyn_cast<ProtocolCompositionType>(concreteReqTy)) {
460+
// ok, protocol composition is fine as requirement,
461+
// since special case of just a single protocol
462+
} else if (auto proto = dyn_cast<ProtocolType>(concreteReqTy)) {
463+
// ok, protocols is exactly what we want to be used as constraints here
464+
} else {
465+
nominal->diagnose(diag::distributed_actor_system_serialization_req_must_be_protocol,
466+
requirementTy);
467+
return true;
468+
}
469+
}
470+
471+
// all good, didn't find any errors
472+
return false;
473+
}
474+
434475
/// Check whether the function is a proper distributed function
435476
///
436477
/// \returns \c true if there was a problem with adding the attribute, \c false

lib/Sema/TypeCheckDistributed.h

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ bool checkDistributedActorSystemAdHocProtocolRequirements(
5050
Type Adoptee,
5151
bool diagnose);
5252

53+
/// Check 'DistributedActorSystem' implementations for additional restrictions.
54+
bool checkDistributedActorSystem(const NominalTypeDecl *system);
55+
5356
/// Typecheck a distributed method declaration
5457
bool checkDistributedFunction(AbstractFunctionDecl *decl);
5558

lib/Sema/TypeCheckProtocol.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -6333,6 +6333,9 @@ void TypeChecker::checkConformancesInContext(IterableDeclContext *idc) {
63336333
}
63346334
}
63356335
}
6336+
} else if (proto->isSpecificProtocol(
6337+
KnownProtocolKind::DistributedActorSystem)) {
6338+
checkDistributedActorSystem(nominal);
63366339
} else if (proto->isSpecificProtocol(KnownProtocolKind::Actor)) {
63376340
if (auto classDecl = dyn_cast<ClassDecl>(nominal)) {
63386341
if (!classDecl->isExplicitActor()) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/Inputs/FakeDistributedActorSystems.swift
3+
// RUN: %target-swift-frontend -typecheck -verify -verify-ignore-unknown -disable-availability-checking -I %t 2>&1 %s
4+
// REQUIRES: concurrency
5+
// REQUIRES: distributed
6+
7+
import Distributed
8+
import FakeDistributedActorSystems
9+
10+
final class SomeClazz: Sendable {}
11+
final class SomeStruct: Sendable {}
12+
final class SomeEnum: Sendable {}
13+
14+
// TODO(distributed): improve to diagnose ON the typealias rather than on the system
15+
// expected-error@+1{{'SerializationRequirement' type witness 'System.SerializationRequirement' (aka 'SomeClazz') must be a protocol type}}
16+
final class System: DistributedActorSystem {
17+
// ignore those since they all fail with the SerializationRequirement being invalid:
18+
// expected-error@-2{{type 'System' does not conform to protocol 'DistributedActorSystem'}}
19+
// expected-note@-3{{protocol 'System' requires function 'remoteCallVoid'}}
20+
// expected-error@-4{{class 'System' is missing witness for protocol requirement 'remoteCall'}}
21+
// expected-note@-5{{protocol 'System' requires function 'remoteCall' with signature:}}
22+
// expected-error@-6{{class 'System' is missing witness for protocol requirement 'remoteCallVoid'}}
23+
typealias ActorID = String
24+
typealias InvocationEncoder = ClassInvocationEncoder
25+
typealias InvocationDecoder = ClassInvocationDecoder
26+
27+
typealias ResultHandler = DistributedTargetInvocationResultHandler
28+
// expected-note@-1{{possibly intended match 'System.ResultHandler' (aka 'DistributedTargetInvocationResultHandler') does not conform to 'DistributedTargetInvocationResultHandler'}}
29+
30+
typealias SerializationRequirement = SomeClazz
31+
32+
func resolve<Act>(id: ActorID, as actorType: Act.Type)
33+
throws -> Act? where Act: DistributedActor {
34+
fatalError()
35+
}
36+
37+
func assignID<Act>(_ actorType: Act.Type) -> ActorID
38+
where Act: DistributedActor {
39+
fatalError()
40+
}
41+
42+
func actorReady<Act>(_ actor: Act)
43+
where Act: DistributedActor,
44+
Act.ID == ActorID {
45+
fatalError()
46+
}
47+
48+
func resignID(_ id: ActorID) {
49+
fatalError()
50+
}
51+
52+
func makeInvocationEncoder() -> InvocationEncoder {
53+
fatalError()
54+
}
55+
56+
func remoteCall<Act, Err, Res>(
57+
on actor: Act,
58+
target: RemoteCallTarget,
59+
invocation: inout InvocationEncoder,
60+
throwing errorType: Err.Type,
61+
returning returnType: Res.Type
62+
) async throws -> Res
63+
where Act: DistributedActor,
64+
Act.ID == ActorID,
65+
Err: Error,
66+
Res: SerializationRequirement {
67+
fatalError()
68+
}
69+
70+
func remoteCallVoid<Act, Err>(
71+
on actor: Act,
72+
target: RemoteCallTarget,
73+
invocation: inout InvocationEncoder,
74+
throwing errorType: Err.Type
75+
) async throws
76+
where Act: DistributedActor,
77+
Act.ID == ActorID,
78+
Err: Error {
79+
fatalError()
80+
}
81+
}
82+
83+
struct ClassInvocationEncoder: DistributedTargetInvocationEncoder {
84+
// expected-note@-1{{protocol 'ClassInvocationEncoder' requires function 'recordArgument' with signature:}}
85+
// expected-error@-2{{struct 'ClassInvocationEncoder' is missing witness for protocol requirement 'recordArgument'}}
86+
// expected-note@-3{{protocol 'ClassInvocationEncoder' requires function 'recordReturnType' with signature:}}
87+
// expected-error@-4{{struct 'ClassInvocationEncoder' is missing witness for protocol requirement 'recordReturnType'}}
88+
typealias SerializationRequirement = SomeClazz
89+
90+
public mutating func recordGenericSubstitution<T>(_ type: T.Type) throws {}
91+
public mutating func recordArgument<Value: SerializationRequirement>(
92+
_ argument: RemoteCallArgument<Value>) throws {}
93+
public mutating func recordErrorType<E: Error>(_ type: E.Type) throws {}
94+
public mutating func recordReturnType<R: SerializationRequirement>(_ type: R.Type) throws {}
95+
public mutating func doneRecording() throws {}
96+
}
97+
98+
final class ClassInvocationDecoder: DistributedTargetInvocationDecoder {
99+
// expected-error@-1{{class 'ClassInvocationDecoder' is missing witness for protocol requirement 'decodeNextArgument'}}
100+
// expected-note@-2{{protocol 'ClassInvocationDecoder' requires function 'decodeNextArgument'}}
101+
typealias SerializationRequirement = SomeClazz
102+
103+
public func decodeGenericSubstitutions() throws -> [Any.Type] {
104+
fatalError()
105+
}
106+
107+
public func decodeNextArgument<Argument: SerializationRequirement>() throws -> Argument {
108+
fatalError()
109+
}
110+
111+
public func decodeErrorType() throws -> Any.Type? {
112+
fatalError()
113+
}
114+
115+
public func decodeReturnType() throws -> Any.Type? {
116+
fatalError()
117+
}
118+
}
119+
120+
struct DistributedTargetInvocationResultHandler {
121+
typealias SerializationRequirement = SomeClazz
122+
func onReturn<Success: SomeClazz>(value: Success) async throws {}
123+
func onReturnVoid() async throws {}
124+
func onThrow<Err: Error>(error: Err) async throws {}
125+
}

0 commit comments

Comments
 (0)