Skip to content

Commit b7860ea

Browse files
committed
[TypeChecker] Split for-in sequence into parsed and type-checked versions
1 parent 5f0dcb5 commit b7860ea

17 files changed

+64
-49
lines changed

include/swift/AST/Stmt.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ class ForEachStmt : public LabeledStmt {
743743

744744
// Set by Sema:
745745
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
746-
VarDecl *iteratorVar = nullptr;
746+
PatternBindingDecl *iteratorVar = nullptr;
747747
Expr *nextCall = nullptr;
748748
OpaqueValueExpr *elementExpr = nullptr;
749749
Expr *convertElementExpr = nullptr;
@@ -759,8 +759,8 @@ class ForEachStmt : public LabeledStmt {
759759
setPattern(Pat);
760760
}
761761

762-
void setIteratorVar(VarDecl *var) { iteratorVar = var; }
763-
VarDecl *getIteratorVar() const { return iteratorVar; }
762+
void setIteratorVar(PatternBindingDecl *var) { iteratorVar = var; }
763+
PatternBindingDecl *getIteratorVar() const { return iteratorVar; }
764764

765765
void setNextCall(Expr *next) { nextCall = next; }
766766
Expr *getNextCall() const { return nextCall; }
@@ -802,8 +802,12 @@ class ForEachStmt : public LabeledStmt {
802802
/// by this foreach loop, as it was written in the source code and
803803
/// subsequently type-checked. To determine the semantic behavior of this
804804
/// expression to extract a range, use \c getRangeInit().
805-
Expr *getSequence() const { return Sequence; }
806-
void setSequence(Expr *S) { Sequence = S; }
805+
Expr *getParsedSequence() const { return Sequence; }
806+
void setParsedSequence(Expr *S) { Sequence = S; }
807+
808+
/// Type-checked version of the sequence or nullptr if this statement
809+
/// yet to be type-checked.
810+
Expr *getTypeCheckedSequence() const;
807811

808812
/// getBody - Retrieve the body of the loop.
809813
BraceStmt *getBody() const { return Body; }

include/swift/Sema/ConstraintSystem.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ struct ForEachStmtInfo {
10111011
Type initType;
10121012

10131013
/// Implicit `$iterator = <sequence>.makeIterator()`
1014-
VarDecl *makeIteratorVar;
1014+
PatternBindingDecl *makeIteratorVar;
10151015

10161016
/// Implicit `$iterator.next()` call.
10171017
Expr *nextCall;
@@ -2434,7 +2434,7 @@ class SolutionApplicationTarget {
24342434
case Kind::forEachStmt:
24352435
auto *stmt = forEachStmt.stmt;
24362436
SourceLoc startLoc = stmt->getForLoc();
2437-
SourceLoc endLoc = stmt->getSequence()->getEndLoc();
2437+
SourceLoc endLoc = stmt->getParsedSequence()->getEndLoc();
24382438

24392439
if (auto *whereExpr = stmt->getWhere()) {
24402440
endLoc = whereExpr->getEndLoc();

lib/AST/ASTDumper.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
15721572
}
15731573
printRec(S->getPattern());
15741574
OS << '\n';
1575-
printRec(S->getSequence());
1575+
printRec(S->getParsedSequence());
15761576
OS << '\n';
15771577
if (S->getIteratorVar()) {
15781578
printRec(S->getIteratorVar());

lib/AST/ASTScopeCreation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ void SwitchStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
994994

995995
void ForEachStmtScope::expandAScopeThatDoesNotCreateANewInsertionPoint(
996996
ScopeCreator &scopeCreator) {
997-
scopeCreator.addToScopeTree(stmt->getSequence(), this);
997+
scopeCreator.addToScopeTree(stmt->getParsedSequence(), this);
998998

999999
// Add a child describing the scope of the pattern.
10001000
// In error cases such as:

lib/AST/ASTWalker.cpp

+22-17
Original file line numberDiff line numberDiff line change
@@ -1590,11 +1590,28 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
15901590

15911591
// The iterator decl is built directly on top of the sequence
15921592
// expression, so don't visit both.
1593-
if (Expr *Sequence = S->getSequence()) {
1594-
if ((Sequence = doIt(Sequence)))
1595-
S->setSequence(Sequence);
1596-
else
1597-
return nullptr;
1593+
//
1594+
// If for-in is already type-checked, the type-checked version
1595+
// of the sequence is going to be visited as part of `iteratorVar`.
1596+
if (S->getTypeCheckedSequence()) {
1597+
if (auto IteratorVar = S->getIteratorVar()) {
1598+
if (doIt(IteratorVar))
1599+
return nullptr;
1600+
}
1601+
1602+
if (auto NextCall = S->getNextCall()) {
1603+
if ((NextCall = doIt(NextCall)))
1604+
S->setNextCall(NextCall);
1605+
else
1606+
return nullptr;
1607+
}
1608+
} else {
1609+
if (Expr *Sequence = S->getParsedSequence()) {
1610+
if ((Sequence = doIt(Sequence)))
1611+
S->setParsedSequence(Sequence);
1612+
else
1613+
return nullptr;
1614+
}
15981615
}
15991616

16001617
if (Expr *Where = S->getWhere()) {
@@ -1611,18 +1628,6 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
16111628
return nullptr;
16121629
}
16131630

1614-
if (auto IteratorVar = S->getIteratorVar()) {
1615-
if (doIt(IteratorVar))
1616-
return nullptr;
1617-
}
1618-
1619-
if (auto NextCall = S->getNextCall()) {
1620-
if ((NextCall = doIt(NextCall)))
1621-
S->setNextCall(NextCall);
1622-
else
1623-
return nullptr;
1624-
}
1625-
16261631
if (Stmt *Body = S->getBody()) {
16271632
if ((Body = doIt(Body)))
16281633
S->setBody(cast<BraceStmt>(Body));

lib/AST/NameLookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -3110,7 +3110,7 @@ void FindLocalVal::visitForEachStmt(ForEachStmt *S) {
31103110
if (!isReferencePointInRange(S->getSourceRange()))
31113111
return;
31123112
visit(S->getBody());
3113-
if (!isReferencePointInRange(S->getSequence()->getSourceRange()))
3113+
if (!isReferencePointInRange(S->getParsedSequence()->getSourceRange()))
31143114
checkPattern(S->getPattern(), DeclVisibilityKind::LocalVariable);
31153115
}
31163116

lib/AST/Stmt.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ void ForEachStmt::setPattern(Pattern *p) {
301301
Pat->markOwnedByStatement(this);
302302
}
303303

304+
Expr *ForEachStmt::getTypeCheckedSequence() const {
305+
return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr;
306+
}
307+
304308
DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
305309
SourceLoc doLoc, Stmt *body,
306310
ArrayRef<CaseStmt *> catches,

lib/IDE/ExprContextAnalysis.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ class ExprContextAnalyzer {
10921092
break;
10931093
}
10941094
case StmtKind::ForEach:
1095-
if (auto SEQ = cast<ForEachStmt>(Parent)->getSequence()) {
1095+
if (auto SEQ = cast<ForEachStmt>(Parent)->getParsedSequence()) {
10961096
if (containsTarget(SEQ)) {
10971097
recordPossibleType(Context.getSequenceType());
10981098
}

lib/IDE/Formatting.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2246,7 +2246,7 @@ class FormatWalker : public ASTWalker {
22462246
if (Range.isValid() && overlapsTarget(Range))
22472247
return IndentContext {ForLoc, !OutdentChecker::hasOutdent(SM, P)};
22482248
}
2249-
if (auto *E = FS->getSequence()) {
2249+
if (auto *E = FS->getParsedSequence()) {
22502250
SourceRange Range = FS->getInLoc();
22512251
widenOrSet(Range, E->getSourceRange());
22522252
if (Range.isValid() && isTargetContext(Range)) {

lib/IDE/SyntaxModel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,8 @@ std::pair<bool, Stmt *> ModelASTWalker::walkToStmtPre(Stmt *S) {
747747
charSourceRangeFromSourceRange(SM, ElemRange));
748748
}
749749
}
750-
if (ForEachS->getSequence())
751-
addExprElem(SyntaxStructureElementKind::Expr, ForEachS->getSequence(),SN);
750+
if (auto *S = ForEachS->getParsedSequence())
751+
addExprElem(SyntaxStructureElementKind::Expr, S, SN);
752752
pushStructureNode(SN, S);
753753

754754
} else if (auto *WhileS = dyn_cast<WhileStmt>(S)) {

lib/SILGen/SILGenStmt.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
10221022
// Emit the 'iterator' variable that we'll be using for iteration.
10231023
LexicalScope OuterForScope(SGF, CleanupLocation(S));
10241024
{
1025-
SGF.emitPatternBinding(S->getIteratorVar()->getParentPatternBinding(),
1025+
SGF.emitPatternBinding(S->getIteratorVar(),
10261026
/*index=*/0);
10271027
}
10281028

@@ -1074,7 +1074,7 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) {
10741074
SILLocation loc = SILLocation(S);
10751075
RValue result = buildElementRValue(SGFContext(nextInit.get()));
10761076
if (!result.isInContext()) {
1077-
ArgumentSource(SILLocation(S->getSequence()),
1077+
ArgumentSource(SILLocation(S->getTypeCheckedSequence()),
10781078
std::move(result).ensurePlusOne(SGF, loc))
10791079
.forwardInto(SGF, nextInit.get());
10801080
}

lib/Sema/CSApply.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -8556,7 +8556,7 @@ static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
85568556
auto resultTarget = target;
85578557
auto &forEachStmtInfo = resultTarget.getForEachStmtInfo();
85588558
auto *stmt = target.getAsForEachStmt();
8559-
auto *parsedSequence = stmt->getSequence();
8559+
auto *parsedSequence = stmt->getParsedSequence();
85608560
bool isAsync = stmt->getAwaitLoc().isValid();
85618561

85628562
// Simplify the various types.
@@ -8572,21 +8572,22 @@ static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
85728572

85738573
// First, let's apply the solution to the sequence expression.
85748574
{
8575+
auto *makeIteratorVar = forEachStmtInfo.makeIteratorVar;
8576+
85758577
auto makeIteratorTarget =
8576-
*cs.getSolutionApplicationTarget(forEachStmtInfo.makeIteratorVar);
8578+
*cs.getSolutionApplicationTarget({makeIteratorVar, /*index=*/0});
85778579

85788580
auto rewrittenTarget = rewriteTarget(makeIteratorTarget);
85798581
if (!rewrittenTarget)
85808582
return None;
85818583

8582-
stmt->setSequence(rewrittenTarget->getAsExpr());
8583-
stmt->setIteratorVar(forEachStmtInfo.makeIteratorVar);
8584-
8585-
auto *PB = forEachStmtInfo.makeIteratorVar->getParentPatternBinding();
8584+
// Set type-checked initializer and mark it as such.
85868585
{
8587-
PB->setInit(/*index=*/0, stmt->getSequence());
8588-
PB->setInitializerChecked(/*index=*/0);
8586+
makeIteratorVar->setInit(/*index=*/0, rewrittenTarget->getAsExpr());
8587+
makeIteratorVar->setInitializerChecked(/*index=*/0);
85898588
}
8589+
8590+
stmt->setIteratorVar(makeIteratorVar);
85908591
}
85918592

85928593
// Now, `$iterator.next()` call.

lib/Sema/CSGen.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -3886,7 +3886,7 @@ generateForEachStmtConstraints(
38863886
auto forEachStmtInfo = target.getForEachStmtInfo();
38873887
ForEachStmt *stmt = target.getAsForEachStmt();
38883888
bool isAsync = stmt->getAwaitLoc().isValid();
3889-
auto *sequenceExpr = stmt->getSequence();
3889+
auto *sequenceExpr = stmt->getParsedSequence();
38903890
auto *dc = target.getDeclContext();
38913891
auto contextualLocator = cs.getConstraintLocator(
38923892
sequenceExpr, LocatorPathElt::ContextualType(CTP_ForEachSequence));
@@ -3940,7 +3940,7 @@ generateForEachStmtConstraints(
39403940
FreeTypeVariableBinding::Disallow))
39413941
return None;
39423942

3943-
forEachStmtInfo.makeIteratorVar = makeIteratorVar;
3943+
forEachStmtInfo.makeIteratorVar = PB;
39443944

39453945
// Type of sequence expression has to conform to Sequence protocol.
39463946
{
@@ -3951,7 +3951,7 @@ generateForEachStmtConstraints(
39513951
forEachStmtInfo.sequenceType = cs.getType(sequenceExpr);
39523952
}
39533953

3954-
cs.setSolutionApplicationTarget(makeIteratorVar, makeIteratorTarget);
3954+
cs.setSolutionApplicationTarget({PB, /*index=*/0}, makeIteratorTarget);
39553955
}
39563956

39573957
// Now, result type of `.makeIterator()` is used to form a call to

lib/Sema/MiscDiagnostics.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -3784,7 +3784,7 @@ static void checkStmtConditionTrailingClosure(ASTContext &ctx, const Stmt *S) {
37843784
} else if (auto SS = dyn_cast<SwitchStmt>(S)) {
37853785
checkStmtConditionTrailingClosure(ctx, SS->getSubjectExpr());
37863786
} else if (auto FES = dyn_cast<ForEachStmt>(S)) {
3787-
checkStmtConditionTrailingClosure(ctx, FES->getSequence());
3787+
checkStmtConditionTrailingClosure(ctx, FES->getParsedSequence());
37883788
checkStmtConditionTrailingClosure(ctx, FES->getWhere());
37893789
} else if (auto DCS = dyn_cast<DoCatchStmt>(S)) {
37903790
for (auto CS : DCS->getCatches())

lib/Sema/PCMacro.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class Instrumenter : InstrumenterBase {
224224

225225
// point at the for stmt, to look nice
226226
SourceLoc StartLoc = FES->getStartLoc();
227-
SourceLoc EndLoc = FES->getSequence()->getEndLoc();
227+
SourceLoc EndLoc = FES->getParsedSequence()->getEndLoc();
228228
// FIXME: get the 'end' of the for stmt
229229
// if (FD->getResultTypeRepr()) {
230230
// EndLoc = FD->getResultTypeSourceRange().End;

lib/Sema/PreCheckExpr.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -2126,7 +2126,7 @@ bool ConstraintSystem::preCheckTarget(SolutionApplicationTarget &target,
21262126
if (target.isForEachStmt()) {
21272127
auto *stmt = target.getAsForEachStmt();
21282128

2129-
auto *sequenceExpr = stmt->getSequence();
2129+
auto *sequenceExpr = stmt->getParsedSequence();
21302130
auto *whereExpr = stmt->getWhere();
21312131

21322132
hadErrors |= preCheckExpression(sequenceExpr, DC,
@@ -2141,7 +2141,7 @@ bool ConstraintSystem::preCheckTarget(SolutionApplicationTarget &target,
21412141

21422142
// Update sequence and where expressions to pre-checked versions.
21432143
if (!hadErrors) {
2144-
stmt->setSequence(sequenceExpr);
2144+
stmt->setParsedSequence(sequenceExpr);
21452145

21462146
if (whereExpr)
21472147
stmt->setWhere(whereExpr);

lib/Sema/TypeCheckConstraints.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ void constraints::performSyntacticDiagnosticsForTarget(
306306
auto *stmt = target.getAsForEachStmt();
307307

308308
// First emit diagnostics for the main expression.
309-
performSyntacticExprDiagnostics(stmt->getSequence(), dc, isExprStmt,
309+
performSyntacticExprDiagnostics(stmt->getTypeCheckedSequence(), dc,
310+
isExprStmt,
310311
disableExprAvailabilityChecking);
311312

312313
if (auto *whereExpr = stmt->getWhere())

0 commit comments

Comments
 (0)