Skip to content

Commit 7499c22

Browse files
committed
AST: Requestify lookup of protocol referenced by ImplementsAttr
Direct lookup relied in primary file checking to have filled in the protocol type stored in the ImplementsAttr. This was already wrong with multi-file test cases in non-WMO mode, and crashed in the ASTPrinter if printing a declaration in a non-primary file. I don't have a standalone test case that is independent of my upcoming ASTPrinter changes, but this is a nice cleanup regardless.
1 parent 8ccabbd commit 7499c22

10 files changed

+118
-89
lines changed

Diff for: include/swift/AST/Attr.h

+10-7
Original file line numberDiff line numberDiff line change
@@ -1551,25 +1551,28 @@ class SpecializeAttr final
15511551
/// The @_implements attribute, which treats a decl as the implementation for
15521552
/// some named protocol requirement (but otherwise not-visible by that name).
15531553
class ImplementsAttr : public DeclAttribute {
1554-
TypeExpr *ProtocolType;
1554+
TypeRepr *TyR;
15551555
DeclName MemberName;
15561556
DeclNameLoc MemberNameLoc;
15571557

1558-
public:
15591558
ImplementsAttr(SourceLoc atLoc, SourceRange Range,
1560-
TypeExpr *ProtocolType,
1559+
TypeRepr *TyR,
15611560
DeclName MemberName,
15621561
DeclNameLoc MemberNameLoc);
15631562

1563+
public:
15641564
static ImplementsAttr *create(ASTContext &Ctx, SourceLoc atLoc,
15651565
SourceRange Range,
1566-
TypeExpr *ProtocolType,
1566+
TypeRepr *TyR,
15671567
DeclName MemberName,
15681568
DeclNameLoc MemberNameLoc);
15691569

1570-
void setProtocolType(Type ty);
1571-
Type getProtocolType() const;
1572-
TypeRepr *getProtocolTypeRepr() const;
1570+
static ImplementsAttr *create(DeclContext *DC,
1571+
ProtocolDecl *Proto,
1572+
DeclName MemberName);
1573+
1574+
ProtocolDecl *getProtocol(DeclContext *dc) const;
1575+
TypeRepr *getProtocolTypeRepr() const { return TyR; }
15731576

15741577
DeclName getMemberName() const { return MemberName; }
15751578
DeclNameLoc getMemberNameLoc() const { return MemberNameLoc; }

Diff for: include/swift/AST/NameLookupRequests.h

+19
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,25 @@ class PotentialMacroExpansionsInContextRequest
912912
bool isCached() const { return true; }
913913
};
914914

915+
/// Resolves the protocol referenced by an @_implements attribute.
916+
class ImplementsAttrProtocolRequest
917+
: public SimpleRequest<ImplementsAttrProtocolRequest,
918+
ProtocolDecl *(const ImplementsAttr *, DeclContext *),
919+
RequestFlags::Cached> {
920+
public:
921+
using SimpleRequest::SimpleRequest;
922+
923+
private:
924+
friend SimpleRequest;
925+
926+
// Evaluation.
927+
ProtocolDecl *evaluate(Evaluator &evaluator, const ImplementsAttr *attr,
928+
DeclContext *dc) const;
929+
930+
public:
931+
bool isCached() const { return true; }
932+
};
933+
915934
#define SWIFT_TYPEID_ZONE NameLookup
916935
#define SWIFT_TYPEID_HEADER "swift/AST/NameLookupTypeIDZone.def"
917936
#include "swift/Basic/DefineTypeIDZone.h"

Diff for: include/swift/AST/NameLookupTypeIDZone.def

+2
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,5 @@ SWIFT_REQUEST(NameLookup, HasDynamicCallableAttributeRequest,
109109
bool(NominalTypeDecl *), Cached, NoLocationInfo)
110110
SWIFT_REQUEST(NameLookup, PotentialMacroExpansionsInContextRequest,
111111
PotentialMacroExpansions(TypeOrExtension), Cached, NoLocationInfo)
112+
SWIFT_REQUEST(NameLookup, ImplementsAttrProtocolRequest,
113+
ProtocolDecl *(const ImplementsAttr *, DeclContext *), Cached, NoLocationInfo)

Diff for: lib/AST/Attr.cpp

