Skip to content

Commit 23ace7d

Browse files
committed
[ConstraintSystem] Store declaration context in which application occurs in ApplicableFunction constraint
This is required because `ApplicableFunction` constraint can inject member reference constraints that require a declaration context. For example, `_ = { Double(...) }` would now produce a disjunction for `Double.init` where overload choice declaration contexts point to the closure instead of the enclosing context. This addresses a long-standing FIXME in `simplifyApplicableFnConstraint` and helps with disjunction optimizer because its correctness depends on correct identification of declaration contexts where applications happen.
1 parent d24f22e commit 23ace7d

File tree

5 files changed

+185
-35
lines changed

5 files changed

+185
-35
lines changed

include/swift/Sema/Constraint.h

+42-5
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,6 @@ class Constraint final : public llvm::ilist_node<Constraint>,
380380
/// The kind of function reference, for member references.
381381
unsigned TheFunctionRefInfo : 3;
382382

383-
/// The trailing closure matching for an applicable function constraint,
384-
/// if any. 0 = None, 1 = Forward, 2 = Backward.
385-
unsigned trailingClosureMatching : 2;
386-
387383
union {
388384
struct {
389385
/// The first type.
@@ -439,6 +435,21 @@ class Constraint final : public llvm::ilist_node<Constraint>,
439435
/// Identifies whether result of this node is unused.
440436
bool IsDiscarded;
441437
} SyntacticElement;
438+
439+
struct {
440+
/// The function type that is being applied where parameters
441+
/// represent argument types passed to callee and result type
442+
/// represents result type of the application.
443+
FunctionType *AppliedFn;
444+
/// The type being called, primarily a function type, but could
445+
/// be a metatype, a tuple or a nominal type.
446+
Type Callee;
447+
/// The trailing closure matching for an applicable function constraint,
448+
/// if any. 0 = None, 1 = Forward, 2 = Backward.
449+
unsigned TrailingClosureMatching : 2;
450+
/// The declaration context in which the application appears.
451+
DeclContext *UseDC;
452+
} Apply;
442453
};
443454

444455
/// The locator that describes where in the expression this
@@ -495,6 +506,11 @@ class Constraint final : public llvm::ilist_node<Constraint>,
495506
ConstraintLocator *locator,
496507
SmallPtrSetImpl<TypeVariableType *> &typeVars);
497508

509+
Constraint(FunctionType *appliedFn, Type calleeType,
510+
unsigned trailingClosureMatching, DeclContext *useDC,
511+
ConstraintLocator *locator,
512+
SmallPtrSetImpl<TypeVariableType *> &typeVars);
513+
498514
/// Retrieve the type variables buffer, for internal mutation.
499515
MutableArrayRef<TypeVariableType *> getTypeVariablesBuffer() {
500516
return { getTrailingObjects<TypeVariableType *>(), NumTypeVariables };
@@ -580,7 +596,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
580596

581597
/// Create a new Applicable Function constraint.
582598
static Constraint *createApplicableFunction(
583-
ConstraintSystem &cs, Type argumentFnType, Type calleeType,
599+
ConstraintSystem &cs, FunctionType *argumentFnType, Type calleeType,
584600
std::optional<TrailingClosureMatching> trailingClosureMatching,
585601
DeclContext *useDC, ConstraintLocator *locator);
586602

@@ -739,6 +755,9 @@ class Constraint final : public llvm::ilist_node<Constraint>,
739755
case ConstraintKind::SyntacticElement:
740756
llvm_unreachable("closure body element constraint has no type operands");
741757

758+
case ConstraintKind::ApplicableFunction:
759+
return Apply.AppliedFn;
760+
742761
default:
743762
return Types.First;
744763
}
@@ -758,6 +777,9 @@ class Constraint final : public llvm::ilist_node<Constraint>,
758777
case ConstraintKind::ValueWitness:
759778
return Member.Second;
760779

780+
case ConstraintKind::ApplicableFunction:
781+
return Apply.Callee;
782+
761783
default:
762784
return Types.Second;
763785
}
@@ -851,6 +873,21 @@ class Constraint final : public llvm::ilist_node<Constraint>,
851873
return Member.UseDC;
852874
}
853875

