Skip to content

Commit 066f253

Browse files
authored
Merge pull request #70196 from simanerush/nested-pack-iteration
[SE-0408] Enable nested iteration
2 parents 917eacb + 0b167b5 commit 066f253

11 files changed

+86
-15
lines changed

include/swift/Sema/ConstraintSystem.h

+9
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,10 @@ class Solution {
15431543
llvm::MapVector<PackElementExpr *, PackExpansionExpr *>
15441544
PackEnvironments;
15451545

1546+
/// The outer pack element generic environment to use when dealing with nested
1547+
/// pack iteration (see \c getPackElementEnvironment).
1548+
llvm::SmallVector<GenericEnvironment *> PackElementGenericEnvironments;
1549+
15461550
/// The locators of \c Defaultable constraints whose defaults were used.
15471551
llvm::DenseSet<ConstraintLocator *> DefaultedConstraints;
15481552

@@ -2344,6 +2348,8 @@ class ConstraintSystem {
23442348
llvm::SmallMapVector<PackElementExpr *, PackExpansionExpr *, 2>
23452349
PackEnvironments;
23462350

2351+
llvm::SmallVector<GenericEnvironment *, 4> PackElementGenericEnvironments;
2352+
23472353
/// The set of functions that have been transformed by a result builder.
23482354
llvm::MapVector<AnyFunctionRef, AppliedBuilderTransform>
23492355
resultBuilderTransformed;
@@ -2833,6 +2839,9 @@ class ConstraintSystem {
28332839
/// The length of \c PackEnvironments.
28342840
unsigned numPackEnvironments;
28352841

2842+
/// The length of \c PackElementGenericEnvironments.
2843+
unsigned numPackElementGenericEnvironments;
2844+
28362845
/// The length of \c DefaultedConstraints.
28372846
unsigned numDefaultedConstraints;
28382847

include/swift/Sema/SyntacticElementTarget.h

+13-4
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class SyntacticElementTarget {
160160
DeclContext *dc;
161161
Pattern *pattern;
162162
bool ignoreWhereClause;
163+
GenericEnvironment *packElementEnv;
163164
ForEachStmtInfo info;
164165
} forEachStmt;
165166

@@ -239,11 +240,13 @@ class SyntacticElementTarget {
239240
}
240241

241242
SyntacticElementTarget(ForEachStmt *stmt, DeclContext *dc,
242-
bool ignoreWhereClause)
243+
bool ignoreWhereClause,
244+
GenericEnvironment *packElementEnv)
243245
: kind(Kind::forEachStmt) {
244246
forEachStmt.stmt = stmt;
245247
forEachStmt.dc = dc;
246248
forEachStmt.ignoreWhereClause = ignoreWhereClause;
249+
forEachStmt.packElementEnv = packElementEnv;
247250
}
248251

249252
/// Form a target for the initialization of a pattern from an expression.
@@ -259,9 +262,10 @@ class SyntacticElementTarget {
259262
unsigned patternBindingIndex, bool bindPatternVarsOneWay);
260263

261264
/// Form a target for a for-in loop.
262-
static SyntacticElementTarget forForEachStmt(ForEachStmt *stmt,
263-
DeclContext *dc,
264-
bool ignoreWhereClause = false);
265+
static SyntacticElementTarget
266+
forForEachStmt(ForEachStmt *stmt, DeclContext *dc,
267+
bool ignoreWhereClause = false,
268+
GenericEnvironment *packElementEnv = nullptr);
265269

266270
/// Form a target for a property with an attached property wrapper that is
267271
/// initialized out-of-line.
@@ -536,6 +540,11 @@ class SyntacticElementTarget {
536540
return forEachStmt.ignoreWhereClause;
537541
}
538542

543+
GenericEnvironment *getPackElementEnv() const {
544+
assert(isForEachStmt());
545+
return forEachStmt.packElementEnv;
546+
}
547+
539548
const ForEachStmtInfo &getForEachStmtInfo() const {
540549
assert(isForEachStmt());
541550
return forEachStmt.info;

lib/AST/GenericEnvironment.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -734,10 +734,11 @@ GenericEnvironment::mapElementTypeIntoPackContext(Type type) const {
734734

735735
type = type->mapTypeOutOfContext();
736736

737+
auto interfaceType = element->getInterfaceType();
738+
737739
llvm::SmallDenseMap<GenericParamKey, GenericTypeParamType *>
738740
packParamForElement;
739-
auto elementDepth =
740-
sig.getInnermostGenericParams().front()->getDepth() + 1;
741+
auto elementDepth = interfaceType->getRootGenericParam()->getDepth();
741742

742743
for (auto *genericParam : sig.getGenericParams()) {
743744
if (!genericParam->isParameterPack())
@@ -792,6 +793,8 @@ Type BuildForwardingSubstitutions::operator()(SubstitutableType *type) const {
792793
auto param = type->castTo<GenericTypeParamType>();
793794
if (!param->isParameterPack())
794795
return resultType;
796+
if (resultType->is<PackType>())
797+
return resultType;
795798
return PackType::getSingletonPackExpansion(resultType);
796799
}
797800
return Type();

lib/Sema/CSGen.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -4929,6 +4929,12 @@ bool ConstraintSystem::generateConstraints(
49294929
}
49304930

49314931
case SyntacticElementTarget::Kind::forEachStmt: {
4932+
4933+
// Cache the outer generic environment, if it exists.
4934+
if (target.getPackElementEnv()) {
4935+
PackElementGenericEnvironments.push_back(target.getPackElementEnv());
4936+
}
4937+
49324938
// For a for-each statement, generate constraints for the pattern, where
49334939
// clause, and sequence traversal.
49344940
auto resultTarget = generateForEachStmtConstraints(*this, target);

lib/Sema/CSSolver.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ Solution ConstraintSystem::finalize() {
247247
for (const auto &packEnv : PackEnvironments)
248248
solution.PackEnvironments.insert(packEnv);
249249

250+
for (const auto &packEltGenericEnv : PackElementGenericEnvironments)
251+
solution.PackElementGenericEnvironments.push_back(packEltGenericEnv);
252+
250253
return solution;
251254
}
252255

@@ -316,6 +319,12 @@ void ConstraintSystem::applySolution(const Solution &solution) {
316319
PackEnvironments.insert(packEnvironment);
317320
}
318321

322+
// Register the solutions's pack element generic environments.
323+
for (auto &packElementGenericEnvironment :
324+
solution.PackElementGenericEnvironments) {
325+
PackElementGenericEnvironments.push_back(packElementGenericEnvironment);
326+
}
327+
319328
// Register the defaulted type variables.
320329
DefaultedConstraints.insert(solution.DefaultedConstraints.begin(),
321330
solution.DefaultedConstraints.end());
@@ -647,6 +656,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
647656
numOpenedPackExpansionTypes = cs.OpenedPackExpansionTypes.size();
648657
numPackExpansionEnvironments = cs.PackExpansionEnvironments.size();
649658
numPackEnvironments = cs.PackEnvironments.size();
659+
numPackElementGenericEnvironments = cs.PackElementGenericEnvironments.size();
650660
numDefaultedConstraints = cs.DefaultedConstraints.size();
651661
numAddedNodeTypes = cs.addedNodeTypes.size();
652662
numAddedKeyPathComponentTypes = cs.addedKeyPathComponentTypes.size();
@@ -736,6 +746,10 @@ ConstraintSystem::SolverScope::~SolverScope() {
736746
// Remove any pack environments.
737747
truncate(cs.PackEnvironments, numPackEnvironments);
738748

749+
// Remove any pack element generic environments.
750+
truncate(cs.PackElementGenericEnvironments,
751+
numPackElementGenericEnvironments);
752+
739753
// Remove any defaulted type variables.
740754
truncate(cs.DefaultedConstraints, numDefaultedConstraints);
741755

lib/Sema/ConstraintSystem.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -792,9 +792,11 @@ ConstraintSystem::getPackElementEnvironment(ConstraintLocator *locator,
792792
shapeClass->mapTypeOutOfContext()->getCanonicalType());
793793

794794
auto &ctx = getASTContext();
795+
auto *contextEnv = PackElementGenericEnvironments.empty()
796+
? DC->getGenericEnvironmentOfContext()
797+
: PackElementGenericEnvironments.back();
795798
auto elementSig = ctx.getOpenedElementSignature(
796-
DC->getGenericSignatureOfContext().getCanonicalSignature(), shapeParam);
797-
auto *contextEnv = DC->getGenericEnvironmentOfContext();
799+
contextEnv->getGenericSignature().getCanonicalSignature(), shapeParam);
798800
auto contextSubs = contextEnv->getForwardingSubstitutionMap();
799801
return GenericEnvironment::forOpenedElement(elementSig, uuidAndShape.first,
800802
shapeParam, contextSubs);
@@ -4413,6 +4415,7 @@ size_t Solution::getTotalMemory() const {
44134415
OpenedPackExpansionTypes.getMemorySize() +
44144416
PackExpansionEnvironments.getMemorySize() +
44154417
size_in_bytes(PackEnvironments) +
4418+
PackElementGenericEnvironments.size() +
44164419
(DefaultedConstraints.size() * sizeof(void *)) +
44174420
ImplicitCallAsFunctionRoots.getMemorySize() +
44184421
nodeTypes.getMemorySize() +

lib/Sema/SyntacticElementTarget.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,9 @@ SyntacticElementTarget SyntacticElementTarget::forInitialization(
181181

182182
SyntacticElementTarget
183183
SyntacticElementTarget::forForEachStmt(ForEachStmt *stmt, DeclContext *dc,
184-
bool ignoreWhereClause) {
185-
SyntacticElementTarget target(stmt, dc, ignoreWhereClause);
184+
bool ignoreWhereClause,
185+
GenericEnvironment *packElementEnv) {
186+
SyntacticElementTarget target(stmt, dc, ignoreWhereClause, packElementEnv);
186187
return target;
187188
}
188189

lib/Sema/TypeCheckConstraints.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,8 @@ bool TypeChecker::typeCheckPatternBinding(PatternBindingDecl *PBD,
906906
return hadError;
907907
}
908908

909-
bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
909+
bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt,
910+
GenericEnvironment *packElementEnv) {
910911
auto &Context = dc->getASTContext();
911912
FrontendStatsTracer statsTracer(Context.Stats, "typecheck-for-each", stmt);
912913
PrettyStackTraceStmt stackTrace(Context, "type-checking-for-each", stmt);
@@ -922,7 +923,8 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
922923
return true;
923924
};
924925

925-
auto target = SyntacticElementTarget::forForEachStmt(stmt, dc);
926+
auto target = SyntacticElementTarget::forForEachStmt(
927+
stmt, dc, /*ignoreWhereClause=*/false, packElementEnv);
926928
if (!typeCheckTarget(target))
927929
return failed();
928930

lib/Sema/TypeCheckStmt.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10001000

10011001
StmtChecker(DeclContext *DC) : Ctx(DC->getASTContext()), DC(DC) { }
10021002

1003+
llvm::SmallVector<GenericEnvironment *, 4> genericSigStack;
1004+
10031005
//===--------------------------------------------------------------------===//
10041006
// Helper Functions.
10051007
//===--------------------------------------------------------------------===//
@@ -1434,17 +1436,28 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
14341436
}
14351437

14361438
Stmt *visitForEachStmt(ForEachStmt *S) {
1437-
if (TypeChecker::typeCheckForEachBinding(DC, S))
1439+
GenericEnvironment *genericSignature =
1440+
genericSigStack.empty() ? nullptr : genericSigStack.back();
1441+
1442+
if (TypeChecker::typeCheckForEachBinding(DC, S, genericSignature))
14381443
return nullptr;
14391444

14401445
// Type-check the body of the loop.
14411446
auto sourceFile = DC->getParentSourceFile();
14421447
checkLabeledStmtShadowing(getASTContext(), sourceFile, S);
14431448

14441449
BraceStmt *Body = S->getBody();
1450+
1451+
if (auto packExpansion =
1452+
dyn_cast<PackExpansionExpr>(S->getParsedSequence()))
1453+
genericSigStack.push_back(packExpansion->getGenericEnvironment());
1454+
14451455
typeCheckStmt(Body);
14461456
S->setBody(Body);
1447-
1457+
1458+
if (isa<PackExpansionExpr>(S->getParsedSequence()))
1459+
genericSigStack.pop_back();
1460+
14481461
return S;
14491462
}
14501463

lib/Sema/TypeChecker.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,8 @@ bool typeCheckPatternBinding(PatternBindingDecl *PBD, unsigned patternNumber,
748748
/// Type-check a for-each loop's pattern binding and sequence together.
749749
///
750750
/// \returns true if a failure occurred.
751-
bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt);
751+
bool typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt,
752+
GenericEnvironment *packElementEnv);
752753

753754
/// Compute the set of captures for the given function or closure.
754755
void computeCaptures(AnyFunctionRef AFR);

test/stmt/foreach.swift

+10
Original file line numberDiff line numberDiff line change
@@ -330,4 +330,14 @@ do {
330330
// expected-error@-1 {{'where' clause in pack iteration is not supported}}
331331
}
332332
}
333+
334+
func nested<each T, each U>(value: repeat each T, value1: repeat each U) {
335+
for e1 in repeat each value {
336+
for _ in [] {}
337+
for e2 in repeat each value1 {
338+
let y = e1 // Ok
339+
}
340+
let x = e1 // Ok
341+
}
342+
}
333343
}

0 commit comments

Comments
 (0)