+22-15
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
12391239
Printer.printAttrName("@_implements");
12401240
Printer << "(";
12411241
auto *attr = cast<ImplementsAttr>(this);
1242-
attr->getProtocolType().print(Printer, Options);
1242+
if (auto *proto = attr->getProtocol(D->getDeclContext()))
1243+
proto->getDeclaredInterfaceType()->print(Printer, Options);
1244+
else
1245+
attr->getProtocolTypeRepr()->print(Printer, Options);
12431246
Printer << ", " << attr->getMemberName() << ")";
12441247
break;
12451248
}
@@ -2360,37 +2363,41 @@ TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
23602363
}
23612364

23622365
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
2363-
TypeExpr *ProtocolType,
2366+
TypeRepr *TyR,
23642367
DeclName MemberName,
23652368
DeclNameLoc MemberNameLoc)
23662369
: DeclAttribute(DAK_Implements, atLoc, range, /*Implicit=*/false),
2367-
ProtocolType(ProtocolType),
2370+
TyR(TyR),
23682371
MemberName(MemberName),
23692372
MemberNameLoc(MemberNameLoc) {
23702373
}
23712374

2372-
23732375
ImplementsAttr *ImplementsAttr::create(ASTContext &Ctx, SourceLoc atLoc,
23742376
SourceRange range,
2375-
TypeExpr *ProtocolType,
2377+
TypeRepr *TyR,
23762378
DeclName MemberName,
23772379
DeclNameLoc MemberNameLoc) {
23782380
void *mem = Ctx.Allocate(sizeof(ImplementsAttr), alignof(ImplementsAttr));
2379-
return new (mem) ImplementsAttr(atLoc, range, ProtocolType,
2381+
return new (mem) ImplementsAttr(atLoc, range, TyR,
23802382
MemberName, MemberNameLoc);
23812383
}
23822384

2383-
void ImplementsAttr::setProtocolType(Type ty) {
2384-
assert(ty);
2385-
ProtocolType->setType(MetatypeType::get(ty));
2386-
}
2387-
2388-
Type ImplementsAttr::getProtocolType() const {
2389-
return ProtocolType->getInstanceType();
2385+
ImplementsAttr *ImplementsAttr::create(DeclContext *DC,
2386+
ProtocolDecl *Proto,
2387+
DeclName MemberName) {
2388+
auto &ctx = DC->getASTContext();
2389+
void *mem = ctx.Allocate(sizeof(ImplementsAttr), alignof(ImplementsAttr));
2390+
auto *attr = new (mem) ImplementsAttr(
2391+
SourceLoc(), SourceRange(), nullptr,
2392+
MemberName, DeclNameLoc());
2393+
ctx.evaluator.cacheOutput(ImplementsAttrProtocolRequest{attr, DC},
2394+
std::move(Proto));
2395+
return attr;
23902396
}
23912397

2392-
TypeRepr *ImplementsAttr::getProtocolTypeRepr() const {
2393-
return ProtocolType->getTypeRepr();
2398+
ProtocolDecl *ImplementsAttr::getProtocol(DeclContext *dc) const {
2399+
return evaluateOrDefault(dc->getASTContext().evaluator,
2400+
ImplementsAttrProtocolRequest{this, dc}, nullptr);
23942401
}
23952402

23962403
CustomAttr::CustomAttr(SourceLoc atLoc, SourceRange range, TypeExpr *type,

Diff for: lib/AST/NameLookup.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -3644,6 +3644,28 @@ bool TypeBase::hasDynamicCallableAttribute() {
36443644
});
36453645
}
36463646

3647+
ProtocolDecl *ImplementsAttrProtocolRequest::evaluate(
3648+
Evaluator &evaluator, const ImplementsAttr *attr, DeclContext *dc) const {
3649+
3650+
auto typeRepr = attr->getProtocolTypeRepr();
3651+
3652+
ASTContext &ctx = dc->getASTContext();
3653+
DirectlyReferencedTypeDecls referenced =
3654+
directReferencesForTypeRepr(evaluator, ctx, typeRepr, dc);
3655+
3656+
// Resolve those type declarations to nominal type declarations.
3657+
SmallVector<ModuleDecl *, 2> modulesFound;
3658+
bool anyObject = false;
3659+
auto nominalTypes
3660+
= resolveTypeDeclsToNominal(evaluator, ctx, referenced, modulesFound,
3661+
anyObject);
3662+
3663+
if (nominalTypes.empty())
3664+
return nullptr;
3665+
3666+
return dyn_cast<ProtocolDecl>(nominalTypes.front());
3667+
}
3668+
36473669
void FindLocalVal::checkPattern(const Pattern *Pat, DeclVisibilityKind Reason) {
36483670
Pat->forEachVariable([&](VarDecl *VD) { checkValueDecl(VD, Reason); });
36493671
}

Diff for: lib/Parse/ParseDecl.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1218,10 +1218,9 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {
12181218
}
12191219

12201220
// FIXME(ModQual): Reject module qualification on MemberName.
1221-
auto *TE = new (Context) TypeExpr(ProtocolType.get());
12221221
return ParserResult<ImplementsAttr>(
12231222
ImplementsAttr::create(Context, AtLoc, SourceRange(Loc, rParenLoc),
1224-
TE, MemberName.getFullName(),
1223+
ProtocolType.get(), MemberName.getFullName(),
12251224
MemberNameLoc));
12261225
}
12271226

Diff for: lib/Sema/DerivedConformanceComparable.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,12 @@ deriveComparable_lt(
259259
// Add the @_implements(Comparable, < (_:_:)) attribute
260260
if (generatedIdentifier != C.Id_LessThanOperator) {
261261
auto comparable = C.getProtocol(KnownProtocolKind::Comparable);
262-
auto comparableType = comparable->getDeclaredInterfaceType();
263-
auto comparableTypeExpr = TypeExpr::createImplicit(comparableType, C);
264262
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
265263
auto comparableDeclName = DeclName(C, DeclBaseName(C.Id_LessThanOperator),
266264
argumentLabels);
267-
comparableDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
268-
SourceRange(),
269-
comparableTypeExpr,
270-
comparableDeclName,
271-
DeclNameLoc()));
265+
comparableDecl->getAttrs().add(ImplementsAttr::create(parentDC,
266+
comparable,
267+
comparableDeclName));
272268
}
273269

