Skip to content

Commit 50ddd7c

Browse files
authored
Merge pull request #36575 from hborla/simd-arithmetic-operator-partition
[ConstraintSystem] Treat arithmetic SIMD operators like other generic operators when partitioning an overload set.
2 parents 371cdcb + 06a7950 commit 50ddd7c

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

include/swift/AST/KnownProtocols.def

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ PROTOCOL(Error)
6969
PROTOCOL_(ErrorCodeProtocol)
7070
PROTOCOL(OptionSet)
7171
PROTOCOL(CaseIterable)
72+
PROTOCOL(SIMD)
7273
PROTOCOL(SIMDScalar)
7374
PROTOCOL(BinaryInteger)
7475
PROTOCOL(RangeReplaceableCollection)

lib/IRGen/GenMeta.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -5251,6 +5251,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
52515251
case KnownProtocolKind::Hashable:
52525252
case KnownProtocolKind::CaseIterable:
52535253
case KnownProtocolKind::Comparable:
5254+
case KnownProtocolKind::SIMD:
52545255
case KnownProtocolKind::SIMDScalar:
52555256
case KnownProtocolKind::BinaryInteger:
52565257
case KnownProtocolKind::ObjectiveCBridgeable:

lib/Sema/CSSolver.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,7 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
19461946
SmallVector<unsigned, 4> concreteOverloads;
19471947
SmallVector<unsigned, 4> numericOverloads;
19481948
SmallVector<unsigned, 4> sequenceOverloads;
1949+
SmallVector<unsigned, 4> simdOverloads;
19491950
SmallVector<unsigned, 4> otherGenericOverloads;
19501951

19511952
auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool {
@@ -1967,7 +1968,10 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
19671968
unsigned index = *iter;
19681969
auto *decl = Choices[index]->getOverloadChoice().getDecl();
19691970
auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl();
1970-
if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
1971+
1972+
if (isSIMDOperator(decl)) {
1973+
simdOverloads.push_back(index);
1974+
} else if (!decl->getInterfaceType()->is<GenericFunctionType>()) {
19711975
concreteOverloads.push_back(index);
19721976
} else if (refinesOrConformsTo(nominal, KnownProtocolKind::AdditiveArithmetic)) {
19731977
numericOverloads.push_back(index);
@@ -2024,11 +2028,20 @@ void DisjunctionChoiceProducer::partitionGenericOperators(
20242028
sequenceOverloads.clear();
20252029
break;
20262030
}
2031+
2032+
if (TypeChecker::conformsToKnownProtocol(
2033+
argType, KnownProtocolKind::SIMD,
2034+
CS.DC->getParentModule())) {
2035+
first = std::copy(simdOverloads.begin(), simdOverloads.end(), first);
2036+
simdOverloads.clear();
2037+
break;
2038+
}
20272039
}
20282040

20292041
first = std::copy(otherGenericOverloads.begin(), otherGenericOverloads.end(), first);
20302042
first = std::copy(numericOverloads.begin(), numericOverloads.end(), first);
20312043
first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first);
2044+
first = std::copy(simdOverloads.begin(), simdOverloads.end(), first);
20322045
}
20332046

20342047
void DisjunctionChoiceProducer::partitionDisjunction(
@@ -2113,7 +2126,8 @@ void DisjunctionChoiceProducer::partitionDisjunction(
21132126
}
21142127

21152128
// Partition SIMD operators.
2116-
if (isOperatorDisjunction(Disjunction)) {
2129+
if (isOperatorDisjunction(Disjunction) &&
2130+
!Choices[0]->getOverloadChoice().getName().getBaseIdentifier().isArithmeticOperator()) {
21172131
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
21182132
if (isSIMDOperator(constraint->getOverloadChoice().getDecl())) {
21192133
simdOperators.push_back(index);

0 commit comments

Comments
 (0)