Skip to content

Commit 1691f32

Browse files
committed
[CSGen] Type-check capture list together with closure body
Delay constraint generation for capture list until body of the associated closure is resolved. This means that we can unify capture checking with that of regular pattern bindings for multi-statement closures.
1 parent f853e8f commit 1691f32

File tree

3 files changed

+60
-31
lines changed

3 files changed

+60
-31
lines changed

Diff for: lib/Sema/CSGen.cpp

+8-17
Original file line numberDiff line numberDiff line change
@@ -2985,7 +2985,14 @@ namespace {
29852985
}
29862986
} collectVarRefs(CS);
29872987

2988-
closure->walk(collectVarRefs);
2988+
// Walk the capture list if this closure has one, because it could
2989+
// reference declarations from the outer closure.
2990+
if (auto *captureList =
2991+
getAsExpr<CaptureListExpr>(CS.getParentExpr(closure))) {
2992+
captureList->walk(collectVarRefs);
2993+
} else {
2994+
closure->walk(collectVarRefs);
2995+
}
29892996

29902997
auto inferredType = inferClosureType(closure);
29912998
if (!inferredType || inferredType->hasError())
@@ -2994,15 +3001,6 @@ namespace {
29943001
SmallVector<TypeVariableType *, 4> referencedVars{
29953002
collectVarRefs.varRefs.begin(), collectVarRefs.varRefs.end()};
29963003

2997-
if (auto *captureList =
2998-
getAsExpr<CaptureListExpr>(CS.getParentExpr(closure))) {
2999-
for (const auto &capture : captureList->getCaptureList()) {
3000-
if (auto *typeVar =
3001-
CS.getType(capture.getVar())->getAs<TypeVariableType>())
3002-
referencedVars.push_back(typeVar);
3003-
}
3004-
}
3005-
30063004
CS.addUnsolvedConstraint(Constraint::create(
30073005
CS, ConstraintKind::DefaultClosureType, closureType, inferredType,
30083006
locator, referencedVars));
@@ -4114,13 +4112,6 @@ namespace {
41144112
// Generate constraints for each of the entries in the capture list.
41154113
if (auto captureList = dyn_cast<CaptureListExpr>(expr)) {
41164114
TypeChecker::diagnoseDuplicateCaptureVars(captureList);
4117-
4118-
auto &CS = CG.getConstraintSystem();
4119-
for (const auto &capture : captureList->getCaptureList()) {
4120-
SyntacticElementTarget target(capture.PBD);
4121-
if (CS.generateConstraints(target))
4122-
return Action::Stop();
4123-
}
41244115
}
41254116

41264117
// Both multi- and single-statement closures now behave the same way

Diff for: lib/Sema/CSSyntacticElement.cpp

+40-2
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,30 @@ class SyntacticElementConstraintGenerator
10121012
void visitBraceStmt(BraceStmt *braceStmt) {
10131013
auto &ctx = cs.getASTContext();
10141014

1015+
CaptureListExpr *captureList = nullptr;
1016+
{
1017+
if (locator->directlyAt<ClosureExpr>()) {
1018+
auto *closure = castToExpr<ClosureExpr>(locator->getAnchor());
1019+
captureList = getAsExpr<CaptureListExpr>(cs.getParentExpr(closure));
1020+
}
1021+
}
1022+
10151023
if (context.isSingleExpressionClosure(cs)) {
1024+
// Generate constraints for the capture list first.
1025+
//
1026+
// TODO: This should be a conjunction connected to
1027+
// the closure body to make sure that each capture
1028+
// is solved in isolation.
1029+
if (captureList) {
1030+
for (const auto &capture : captureList->getCaptureList()) {
1031+
SyntacticElementTarget target(capture.PBD);
1032+
if (cs.generateConstraints(target)) {
1033+
hadError = true;
1034+
return;
1035+
}
1036+
}
1037+
}
1038+
10161039
for (auto node : braceStmt->getElements()) {
10171040
if (auto expr = node.dyn_cast<Expr *>()) {
10181041
auto generatedExpr = cs.generateConstraints(
@@ -1029,6 +1052,8 @@ class SyntacticElementConstraintGenerator
10291052
return;
10301053
}
10311054

1055+
SmallVector<ElementInfo, 4> elements;
1056+
10321057
// If this brace statement represents a body of an empty or
10331058
// multi-statement closure.
10341059
if (locator->directlyAt<ClosureExpr>()) {
@@ -1054,10 +1079,24 @@ class SyntacticElementConstraintGenerator
10541079
cs.getConstraintLocator(closure, ConstraintLocator::ClosureResult));
10551080
}
10561081

1082+
// If this multi-statement closure has captures, let's solve
1083+
// them first.
1084+
if (captureList) {
1085+
for (const auto &capture : captureList->getCaptureList())
1086+
visitPatternBinding(capture.PBD, elements);
1087+
}
1088+
10571089
// Let's not walk into the body if empty or multi-statement closure
10581090
// doesn't participate in inference.
1059-
if (!cs.participatesInInference(closure))
1091+
if (!cs.participatesInInference(closure)) {
1092+
// Although the body doesn't participate in inference we still
1093+
// want to type-check captures to make sure that the context
1094+
// is valid.
1095+
if (captureList)
1096+
createConjunction(cs, elements, locator);
1097+
10601098
return;
1099+
}
10611100
}
10621101

10631102
if (isChildOf(StmtKind::Case)) {
@@ -1070,7 +1109,6 @@ class SyntacticElementConstraintGenerator
10701109
}
10711110
}
10721111

1073-
SmallVector<ElementInfo, 4> elements;
10741112
for (auto element : braceStmt->getElements()) {
10751113
if (cs.isForCodeCompletion() &&
10761114
!cs.containsIDEInspectionTarget(element)) {

Diff for: unittests/Sema/ConstraintGenerationTests.cpp

+12-12
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ TEST_F(SemaTest, TestImplicitConditionalCastConstraintGeneration) {
116116
->isEqual(getStdlibType("Int")));
117117
}
118118

119-
TEST_F(SemaTest, TestCaptureListIsConnectedToTheClosure) {
119+
TEST_F(SemaTest, TestCaptureListIsNotOpenedEarly) {
120120
ConstraintSystem cs(DC, ConstraintSystemOptions());
121121

122122
DeclAttributes attrs;
@@ -160,19 +160,19 @@ TEST_F(SemaTest, TestCaptureListIsConnectedToTheClosure) {
160160
auto *processed = cs.generateConstraints(assign, DC);
161161
ASSERT_NE(processed, nullptr);
162162

163-
auto *closureType = cs.getType(closure)->castTo<TypeVariableType>();
164-
auto &CG = cs.getConstraintGraph();
163+
for (const auto &capture : captureList->getCaptureList()) {
164+
ASSERT_FALSE(cs.hasType(capture.getVar()));
165+
}
165166

166-
for (auto *constraint : CG[closureType].getConstraints()) {
167-
if (constraint->getKind() != ConstraintKind::DefaultClosureType)
168-
continue;
167+
auto *closureType = cs.getType(closure)->castTo<TypeVariableType>();
169168

170-
for (const auto &capture : captureList->getCaptureList()) {
171-
auto *capturedVar =
172-
cs.getType(capture.getVar())->castTo<TypeVariableType>();
169+
ASTExtInfo extInfo;
170+
ASSERT_TRUE(cs.resolveClosure(
171+
closureType,
172+
FunctionType::get(/*params*/ {}, Context.TheEmptyTupleType, extInfo),
173+
cs.getConstraintLocator(closure)));
173174

174-
ASSERT_TRUE(
175-
llvm::is_contained(constraint->getTypeVariables(), capturedVar));
176-
}
175+
for (const auto &capture : captureList->getCaptureList()) {
176+
ASSERT_TRUE(cs.hasType(capture.getVar()));
177177
}
178178
}

0 commit comments

Comments
 (0)