Skip to content

Commit 731d903

Browse files
committed
Sema: Add ConstraintKind::ShapeOf
1 parent 971d581 commit 731d903

File tree

7 files changed

+128
-12
lines changed

7 files changed

+128
-12
lines changed

include/swift/Sema/Constraint.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ enum class ConstraintKind : char {
223223
///
224224
/// Binds the RHS type to a tuple of the params of a function typed LHS. Note
225225
/// this discards function parameter flags.
226-
BindTupleOfFunctionParams
226+
BindTupleOfFunctionParams,
227+
/// The first type is a type pack, and the second type is its reduced shape.
228+
ShapeOf,
227229
};
228230

229231
/// Classification of the different kinds of constraints.
@@ -705,6 +707,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
705707
case ConstraintKind::KeyPathApplication:
706708
case ConstraintKind::Defaultable:
707709
case ConstraintKind::BindTupleOfFunctionParams:
710+
case ConstraintKind::ShapeOf:
708711
return ConstraintClassification::TypeProperty;
709712

710713
case ConstraintKind::Disjunction:

include/swift/Sema/ConstraintSystem.h

+6
Original file line numberDiff line numberDiff line change
@@ -5613,6 +5613,12 @@ class ConstraintSystem {
56135613
ASTNode element, ContextualTypeInfo context, bool isDiscarded,
56145614
TypeMatchOptions flags, ConstraintLocatorBuilder locator);
56155615

5616+
/// Simplify a shape constraint by binding the reduced shape of the
5617+
/// left hand side to the right hand side.
5618+
SolutionKind simplifyShapeOfConstraint(
5619+
Type type1, Type type2, TypeMatchOptions flags,
5620+
ConstraintLocatorBuilder locator);
5621+
56165622
public: // FIXME: Public for use by static functions.
56175623
/// Simplify a conversion constraint with a fix applied to it.
56185624
SolutionKind simplifyFixConstraint(ConstraintFix *fix, Type type1, Type type2,

lib/Sema/CSBindings.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,7 @@ void PotentialBindings::infer(Constraint *constraint) {
13971397
case ConstraintKind::Conjunction:
13981398
case ConstraintKind::BindTupleOfFunctionParams:
13991399
case ConstraintKind::PackElementOf:
1400+
case ConstraintKind::ShapeOf:
14001401
// Constraints from which we can't do anything.
14011402
break;
14021403

lib/Sema/CSSimplify.cpp

+95-8
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
24252425
case ConstraintKind::SyntacticElement:
24262426
case ConstraintKind::BindTupleOfFunctionParams:
24272427
case ConstraintKind::PackElementOf:
2428+
case ConstraintKind::ShapeOf:
24282429
llvm_unreachable("Bad constraint kind in matchTupleTypes()");
24292430
}
24302431

@@ -2597,6 +2598,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
25972598
case ConstraintKind::SyntacticElement:
25982599
case ConstraintKind::BindTupleOfFunctionParams:
25992600
case ConstraintKind::PackElementOf:
2601+
case ConstraintKind::ShapeOf:
26002602
return true;
26012603
}
26022604

@@ -3013,6 +3015,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
30133015
case ConstraintKind::SyntacticElement:
30143016
case ConstraintKind::BindTupleOfFunctionParams:
30153017
case ConstraintKind::PackElementOf:
3018+
case ConstraintKind::ShapeOf:
30163019
llvm_unreachable("Not a relational constraint");
30173020
}
30183021

@@ -6366,6 +6369,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
63666369
case ConstraintKind::SyntacticElement:
63676370
case ConstraintKind::BindTupleOfFunctionParams:
63686371
case ConstraintKind::PackElementOf:
6372+
case ConstraintKind::ShapeOf:
63696373
llvm_unreachable("Not a relational constraint");
63706374
}
63716375
}
@@ -7442,17 +7446,14 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifySubclassOfConstraint(
74427446
return SolutionKind::Solved;
74437447
}
74447448

