Skip to content

Commit 53e06d6

Browse files
committed
AST: Factor out a new Requirement::getProtocolDecl() utility method
1 parent 8ae1b76 commit 53e06d6

24 files changed

+74
-107
lines changed

include/swift/AST/Requirement.h

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class Requirement
7474
llvm_unreachable("Unhandled RequirementKind in switch.");
7575
}
7676

77+
ProtocolDecl *getProtocolDecl() const;
78+
7779
SWIFT_DEBUG_DUMP;
7880
void dump(raw_ostream &out) const;
7981
void print(raw_ostream &os, const PrintOptions &opts) const;

include/swift/IRGen/Linking.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,7 @@ class LinkEntity {
565565
for (const auto &reqt : proto->getRequirementSignature()) {
566566
if (reqt.getKind() == RequirementKind::Conformance &&
567567
reqt.getFirstType()->getCanonicalType() == associatedType &&
568-
reqt.getSecondType()->castTo<ProtocolType>()->getDecl() ==
569-
requirement) {
568+
reqt.getProtocolDecl() == requirement) {
570569
return index;
571570
}
572571
++index;
@@ -590,7 +589,7 @@ class LinkEntity {
590589
auto &reqt = proto->getRequirementSignature()[index];
591590
assert(reqt.getKind() == RequirementKind::Conformance);
592591
return { reqt.getFirstType()->getCanonicalType(),
593-
reqt.getSecondType()->castTo<ProtocolType>()->getDecl() };
592+
reqt.getProtocolDecl() };
594593
}
595594

596595
static std::pair<CanType, ProtocolDecl*>

include/swift/SIL/SILWitnessVisitor.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ template <class T> class SILWitnessVisitor : public ASTVisitor<T> {
6464
case RequirementKind::Conformance: {
6565
auto type = reqt.getFirstType()->getCanonicalType();
6666
assert(type->isTypeParameter());
67-
auto requirement =
68-
cast<ProtocolType>(reqt.getSecondType()->getCanonicalType())
69-
->getDecl();
67+
auto requirement = reqt.getProtocolDecl();
7068

7169
// ObjC protocols do not have witnesses.
7270
if (!Lowering::TypeConverter::protocolRequiresWitnessTable(requirement))

lib/AST/ASTContext.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,7 @@ FuncDecl *ASTContext::getPlusFunctionOnRangeReplaceableCollection() const {
709709
continue;
710710
for (auto Req: FD->getGenericRequirements()) {
711711
if (Req.getKind() == RequirementKind::Conformance &&
712-
Req.getSecondType()->getNominalOrBoundGenericNominal() ==
713-
getRangeReplaceableCollectionDecl()) {
712+
Req.getProtocolDecl() == getRangeReplaceableCollectionDecl()) {
714713
getImpl().PlusFunctionOnRangeReplaceableCollection = FD;
715714
}
716715
}

lib/AST/ASTMangler.cpp

+4-7
Original file line numberDiff line numberDiff line change
@@ -1584,8 +1584,7 @@ static bool containsRetroactiveConformance(
15841584
for (auto requirement : rootConformance->getConditionalRequirements()) {
15851585
if (requirement.getKind() != RequirementKind::Conformance)
15861586
continue;
1587-
ProtocolDecl *proto =
1588-
requirement.getSecondType()->castTo<ProtocolType>()->getDecl();
1587+
ProtocolDecl *proto = requirement.getProtocolDecl();
15891588
auto conformance = subMap.lookupConformance(
15901589
requirement.getFirstType()->getCanonicalType(), proto);
15911590
if (conformance.isInvalid()) {
@@ -2598,8 +2597,7 @@ void ASTMangler::appendRequirement(const Requirement &reqt) {
25982597
case RequirementKind::Layout: {
25992598
} break;
26002599
case RequirementKind::Conformance: {
2601-
Type SecondTy = reqt.getSecondType();
2602-
appendProtocolName(SecondTy->castTo<ProtocolType>()->getDecl());
2600+
appendProtocolName(reqt.getProtocolDecl());
26032601
} break;
26042602
case RequirementKind::Superclass:
26052603
case RequirementKind::SameType: {
@@ -3047,7 +3045,7 @@ static unsigned conformanceRequirementIndex(
30473045
continue;
30483046

30493047
if (req.getFirstType()->isEqual(entry.first) &&
3050-
req.getSecondType()->castTo<ProtocolType>()->getDecl() == entry.second)
3048+
req.getProtocolDecl() == entry.second)
30513049
return result;
30523050

30533051
++result;
@@ -3175,8 +3173,7 @@ void ASTMangler::appendConcreteProtocolConformance(
31753173
if (type->hasArchetype())
31763174
type = type->mapTypeOutOfContext();
31773175
CanType canType = type->getCanonicalType(CurGenericSignature);
3178-
auto proto =
3179-
conditionalReq.getSecondType()->castTo<ProtocolType>()->getDecl();
3176+
auto proto = conditionalReq.getProtocolDecl();
31803177

31813178
ProtocolConformanceRef conformance;
31823179

lib/AST/ASTPrinter.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -2533,8 +2533,7 @@ static bool usesFeatureRethrowsProtocol(
25332533
->getGenericSignatureOfContext()) {
25342534
for (const auto &req : genericSig->getRequirements()) {
25352535
if (req.getKind() == RequirementKind::Conformance &&
2536-
usesFeatureRethrowsProtocol(
2537-
req.getSecondType()->getAs<ProtocolType>()->getDecl(), checked))
2536+
usesFeatureRethrowsProtocol(req.getProtocolDecl(), checked))
25382537
return true;
25392538
}
25402539
}

lib/AST/ASTVerifier.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -2752,8 +2752,7 @@ class Verifier : public ASTWalker {
27522752
abort();
27532753
}
27542754

2755-
auto reqProto =
2756-
req.getSecondType()->castTo<ProtocolType>()->getDecl();
2755+
auto reqProto = req.getProtocolDecl();
27572756
if (reqProto != conformances[idx].getRequirement()) {
27582757
Out << "error: wrong protocol in signature conformances: have "
27592758
<< conformances[idx].getRequirement()->getName().str()

lib/AST/GenericSignature.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,8 @@ CanGenericSignature::getCanonical(TypeArrayView<GenericTypeParamType> params,
308308
assert(reqt.getKind() == RequirementKind::Conformance &&
309309
"Only conformance requirements can have multiples");
310310

311-
auto prevProto =
312-
prevReqt.getSecondType()->castTo<ProtocolType>()->getDecl();
313-
auto proto = reqt.getSecondType()->castTo<ProtocolType>()->getDecl();
311+
auto prevProto = prevReqt.getProtocolDecl();
312+
auto proto = reqt.getProtocolDecl();
314313
assert(TypeDecl::compare(prevProto, proto) < 0 &&
315314
"Out-of-order conformance requirements");
316315
}
@@ -545,8 +544,7 @@ bool GenericSignatureImpl::isRequirementSatisfied(
545544

546545
switch (requirement.getKind()) {
547546
case RequirementKind::Conformance: {
548-
auto protocolType = requirement.getSecondType()->castTo<ProtocolType>();
549-
auto protocol = protocolType->getDecl();
547+
auto *protocol = requirement.getProtocolDecl();
550548

551549
if (canFirstType->isTypeParameter())
552550
return requiresProtocol(canFirstType, protocol);
@@ -746,8 +744,7 @@ static bool hasConformanceInSignature(ArrayRef<Requirement> requirements,
746744
for (const auto &req: requirements) {
747745
if (req.getKind() == RequirementKind::Conformance &&
748746
req.getFirstType()->isEqual(subjectType) &&
749-
req.getSecondType()->castTo<ProtocolType>()->getDecl()
750-
== proto) {
747+
req.getProtocolDecl() == proto) {
751748
return true;
752749
}
753750
}
@@ -1007,7 +1004,6 @@ bool Requirement::isCanonical() const {
10071004
return true;
10081005
}
10091006

1010-
10111007
/// Get the canonical form of this requirement.
10121008
Requirement Requirement::getCanonical() const {
10131009
Type firstType = getFirstType();
@@ -1029,3 +1025,8 @@ Requirement Requirement::getCanonical() const {
10291025
}
10301026
llvm_unreachable("Unhandled RequirementKind in switch");
10311027
}
1028+
1029+
ProtocolDecl *Requirement::getProtocolDecl() const {
1030+
assert(getKind() == RequirementKind::Conformance);
1031+
return getSecondType()->castTo<ProtocolType>()->getDecl();
1032+
}

lib/AST/NameLookup.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -2700,8 +2700,7 @@ swift::getDirectlyInheritedNominalTypeDecls(
27002700
if (!req.getFirstType()->isEqual(protoSelfTy))
27012701
continue;
27022702

2703-
result.emplace_back(req.getSecondType()->castTo<ProtocolType>()->getDecl(),
2704-
loc);
2703+
result.emplace_back(req.getProtocolDecl(), loc);
27052704
}
27062705
return result;
27072706
}

lib/AST/ProtocolConformance.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,7 @@ void NormalProtocolConformance::setSignatureConformances(
572572
"Should have interface types here");
573573
assert(idx < conformances.size());
574574
assert(conformances[idx].isInvalid() ||
575-
conformances[idx].getRequirement() ==
576-
req.getSecondType()->castTo<ProtocolType>()->getDecl());
575+
conformances[idx].getRequirement() == req.getProtocolDecl());
577576
++idx;
578577
}
579578
}
@@ -772,7 +771,7 @@ NormalProtocolConformance::getAssociatedConformance(Type assocType,
772771
if (reqt.getKind() == RequirementKind::Conformance) {
773772
// Is this the conformance we're looking for?
774773
if (reqt.getFirstType()->isEqual(assocType) &&
775-
reqt.getSecondType()->castTo<ProtocolType>()->getDecl() == protocol)
774+
reqt.getProtocolDecl() == protocol)
776775
return getSignatureConformances()[conformanceIndex];
777776

778777
++conformanceIndex;
@@ -840,7 +839,7 @@ void NormalProtocolConformance::finishSignatureConformances() {
840839
auto *depMemTy = origTy->castTo<DependentMemberType>();
841840
substTy = recursivelySubstituteBaseType(module, this, depMemTy);
842841
}
843-
auto reqProto = req.getSecondType()->castTo<ProtocolType>()->getDecl();
842+
auto reqProto = req.getProtocolDecl();
844843

845844
// Looking up a conformance for a contextual type and mapping the
846845
// conformance context produces a more accurate result than looking

lib/AST/SubstitutionMap.cpp

+25-35
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,7 @@ SubstitutionMap SubstitutionMap::get(GenericSignature genericSig,
234234

235235
CanType depTy = req.getFirstType()->getCanonicalType();
236236
auto replacement = depTy.subst(subs, lookupConformance);
237-
auto protoType = req.getSecondType()->castTo<ProtocolType>();
238-
auto proto = protoType->getDecl();
237+
auto *proto = req.getProtocolDecl();
239238
auto conformance = lookupConformance(depTy, replacement, proto);
240239
conformances.push_back(conformance);
241240
}
@@ -334,38 +333,27 @@ SubstitutionMap::lookupConformance(CanType type, ProtocolDecl *proto) const {
334333

335334
auto genericSig = getGenericSignature();
336335

337-
// Fast path
338-
unsigned index = 0;
339-
for (auto reqt : genericSig->getRequirements()) {
340-
if (reqt.getKind() == RequirementKind::Conformance) {
341-
if (reqt.getFirstType()->isEqual(type) &&
342-
reqt.getSecondType()->isEqual(proto->getDeclaredInterfaceType()))
343-
return getConformances()[index];
336+
auto getSignatureConformance =
337+
[&](Type type, ProtocolDecl *proto) -> Optional<ProtocolConformanceRef> {
338+
unsigned index = 0;
339+
for (auto reqt : genericSig->getRequirements()) {
340+
if (reqt.getKind() == RequirementKind::Conformance) {
341+
if (reqt.getFirstType()->isEqual(type) &&
342+
reqt.getProtocolDecl() == proto)
343+
return getConformances()[index];
344344

345-
++index;
346-
}
347-
}
348-
349-
// Retrieve the starting conformance from the conformance map.
350-
auto getInitialConformance =
351-
[&](Type type, ProtocolDecl *proto) -> ProtocolConformanceRef {
352-
unsigned conformanceIndex = 0;
353-
for (const auto &req : getGenericSignature()->getRequirements()) {
354-
if (req.getKind() != RequirementKind::Conformance)
355-
continue;
356-
357-
// Is this the conformance we're looking for?
358-
if (req.getFirstType()->isEqual(type) &&
359-
req.getSecondType()->castTo<ProtocolType>()->getDecl() == proto) {
360-
return getConformances()[conformanceIndex];
345+
++index;
361346
}
362-
363-
++conformanceIndex;
364347
}
365348

366-
return ProtocolConformanceRef::forInvalid();
349+
return None;
367350
};
368351

352+
// Fast path -- check if the generic signature directly states the
353+
// conformance.
354+
if (auto directConformance = getSignatureConformance(type, proto))
355+
return *directConformance;
356+
369357
// Check whether the superclass conforms.
370358
if (auto superclass = genericSig->getSuperclassBound(type)) {
371359
LookUpConformanceInSignature lookup(getGenericSignature().getPointer());
@@ -388,16 +376,16 @@ SubstitutionMap::lookupConformance(CanType type, ProtocolDecl *proto) const {
388376
for (const auto &step : accessPath) {
389377
// For the first step, grab the initial conformance.
390378
if (conformance.isInvalid()) {
391-
conformance = getInitialConformance(step.first, step.second);
392-
if (conformance.isInvalid())
393-
return ProtocolConformanceRef::forInvalid();
379+
if (auto initialConformance = getSignatureConformance(
380+
step.first, step.second)) {
381+
conformance = *initialConformance;
382+
continue;
383+
}
394384

395-
continue;
385+
// We couldn't find the initial conformance, fail.
386+
return ProtocolConformanceRef::forInvalid();
396387
}
397388

398-
if (conformance.isInvalid())
399-
return conformance;
400-
401389
// If we've hit an abstract conformance, everything from here on out is
402390
// abstract.
403391
// FIXME: This may not always be true, but it holds for now.
@@ -436,6 +424,8 @@ SubstitutionMap::lookupConformance(CanType type, ProtocolDecl *proto) const {
436424

437425
// Get the associated conformance.
438426
conformance = concrete->getAssociatedConformance(step.first, step.second);
427+
if (conformance.isInvalid())
428+
return conformance;
439429
}
440430

441431
return conformance;

lib/AST/Type.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2012,7 +2012,7 @@ class IsBindableVisitor
20122012
if (req.getKind() != RequirementKind::Conformance) continue;
20132013

20142014
auto canTy = req.getFirstType()->getCanonicalType();
2015-
auto *proto = req.getSecondType()->castTo<ProtocolType>()->getDecl();
2015+
auto *proto = req.getProtocolDecl();
20162016
auto origConf = origSubMap.lookupConformance(canTy, proto);
20172017
auto substConf = substSubMap.lookupConformance(canTy, proto);
20182018

lib/IRGen/GenMeta.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -5243,8 +5243,7 @@ GenericRequirementsMetadata irgen::addGenericRequirements(
52435243
break;
52445244

52455245
case RequirementKind::Conformance: {
5246-
auto protocol = requirement.getSecondType()->castTo<ProtocolType>()
5247-
->getDecl();
5246+
auto protocol = requirement.getProtocolDecl();
52485247

52495248
// Marker protocols do not record generic requirements at all.
52505249
if (protocol->isMarkerProtocol()) {

lib/IRGen/GenProto.cpp

+3-5
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,7 @@ irgen::enumerateGenericSignatureRequirements(CanGenericSignature signature,
266266

267267
case RequirementKind::Conformance: {
268268
auto type = CanType(reqt.getFirstType());
269-
auto protocol =
270-
cast<ProtocolType>(CanType(reqt.getSecondType()))->getDecl();
269+
auto protocol = reqt.getProtocolDecl();
271270
if (Lowering::TypeConverter::protocolRequiresWitnessTable(protocol)) {
272271
callback({type, protocol});
273272
}
@@ -926,7 +925,7 @@ static bool isDependentConformance(
926925
if (req.getKind() != RequirementKind::Conformance)
927926
continue;
928927

929-
auto assocProtocol = req.getSecondType()->castTo<ProtocolType>()->getDecl();
928+
auto assocProtocol = req.getProtocolDecl();
930929
if (assocProtocol->isObjC())
931930
continue;
932931

@@ -2801,8 +2800,7 @@ void NecessaryBindings::addAbstractConditionalRequirements(
28012800
for (auto req : condRequirements) {
28022801
if (req.getKind() != RequirementKind::Conformance)
28032802
continue;
2804-
auto *proto =
2805-
req.getSecondType()->castTo<ProtocolType>()->getDecl();
2803+
auto *proto = req.getProtocolDecl();
28062804
auto ty = req.getFirstType()->getCanonicalType();
28072805
auto archetype = dyn_cast<ArchetypeType>(ty);
28082806
if (!archetype)

lib/SIL/IR/SILWitnessTable.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ bool SILWitnessTable::enumerateWitnessTableConditionalConformances(
187187
if (req.getKind() != RequirementKind::Conformance)
188188
continue;
189189

190-
auto proto = req.getSecondType()->castTo<ProtocolType>()->getDecl();
190+
auto proto = req.getProtocolDecl();
191191

192192
if (Lowering::TypeConverter::protocolRequiresWitnessTable(proto)) {
193193
if (fn(conformanceIndex, req.getFirstType()->getCanonicalType(), proto))

lib/SILOptimizer/Mandatory/Differentiation.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,7 @@ static bool diagnoseUnsatisfiedRequirements(ADContext &context,
266266
}
267267
// Check conformance requirements.
268268
case RequirementKind::Conformance: {
269-
auto protocolType = req.getSecondType()->castTo<ProtocolType>();
270-
auto protocol = protocolType->getDecl();
269+
auto *protocol = req.getProtocolDecl();
271270
assert(protocol && "Expected protocol in generic signature requirement");
272271
// If the first type does not conform to the second type in the current
273272
// module, then record the unsatisfied requirement.

lib/Sema/ConstraintSystem.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -1381,14 +1381,13 @@ void ConstraintSystem::openGenericRequirements(
13811381
auto kind = req.getKind();
13821382
switch (kind) {
13831383
case RequirementKind::Conformance: {
1384-
auto proto = req.getSecondType()->castTo<ProtocolType>();
1385-
auto protoDecl = proto->getDecl();
1384+
auto protoDecl = req.getProtocolDecl();
13861385
// Determine whether this is the protocol 'Self' constraint we should
13871386
// skip.
13881387
if (skipProtocolSelfConstraint && protoDecl == outerDC &&
13891388
protoDecl->getSelfInterfaceType()->isEqual(req.getFirstType()))
13901389
continue;
1391-
openedReq = Requirement(kind, openedFirst, proto);
1390+
openedReq = Requirement(kind, openedFirst, req.getSecondType());
13921391
break;
13931392
}
13941393
case RequirementKind::Superclass:

0 commit comments

Comments
 (0)