274270
if (!C.getLessThanIntDecl()) {

Diff for: lib/Sema/DerivedConformanceEquatableHashable.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -417,16 +417,12 @@ deriveEquatable_eq(
417417
// Add the @_implements(Equatable, ==(_:_:)) attribute
418418
if (generatedIdentifier != C.Id_EqualsOperator) {
419419
auto equatableProto = C.getProtocol(KnownProtocolKind::Equatable);
420-
auto equatableTy = equatableProto->getDeclaredInterfaceType();
421-
auto equatableTyExpr = TypeExpr::createImplicit(equatableTy, C);
422420
SmallVector<Identifier, 2> argumentLabels = { Identifier(), Identifier() };
423421
auto equalsDeclName = DeclName(C, DeclBaseName(C.Id_EqualsOperator),
424422
argumentLabels);
425-
eqDecl->getAttrs().add(new (C) ImplementsAttr(SourceLoc(),
426-
SourceRange(),
427-
equatableTyExpr,
428-
equalsDeclName,
429-
DeclNameLoc()));
423+
eqDecl->getAttrs().add(ImplementsAttr::create(parentDC,
424+
equatableProto,
425+
equalsDeclName));
430426
}
431427

432428
if (!C.getEqualIntDecl()) {

Diff for: lib/Sema/TypeCheckAttr.cpp

+33-46
Original file line numberDiff line numberDiff line change
@@ -3572,58 +3572,45 @@ void AttributeChecker::visitTypeEraserAttr(TypeEraserAttr *attr) {
35723572
void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) {
35733573
DeclContext *DC = D->getDeclContext();
35743574

3575-
Type T = attr->getProtocolType();
3576-
if (!T && attr->getProtocolTypeRepr()) {
3577-
auto context = TypeResolverContext::GenericRequirement;
3578-
T = TypeResolution::resolveContextualType(attr->getProtocolTypeRepr(), DC,
3579-
TypeResolutionOptions(context),
3580-
/*unboundTyOpener*/ nullptr,
3581-
/*placeholderHandler*/ nullptr,
3582-
/*packElementOpener*/ nullptr);
3583-
}
3584-
3585-
// Definite error-types were already diagnosed in resolveType.
3586-
if (T->hasError())
3575+
ProtocolDecl *PD = attr->getProtocol(DC);
3576+
3577+
if (!PD) {
3578+
diagnose(attr->getLocation(), diag::implements_attr_non_protocol_type)
3579+
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
35873580
return;
3588-
attr->setProtocolType(T);
3581+
}
35893582

3590-
// Check that we got a ProtocolType.
3591-
if (auto PT = T->getAs<ProtocolType>()) {
3592-
ProtocolDecl *PD = PT->getDecl();
3583+
// Check that the ProtocolType has the specified member.
3584+
LookupResult R =
3585+
TypeChecker::lookupMember(PD->getDeclContext(),
3586+
PD->getDeclaredInterfaceType(),
3587+
DeclNameRef(attr->getMemberName()));
3588+
if (!R) {
3589+
diagnose(attr->getLocation(),
3590+
diag::implements_attr_protocol_lacks_member,
3591+
PD->getName(), attr->getMemberName())
3592+
.highlight(attr->getMemberNameLoc().getSourceRange());
3593+
return;
3594+
}
35933595

3594-
// Check that the ProtocolType has the specified member.
3595-
LookupResult R =
3596-
TypeChecker::lookupMember(PD->getDeclContext(), PT,
3597-
DeclNameRef(attr->getMemberName()));
3598-
if (!R) {
3596+
// Check that the decl we're decorating is a member of a type that actually
3597+
// conforms to the specified protocol.
3598+
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
3599+
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
3600+
if (!OtherPD->inheritsFrom(PD)) {
35993601
diagnose(attr->getLocation(),
3600-
diag::implements_attr_protocol_lacks_member,
3601-
PD->getName(), attr->getMemberName())
3602-
.highlight(attr->getMemberNameLoc().getSourceRange());
3603-
}
3604-
3605-
// Check that the decl we're decorating is a member of a type that actually
3606-
// conforms to the specified protocol.
3607-
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
3608-
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
3609-
if (!OtherPD->inheritsFrom(PD)) {
3610-
diagnose(attr->getLocation(),
3611-
diag::implements_attr_protocol_not_conformed_to,
3612-
NTD->getName(), PD->getName())
3613-
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3614-
}
3615-
} else {
3616-
SmallVector<ProtocolConformance *, 2> conformances;
3617-
if (!NTD->lookupConformance(PD, conformances)) {
3618-
diagnose(attr->getLocation(),
3619-
diag::implements_attr_protocol_not_conformed_to,
3620-
NTD->getName(), PD->getName())
3621-
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3622-
}
3602+
diag::implements_attr_protocol_not_conformed_to,
3603+
NTD->getName(), PD->getName())
3604+
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
36233605
}
36243606
} else {
3625-
diagnose(attr->getLocation(), diag::implements_attr_non_protocol_type)
3626-
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3607+
SmallVector<ProtocolConformance *, 2> conformances;
3608+
if (!NTD->lookupConformance(PD, conformances)) {
3609+
diagnose(attr->getLocation(),
3610+
diag::implements_attr_protocol_not_conformed_to,
3611+
NTD->getName(), PD->getName())
3612+
.highlight(attr->getProtocolTypeRepr()->getSourceRange());
3613+
}
36273614
}
36283615
}
36293616

Diff for: lib/Sema/TypeCheckProtocol.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -1237,11 +1237,9 @@ witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
12371237
assert(requirement->isProtocolRequirement());
12381238
auto *PD = cast<ProtocolDecl>(requirement->getDeclContext());
12391239
if (auto A = witness->getAttrs().getAttribute<ImplementsAttr>()) {
1240-
if (Type T = A->getProtocolType()) {
1241-
if (auto ProtoTy = T->getAs<ProtocolType>()) {
1242-
if (ProtoTy->getDecl() == PD) {
1243-
return A->getMemberName() == requirement->getName();
1244-
}
1240+
if (auto *OtherPD = A->getProtocol(witness->getDeclContext())) {
1241+
if (OtherPD == PD) {
1242+
return A->getMemberName() == requirement->getName();
12451243
}
12461244
}
12471245
}

0 commit comments

Comments
 (0)