7445-
auto formUnsolved = [&](bool activate = false) {
7449+
auto formUnsolved = [&]() {
74467450
// If we're supposed to generate constraints, do so.
74477451
if (flags.contains(TMF_GenerateConstraints)) {
7448-
auto *conformance = Constraint::create(
7452+
auto *subclassOf = Constraint::create(
74497453
*this, ConstraintKind::SubclassOf, type, classType,
74507454
getConstraintLocator(locator));
74517455

7452-
addUnsolvedConstraint(conformance);
7453-
if (activate)
7454-
activateConstraint(conformance);
7455-
7456+
addUnsolvedConstraint(subclassOf);
74567457
return SolutionKind::Solved;
74577458
}
74587459

@@ -12510,6 +12511,74 @@ ConstraintSystem::simplifyDynamicCallableApplicableFnConstraint(
1251012511
return SolutionKind::Solved;
1251112512
}
1251212513

12514+
static Type getReducedShape(Type type) {
12515+
// Pack archetypes know their reduced shape
12516+
if (auto *packArchetype = type->getAs<PackArchetypeType>())
12517+
return packArchetype->getShape();
12518+
12519+
// Reduced shape of pack is computed recursively
12520+
if (auto *packType = type->getAs<PackType>()) {
12521+
auto &ctx = type->getASTContext();
12522+
SmallVector<Type, 2> elts;
12523+
12524+
for (auto elt : packType->getElementTypes()) {
12525+
// T... => shape(T)...
12526+
if (auto *packExpansionType = elt->getAs<PackExpansionType>()) {
12527+
auto shapeType = getReducedShape(packExpansionType->getCountType());
12528+
assert(shapeType && "Should not end up here if pack type's shape "
12529+
"is still potentially unknown");
12530+
elts.push_back(PackExpansionType::get(shapeType, shapeType));
12531+
}
12532+
12533+
// Use () as a placeholder for scalar shape
12534+
elts.push_back(ctx.TheEmptyTupleType);
12535+
}
12536+
12537+
return PackType::get(ctx, elts);
12538+
}
12539+
12540+
// Getting the shape of any other type is an error.
12541+
return Type();
12542+
}
12543+
12544+
ConstraintSystem::SolutionKind ConstraintSystem::simplifyShapeOfConstraint(
12545+
Type type1, Type type2, TypeMatchOptions flags,
12546+
ConstraintLocatorBuilder locator) {
12547+
// Recursively replace all type variables with fixed bindings if
12548+
// possible.
12549+
type1 = simplifyType(type1, flags);
12550+
12551+
auto formUnsolved = [&]() {
12552+
// If we're supposed to generate constraints, do so.
12553+
if (flags.contains(TMF_GenerateConstraints)) {
12554+
auto *shapeOf = Constraint::create(
12555+
*this, ConstraintKind::ShapeOf, type1, type2,
12556+
getConstraintLocator(locator));
12557+
12558+
addUnsolvedConstraint(shapeOf);
12559+
return SolutionKind::Solved;
12560+
}
12561+
12562+
return SolutionKind::Unsolved;
12563+
};
12564+
12565+
// We can't compute a reduced shape if the input type still
12566+
// contains type variables that might bind to pack archetypes.
12567+
SmallPtrSet<TypeVariableType *, 2> typeVars;
12568+
type1->getTypeVariables(typeVars);
12569+
for (auto *typeVar : typeVars) {
12570+
if (typeVar->getImpl().canBindToPack())
12571+
return formUnsolved();
12572+
}
12573+
12574+
if (Type shape = getReducedShape(type1)) {
12575+
addConstraint(ConstraintKind::Bind, shape, type2, locator);
12576+
return SolutionKind::Solved;
12577+
}
12578+
12579+
return SolutionKind::Error;
12580+
}
12581+
1251312582
static llvm::PointerIntPair<Type, 3, unsigned>
1251412583
getBaseTypeForPointer(TypeBase *type) {
1251512584
unsigned unwrapCount = 0;
@@ -13847,6 +13916,9 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
1384713916
case ConstraintKind::PackElementOf:
1384813917
return simplifyPackElementOfConstraint(first, second, subflags, locator);
1384913918

13919+
case ConstraintKind::ShapeOf:
13920+
return simplifyShapeOfConstraint(first, second, subflags, locator);
13921+
1385013922
case ConstraintKind::ValueMember:
1385113923
case ConstraintKind::UnresolvedValueMember:
1385213924
case ConstraintKind::ValueWitness:
@@ -14010,8 +14082,18 @@ void ConstraintSystem::addConstraint(Requirement req,
1401014082
bool conformsToAnyObject = false;
1401114083
Optional<ConstraintKind> kind;
1401214084
switch (req.getKind()) {
14013-
case RequirementKind::SameShape:
14014-
llvm_unreachable("Same-shape requirement not supported here");
14085+
case RequirementKind::SameShape: {
14086+
auto type1 = req.getFirstType();
14087+
auto type2 = req.getSecondType();
14088+
14089+
// FIXME: Locator for diagnostics
14090+
auto typeVar = createTypeVariable(getConstraintLocator(locator),
14091+
TVO_CanBindToPack);
14092+
14093+
addConstraint(ConstraintKind::ShapeOf, type1, typeVar, locator);
14094+
addConstraint(ConstraintKind::ShapeOf, type2, typeVar, locator);
14095+
return;
14096+
}
1401514097

1401614098
case RequirementKind::Conformance:
1401714099
kind = ConstraintKind::ConformsTo;
@@ -14421,6 +14503,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
1442114503
return simplifyPackElementOfConstraint(
1442214504
constraint.getFirstType(), constraint.getSecondType(), /*flags*/None,
1442314505
constraint.getLocator());
14506+
14507+
case ConstraintKind::ShapeOf:
14508+
return simplifyShapeOfConstraint(
14509+
constraint.getFirstType(), constraint.getSecondType(), /*flags*/ None,
14510+
constraint.getLocator());
1442414511
}
1442514512

1442614513
llvm_unreachable("Unhandled ConstraintKind in switch.");

lib/Sema/Constraint.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
8080
case ConstraintKind::PropertyWrapper:
8181
case ConstraintKind::BindTupleOfFunctionParams:
8282
case ConstraintKind::PackElementOf:
83+
case ConstraintKind::ShapeOf:
8384
assert(!First.isNull());
8485
assert(!Second.isNull());
8586
break;
@@ -167,6 +168,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
167168
case ConstraintKind::SyntacticElement:
168169
case ConstraintKind::BindTupleOfFunctionParams:
169170
case ConstraintKind::PackElementOf:
171+
case ConstraintKind::ShapeOf:
170172
llvm_unreachable("Wrong constructor");
171173

172174
case ConstraintKind::KeyPath:
@@ -313,6 +315,7 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
313315
case ConstraintKind::PropertyWrapper:
314316
case ConstraintKind::BindTupleOfFunctionParams:
315317
case ConstraintKind::PackElementOf:
318+
case ConstraintKind::ShapeOf:
316319
return create(cs, getKind(), getFirstType(), getSecondType(), getLocator());
317320

318321
case ConstraintKind::ApplicableFunction:
@@ -553,6 +556,10 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm, unsigned inden
553556
Out << " element of pack expansion pattern ";
554557
break;
555558

559+
case ConstraintKind::ShapeOf:
560+
Out << " shape of ";
561+
break;
562+
556563
case ConstraintKind::Disjunction:
557564
llvm_unreachable("disjunction handled above");
558565
case ConstraintKind::Conjunction:
@@ -718,6 +725,7 @@ gatherReferencedTypeVars(Constraint *constraint,
718725
case ConstraintKind::PropertyWrapper:
719726
case ConstraintKind::BindTupleOfFunctionParams:
720727
case ConstraintKind::PackElementOf:
728+
case ConstraintKind::ShapeOf:
721729
constraint->getFirstType()->getTypeVariables(typeVars);
722730
constraint->getSecondType()->getTypeVariables(typeVars);
723731
break;

lib/Sema/ConstraintSystem.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1773,8 +1773,6 @@ void ConstraintSystem::openGenericRequirement(
17731773

17741774
auto kind = req.getKind();
17751775
switch (kind) {
1776-
case RequirementKind::SameShape:
1777-
llvm_unreachable("Same-shape requirement not supported here");
17781776
case RequirementKind::Conformance: {
17791777
auto protoDecl = req.getProtocolDecl();
17801778
// Determine whether this is the protocol 'Self' constraint we should
@@ -1788,6 +1786,7 @@ void ConstraintSystem::openGenericRequirement(
17881786
}
17891787
case RequirementKind::Superclass:
17901788
case RequirementKind::SameType:
1789+
case RequirementKind::SameShape:
17911790
openedReq = Requirement(kind, openedFirst, substFn(req.getSecondType()));
17921791
break;
17931792
case RequirementKind::Layout:

test/Constraints/variadic_generic_constraints.swift

+13-1
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,16 @@ func takesParallelSequences<T..., U...>(t: T..., u: U...) where T: Sequence, U:
5151
takesParallelSequences() // ok
5252
takesParallelSequences(t: Array<Int>(), u: Set<Int>()) // ok
5353
takesParallelSequences(t: Array<String>(), Set<Int>(), u: Set<String>(), Array<Int>()) // ok
54-
takesParallelSequences(t: Array<String>(), Set<Int>(), u: Array<Int>(), Set<String>()) // expected-error {{global function 'takesParallelSequences(t:u:)' requires the types 'String' and 'Int' be equivalent}}
54+
takesParallelSequences(t: Array<String>(), Set<Int>(), u: Array<Int>(), Set<String>()) // expected-error {{global function 'takesParallelSequences(t:u:)' requires the types 'String' and 'Int' be equivalent}}
55+
56+
// Same-shape requirements
57+
58+
func zip<T..., U...>(t: T..., u: U...) -> ((T, U)...) {}
59+
60+
let _ = zip() // ok
61+
let _ = zip(t: 1, u: "hi") // ok
62+
let _ = zip(t: 1, 2, u: "hi", "hello") // ok
63+
let _ = zip(t: 1, 2, 3, u: "hi", "hello", "greetings") // ok
64+
65+
// FIXME: Bad diagnostic
66+
let _ = zip(t: 1, u: "hi", "hello", "greetings") // expected-error {{type of expression is ambiguous without more context}}

0 commit comments

Comments
 (0)