876+
FunctionType *getAppliedFunctionType() const {
877+
assert(Kind == ConstraintKind::ApplicableFunction);
878+
return Apply.AppliedFn;
879+
}
880+
881+
Type getCalleeType() const {
882+
assert(Kind == ConstraintKind::ApplicableFunction);
883+
return Apply.Callee;
884+
}
885+
886+
DeclContext *getApplicationDC() const {
887+
assert(Kind == ConstraintKind::ApplicableFunction);
888+
return Apply.UseDC;
889+
}
890+
854891
ASTNode getSyntacticElement() const {
855892
assert(Kind == ConstraintKind::SyntacticElement);
856893
return SyntacticElement.Element;

include/swift/Sema/ConstraintSystem.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -4965,7 +4965,7 @@ class ConstraintSystem {
49654965

49664966
/// Attempt to simplify the ApplicableFunction constraint.
49674967
SolutionKind simplifyApplicableFnConstraint(
4968-
Type type1, Type type2,
4968+
FunctionType *appliedFn, Type calleeTy,
49694969
std::optional<TrailingClosureMatching> trailingClosureMatching,
49704970
DeclContext *useDC,
49714971
TypeMatchOptions flags, ConstraintLocatorBuilder locator);

lib/Sema/CSSimplify.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -13107,16 +13107,12 @@ createImplicitRootForCallAsFunction(ConstraintSystem &cs, Type refType,
1310713107
}
1310813108

1310913109
ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
13110-
Type type1, Type type2,
13110+
FunctionType *func1, Type type2,
1311113111
std::optional<TrailingClosureMatching> trailingClosureMatching,
1311213112
DeclContext *useDC,
1311313113
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
1311413114
auto &ctx = getASTContext();
1311513115

13116-
// By construction, the left hand side is a type that looks like the
13117-
// following: $T1 -> $T2.
13118-
auto func1 = type1->castTo<FunctionType>();
13119-
1312013116
// Before stripping lvalue-ness and optional types, save the original second
1312113117
// type for handling `func callAsFunction` and `@dynamicCallable`
1312213118
// applications. This supports the following cases:
@@ -13173,7 +13169,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1317313169
auto formUnsolved = [&](bool activate = false) {
1317413170
if (flags.contains(TMF_GenerateConstraints)) {
1317513171
auto *application = Constraint::createApplicableFunction(
13176-
*this, type1, type2, trailingClosureMatching, useDC,
13172+
*this, func1, type2, trailingClosureMatching, useDC,
1317713173
getConstraintLocator(locator));
1317813174

1317913175
addUnsolvedConstraint(application);
@@ -13203,7 +13199,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1320313199
// If the types are obviously equivalent, we're done. This optimization
1320413200
// is not valid for operators though, where an inout parameter does not
1320513201
// have an explicit inout argument.
13206-
if (type1.getPointer() == desugar2) {
13202+
if (func1 == desugar2) {
1320713203
// Note that this could throw.
1320813204
recordPotentialThrowSite(
1320913205
PotentialThrowSite::Application, Type(desugar2), outerLocator);
@@ -13370,10 +13366,10 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1337013366

1337113367
auto applyLocator = getConstraintLocator(locator);
1337213368
auto forwardConstraint = Constraint::createApplicableFunction(
13373-
*this, type1, type2, TrailingClosureMatching::Forward, useDC,
13369+
*this, func1, type2, TrailingClosureMatching::Forward, useDC,
1337413370
applyLocator);
1337513371
auto backwardConstraint = Constraint::createApplicableFunction(
13376-
*this, type1, type2, TrailingClosureMatching::Backward, useDC,
13372+
*this, func1, type2, TrailingClosureMatching::Backward, useDC,
1337713373
applyLocator);
1337813374
addDisjunctionConstraint({forwardConstraint, backwardConstraint},
1337913375
applyLocator);
@@ -13472,7 +13468,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1347213468

1347313469
// Handle applications of @dynamicCallable types.
1347413470
auto result = simplifyDynamicCallableApplicableFnConstraint(
13475-
type1, origType2, subflags, locator);
13471+
func1, origType2, subflags, locator);
1347613472

1347713473
if (shouldAttemptFixes() && result == SolutionKind::Error) {
1347813474
// Skip this fix if the type is not yet resolved or
@@ -16204,9 +16200,9 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
1620416200

1620516201
case ConstraintKind::ApplicableFunction:
1620616202
return simplifyApplicableFnConstraint(
16207-
constraint.getFirstType(), constraint.getSecondType(),
16203+
constraint.getAppliedFunctionType(), constraint.getCalleeType(),
1620816204
constraint.getTrailingClosureMatching(),
16209-
/*FIXME*/DC, /*flags=*/std::nullopt,
16205+
constraint.getApplicationDC(), /*flags=*/std::nullopt,
1621016206
constraint.getLocator());
1621116207

1621216208
case ConstraintKind::DynamicCallableApplicableFunction:

lib/Sema/Constraint.cpp

+39-17
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,9 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
8787
assert(!First.isNull());
8888
assert(!Second.isNull());
8989
break;
90-
case ConstraintKind::ApplicableFunction:
9190
case ConstraintKind::DynamicCallableApplicableFunction:
9291
assert(First->is<FunctionType>()
9392
&& "The left-hand side type should be a function type");
94-
trailingClosureMatching = 0;
9593
break;
9694

9795
case ConstraintKind::ValueMember:
@@ -120,6 +118,10 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
120118

121119
case ConstraintKind::SyntacticElement:
122120
llvm_unreachable("Syntactic element constraint should use create()");
121+
122+
case ConstraintKind::ApplicableFunction:
123+
llvm_unreachable(
124+
"Application constraint should use create()");
123125
}
124126

125127
std::uninitialized_copy(typeVars.begin(), typeVars.end(),
@@ -287,6 +289,27 @@ Constraint::Constraint(ASTNode node, ContextualTypeInfo context,
287289
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
288290
}
289291

292+
Constraint::Constraint(FunctionType *appliedFn, Type calleeType,
293+
unsigned trailingClosureMatching, DeclContext *useDC,
294+
ConstraintLocator *locator,
295+
SmallPtrSetImpl<TypeVariableType *> &typeVars)
296+
: Kind(ConstraintKind::ApplicableFunction), HasFix(false),
297+
HasRestriction(false), IsActive(false), IsDisabled(false),
298+
IsDisabledForPerformance(false), RememberChoice(false), IsFavored(false),
299+
IsIsolated(false), NumTypeVariables(typeVars.size()), Locator(locator) {
300+
assert(appliedFn);
301+
assert(calleeType);
302+
assert(trailingClosureMatching >= 0 && trailingClosureMatching <= 2);
303+
assert(useDC);
304+
305+
Apply.AppliedFn = appliedFn;
306+
Apply.Callee = calleeType;
307+
Apply.TrailingClosureMatching = trailingClosureMatching;
308+
Apply.UseDC = useDC;
309+
310+
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
311+
}
312+
290313
ProtocolDecl *Constraint::getProtocol() const {
291314
assert((Kind == ConstraintKind::ConformsTo ||
292315
Kind == ConstraintKind::LiteralConformsTo ||
@@ -986,7 +1009,7 @@ Constraint *Constraint::createConjunction(
9861009
}
9871010

9881011
Constraint *Constraint::createApplicableFunction(
989-
ConstraintSystem &cs, Type argumentFnType, Type calleeType,
1012+
ConstraintSystem &cs, FunctionType *argumentFnType, Type calleeType,
9901013
std::optional<TrailingClosureMatching> trailingClosureMatching,
9911014
DeclContext *useDC, ConstraintLocator *locator) {
9921015
// Collect type variables.
@@ -996,30 +1019,29 @@ Constraint *Constraint::createApplicableFunction(
9961019
if (calleeType->hasTypeVariable())
9971020
calleeType->getTypeVariables(typeVars);
9981021

999-
// Create the constraint.
1000-
auto size =
1001-
totalSizeToAlloc<TypeVariableType *, ConstraintFix *, OverloadChoice>(
1002-
typeVars.size(), /*hasFix=*/0, /*hasOverloadChoice=*/0);
1003-
void *mem = cs.getAllocator().Allocate(size, alignof(Constraint));
1004-
auto constraint = new (mem) Constraint(
1005-
ConstraintKind::ApplicableFunction, argumentFnType, calleeType, locator,
1006-
typeVars);
1007-
10081022
// Encode the trailing closure matching.
1023+
unsigned rawTrailingClosureMatching = 0;
10091024
if (trailingClosureMatching) {
10101025
switch (*trailingClosureMatching) {
10111026
case TrailingClosureMatching::Forward:
1012-
constraint->trailingClosureMatching = 1;
1027+
rawTrailingClosureMatching = 1;
10131028
break;
10141029

10151030
case TrailingClosureMatching::Backward:
1016-
constraint->trailingClosureMatching = 2;
1031+
rawTrailingClosureMatching = 2;
10171032
break;
10181033
}
1019-
} else {
1020-
constraint->trailingClosureMatching = 0;
10211034
}
10221035

1036+
// Create the constraint.
1037+
auto size =
1038+
totalSizeToAlloc<TypeVariableType *, ConstraintFix *, OverloadChoice>(
1039+
typeVars.size(), /*hasFix=*/0, /*hasOverloadChoice=*/0);
1040+
void *mem = cs.getAllocator().Allocate(size, alignof(Constraint));
1041+
auto constraint = new (mem)
1042+
Constraint(argumentFnType, calleeType, rawTrailingClosureMatching, useDC,
1043+
locator, typeVars);
1044+
10231045
return constraint;
10241046
}
10251047

@@ -1051,7 +1073,7 @@ Constraint *Constraint::createSyntacticElement(ConstraintSystem &cs,
10511073
std::optional<TrailingClosureMatching>
10521074
Constraint::getTrailingClosureMatching() const {
10531075
assert(Kind == ConstraintKind::ApplicableFunction);
1054-
switch (trailingClosureMatching) {
1076+
switch (Apply.TrailingClosureMatching) {
10551077
case 0:
10561078
return std::nullopt;
10571079
case 1: return TrailingClosureMatching::Forward;

unittests/Sema/ConstraintSimplificationTests.cpp

+95
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,98 @@ TEST_F(SemaTest, TestClosureInferenceFromOptionalContext) {
127127
ASSERT_TRUE(cs.simplifyType(paramTy)->isEqual(getStdlibType("Int")));
128128
ASSERT_TRUE(cs.simplifyType(resultTy)->isEqual(Context.TheEmptyTupleType));
129129
}
130+
131+
/// Emulates code like this:
132+
///
133+
/// func test(_: (Int) -> Void) {}
134+
///
135+
/// test { Double($0) }
136+
///
137+
/// To make sure that constructor application sets correct
138+
/// declaration context for implicit `.init` member.
139+
TEST_F(SemaTest, TestInitializerUseDCIsSetCorrectlyInClosure) {
140+
ConstraintSystem cs(DC, ConstraintSystemOptions());
141+
142+
DeclAttributes closureAttrs;
143+
144+
// Anonymous closure parameter
145+
auto paramName = Context.getIdentifier("0");
146+
147+
auto *paramDecl =
148+
new (Context) ParamDecl(/*specifierLoc=*/SourceLoc(),
149+
/*argumentNameLoc=*/SourceLoc(), paramName,
150+
/*parameterNameLoc=*/SourceLoc(), paramName, DC);
151+
152+
paramDecl->setSpecifier(ParamSpecifier::Default);
153+
154+
auto *closure = new (Context) ClosureExpr(
155+
closureAttrs,
156+
/*bracketRange=*/SourceRange(),
157+
/*capturedSelfDecl=*/nullptr, ParameterList::create(Context, {paramDecl}),
158+
/*asyncLoc=*/SourceLoc(),
159+
/*throwsLoc=*/SourceLoc(),
160+
/*thrownType=*/nullptr,
161+
/*arrowLoc=*/SourceLoc(),
162+
/*inLoc=*/SourceLoc(),
163+
/*explicitResultType=*/nullptr,
164+
/*parent=*/DC);
165+
closure->setDiscriminator(0);
166+
167+
closure->setImplicit();
168+
169+
// Double($0)
170+
auto initCall = CallExpr::createImplicit(
171+
Context, TypeExpr::createImplicit(getStdlibType("Double"), Context),
172+
ArgumentList::forImplicitUnlabeled(
173+
Context, {new (Context) DeclRefExpr(ConcreteDeclRef(paramDecl),
174+
/*Loc*/ DeclNameLoc(),
175+
/*Implicit=*/true)}));
176+
177+
closure->setBody(BraceStmt::createImplicit(Context, {initCall}));
178+
179+
auto *closureLoc = cs.getConstraintLocator(closure);
180+
181+
auto *paramTy = cs.createTypeVariable(
182+
cs.getConstraintLocator(closure, LocatorPathElt::TupleElement(0)),
183+
/*options=*/TVO_CanBindToInOut);
184+
185+
auto *resultTy = cs.createTypeVariable(
186+
cs.getConstraintLocator(closure, ConstraintLocator::ClosureResult),
187+
/*options=*/0);
188+
189+
auto extInfo = FunctionType::ExtInfo();
190+
191+
auto defaultTy = FunctionType::get({FunctionType::Param(paramTy, paramName)},
192+
resultTy, extInfo);
193+
194+
cs.setClosureType(closure, defaultTy);
195+
196+
auto *closureTy = cs.createTypeVariable(closureLoc, /*options=*/0);
197+
cs.setType(closure, closureTy);
198+
199+
cs.addUnsolvedConstraint(Constraint::create(
200+
cs, ConstraintKind::FallbackType, closureTy, defaultTy,
201+
cs.getConstraintLocator(closure), /*referencedVars=*/{}));
202+
203+
auto contextualTy =
204+
FunctionType::get({FunctionType::Param(getStdlibType("Int"))},
205+
Context.TheEmptyTupleType, extInfo);
206+
207+
cs.resolveClosure(closureTy, contextualTy, closureLoc);
208+
209+
auto &graph = cs.getConstraintGraph();
210+
211+
for (const auto &component :
212+
graph.computeConnectedComponents(cs.getTypeVariables())) {
213+
for (auto *constraint : component.getConstraints()) {
214+
if (constraint->getKind() != ConstraintKind::Disjunction)
215+
continue;
216+
217+
ASSERT_TRUE(constraint->getLocator()
218+
->isLastElement<LocatorPathElt::ConstructorMember>());
219+
220+
for (auto *choice : constraint->getNestedConstraints())
221+
ASSERT_EQ(choice->getOverloadUseDC(), closure);
222+
}
223+
}
224+
}

0 commit comments

Comments
 (0)