Skip to content

Commit ae37591

Browse files
committed
AST: Add a count type field to PackExpansionType
Even if we can't spell them in source, we want to model expansions where the pattern does not depend on any pack type parameters, eg func f<C...: Collection>(_ c: C...) { let x = (c.count...) } Here, the type of 'x' is notionally 'Int * C.count'.
1 parent 67b23d7 commit ae37591

File tree

7 files changed

+98
-40
lines changed

7 files changed

+98
-40
lines changed

Diff for: include/swift/AST/Types.h

+22-11
Original file line numberDiff line numberDiff line change
@@ -6378,7 +6378,7 @@ class PackType final : public TypeBase, public llvm::FoldingSetNode,
63786378
}
63796379

63806380
public:
6381-
void Profile(llvm::FoldingSetNodeID &ID) {
6381+
void Profile(llvm::FoldingSetNodeID &ID) const {
63826382
Profile(ID, getElementTypes());
63836383
}
63846384
static void Profile(llvm::FoldingSetNodeID &ID, ArrayRef<Type> Elements);
@@ -6427,41 +6427,52 @@ class PackExpansionType : public TypeBase, public llvm::FoldingSetNode {
64276427
friend class ASTContext;
64286428

64296429
Type patternType;
6430+
Type countType;
64306431

64316432
public:
64326433
/// Create a pack expansion type from the given pattern type.
64336434
///
6434-
/// It is not required that the pattern type actually contain a reference to
6435-
/// a variadic generic parameter.
6436-
static PackExpansionType *get(Type pattern);
6435+
/// It is not required that \p pattern actually contain a reference to
6436+
/// a variadic generic parameter, but any variadic generic parameters
6437+
/// appearing in the pattern type must have the same count as \p countType.
6438+
///
6439+
/// As for \p countType itself, it must be a type sequence generic parameter
6440+
/// type, or a sequence archetype type.
6441+
static PackExpansionType *get(Type pattern, Type countType);
64376442

64386443
public:
64396444
/// Retrieves the pattern type of this pack expansion.
64406445
Type getPatternType() const { return patternType; }
64416446

6447+
/// Retrieves the count type of this pack expansion.
6448+
Type getCountType() const { return countType; }
6449+
64426450
public:
64436451
void Profile(llvm::FoldingSetNodeID &ID) {
6444-
Profile(ID, getPatternType());
6452+
Profile(ID, getPatternType(), getCountType());
64456453
}
64466454

6447-
static void Profile(llvm::FoldingSetNodeID &ID, Type patternType);
6455+
static void Profile(llvm::FoldingSetNodeID &ID,
6456+
Type patternType, Type countType);
64486457

64496458
// Implement isa/cast/dyncast/etc.
64506459
static bool classof(const TypeBase *T) {
64516460
return T->getKind() == TypeKind::PackExpansion;
64526461
}
64536462

64546463
private:
6455-
PackExpansionType(Type patternType, const ASTContext *CanCtx)
6456-
: TypeBase(TypeKind::PackExpansion, CanCtx,
6457-
patternType->getRecursiveProperties()), patternType(patternType) {
6458-
assert(patternType);
6459-
}
6464+
PackExpansionType(Type patternType, Type countType,
6465+
RecursiveTypeProperties properties,
6466+
const ASTContext *ctx);
64606467
};
64616468
BEGIN_CAN_TYPE_WRAPPER(PackExpansionType, Type)
64626469
CanType getPatternType() const {
64636470
return CanType(getPointer()->getPatternType());
64646471
}
6472+
6473+
CanType getCountType() const {
6474+
return CanType(getPointer()->getCountType());
6475+
}
64656476
END_CAN_TYPE_WRAPPER(PackExpansionType, Type)
64666477

64676478
/// getASTContext - Return the ASTContext that this type belongs to.

Diff for: lib/AST/ASTContext.cpp

+29-14
Original file line numberDiff line numberDiff line change
@@ -2897,33 +2897,48 @@ TupleTypeElt::TupleTypeElt(Type ty, Identifier name)
28972897
assert(!ty->is<InOutType>() && "Cannot have InOutType in a tuple");
28982898
}
28992899

2900-
PackExpansionType *PackExpansionType::get(Type patternTy) {
2901-
assert(patternTy && "Missing pattern type in expansion");
2900+
PackExpansionType::PackExpansionType(Type patternType, Type countType,
2901+
RecursiveTypeProperties properties,
2902+
const ASTContext *canCtx)
2903+
: TypeBase(TypeKind::PackExpansion, canCtx, properties),
2904+
patternType(patternType), countType(countType) {
2905+
assert(countType->is<TypeVariableType>() ||
2906+
countType->is<SequenceArchetypeType>() ||
2907+
countType->castTo<GenericTypeParamType>()->isTypeSequence());
2908+
}
2909+
2910+
PackExpansionType *PackExpansionType::get(Type patternType, Type countType) {
2911+
auto properties = patternType->getRecursiveProperties();
2912+
properties |= countType->getRecursiveProperties();
29022913

2903-
auto properties = patternTy->getRecursiveProperties();
29042914
auto arena = getArena(properties);
29052915

2906-
auto &context = patternTy->getASTContext();
2916+
auto &context = patternType->getASTContext();
29072917
llvm::FoldingSetNodeID id;
2908-
PackExpansionType::Profile(id, patternTy);
2918+
PackExpansionType::Profile(id, patternType, countType);
29092919

29102920
void *insertPos;
29112921
if (PackExpansionType *expType =
2912-
context.getImpl()
2913-
.getArena(arena)
2914-
.PackExpansionTypes.FindNodeOrInsertPos(id, insertPos))
2922+
context.getImpl().getArena(arena)
2923+
.PackExpansionTypes.FindNodeOrInsertPos(id, insertPos))
29152924
return expType;
29162925

2917-
const ASTContext *canCtx = patternTy->isCanonical() ? &context : nullptr;
2918-
PackExpansionType *expansionTy = new (context, AllocationArena::Permanent)
2919-
PackExpansionType(patternTy, canCtx);
2920-
context.getImpl().getArena(arena).PackExpansionTypes.InsertNode(expansionTy,
2926+
const ASTContext *canCtx =
2927+
(patternType->isCanonical() && countType->isCanonical())
2928+
? &context : nullptr;
2929+
PackExpansionType *expansionType =
2930+
new (context, arena) PackExpansionType(patternType, countType, properties,
2931+
canCtx);
2932+
context.getImpl().getArena(arena).PackExpansionTypes.InsertNode(expansionType,
29212933
insertPos);
2922-
return expansionTy;
2934+
return expansionType;
29232935
}
29242936

2925-
void PackExpansionType::Profile(llvm::FoldingSetNodeID &ID, Type patternType) {
2937+
void PackExpansionType::Profile(llvm::FoldingSetNodeID &ID,
2938+
Type patternType,
2939+
Type countType) {
29262940
ID.AddPointer(patternType.getPointer());
2941+
ID.AddPointer(countType.getPointer());
29272942
}
29282943

29292944
PackType *PackType::getEmpty(const ASTContext &C) {

Diff for: lib/AST/ASTDumper.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -3686,6 +3686,7 @@ namespace {
36863686
void visitPackExpansionType(PackExpansionType *T, StringRef label) {
36873687
printCommon(label, "pack_expansion_type");
36883688
printField("pattern", T->getPatternType());
3689+
printField("count", T->getCountType());
36893690
PrintWithColorRAII(OS, ParenthesisColor) << ')';
36903691
}
36913692

Diff for: lib/AST/Type.cpp

+14-10
Original file line numberDiff line numberDiff line change
@@ -1593,8 +1593,9 @@ CanType TypeBase::computeCanonicalType() {
15931593

15941594
case TypeKind::PackExpansion: {
15951595
auto *expansion = cast<PackExpansionType>(this);
1596-
auto pattern = expansion->getPatternType()->getCanonicalType();
1597-
Result = PackExpansionType::get(pattern);
1596+
auto patternType = expansion->getPatternType()->getCanonicalType();
1597+
auto countType = expansion->getCountType()->getCanonicalType();
1598+
Result = PackExpansionType::get(patternType, countType);
15981599
break;
15991600
}
16001601

@@ -5535,12 +5536,9 @@ case TypeKind::Id:
55355536
return remap;
55365537
}
55375538

5538-
if (input->is<TypeVariableType>()) {
5539-
if (auto *PT = (*remap)->getAs<PackType>()) {
5540-
maxArity = std::max(maxArity, PT->getNumElements());
5541-
cache.insert({input, PT});
5542-
}
5543-
} else if (input->isTypeSequenceParameter()) {
5539+
if (input->is<TypeVariableType>() ||
5540+
input->isTypeSequenceParameter() ||
5541+
input->is<SequenceArchetypeType>()) {
55445542
if (auto *PT = (*remap)->getAs<PackType>()) {
55455543
maxArity = std::max(maxArity, PT->getNumElements());
55465544
cache.insert({input, PT});
@@ -5563,7 +5561,13 @@ case TypeKind::Id:
55635561
if (!transformedPat)
55645562
return Type();
55655563

5566-
if (transformedPat.getPointer() == expand->getPatternType().getPointer())
5564+
Type transformedCount =
5565+
expand->getCountType().transformWithPosition(pos, gather);
5566+
if (!transformedCount)
5567+
return Type();
5568+
5569+
if (transformedPat.getPointer() == expand->getPatternType().getPointer() &&
5570+
transformedCount.getPointer() == expand->getCountType().getPointer())
55675571
return *this;
55685572

55695573
llvm::DenseMap<TypeBase *, PackType *> expansions;
@@ -5573,7 +5577,7 @@ case TypeKind::Id:
55735577
// If we didn't find any expansions, either the caller wasn't interested
55745578
// in expanding this pack, or something has gone wrong. Leave off the
55755579
// expansion and return the transformed type.
5576-
return PackExpansionType::get(transformedPat);
5580+
return PackExpansionType::get(transformedPat, transformedCount);
55775581
}
55785582

55795583
SmallVector<Type, 8> elts;

Diff for: lib/Sema/TypeCheckDecl.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -2227,10 +2227,22 @@ static Type validateParameterType(ParamDecl *decl) {
22272227
}
22282228

22292229
if (decl->isVariadic()) {
2230+
// Find the first type sequence parameter and use that as the count type.
2231+
Type countTy;
2232+
(void) Ty.findIf([&](Type t) -> bool {
2233+
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
2234+
if (paramTy->isTypeSequence()) {
2235+
countTy = paramTy;
2236+
return true;
2237+
}
2238+
}
2239+
2240+
return false;
2241+
});
22302242
// Handle the monovariadic/polyvariadic interface type split.
2231-
if (Ty->hasTypeSequence()) {
2243+
if (countTy) {
22322244
// Polyvariadic types (T...) for <T...> resolve to pack expansions.
2233-
Ty = PackExpansionType::get(Ty);
2245+
Ty = PackExpansionType::get(Ty, countTy);
22342246
} else {
22352247
// Monovariadic types (T...) for <T> resolve to [T].
22362248
Ty = VariadicSequenceType::get(Ty);

Diff for: lib/Sema/TypeCheckType.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -4018,13 +4018,28 @@ NeverNullType TypeResolver::resolveTupleType(TupleTypeRepr *repr,
40184018
if (patternTy->hasError())
40194019
complained = true;
40204020

4021+
// Find the first type sequence parameter and use that as the count type.
4022+
Type countTy;
4023+
(void) patternTy.get().findIf([&](Type t) -> bool {
4024+
if (auto *paramTy = t->getAs<GenericTypeParamType>()) {
4025+
if (paramTy->isTypeSequence()) {
4026+
countTy = paramTy;
4027+
return true;
4028+
}
4029+
}
4030+
4031+
return false;
4032+
});
4033+
40214034
// If there's no reference to a variadic generic parameter, complain
40224035
// - the pack won't actually expand to anything meaningful.
4023-
if (!patternTy->hasTypeSequence())
4036+
if (!countTy) {
40244037
diagnose(repr->getLoc(), diag::expansion_not_variadic, patternTy)
40254038
.highlight(repr->getParens());
4039+
return ErrorType::get(getASTContext());
4040+
}
40264041

4027-
return PackExpansionType::get(patternTy);
4042+
return PackExpansionType::get(patternTy, countTy);
40284043
} else {
40294044
// Variadic tuples are not permitted.
40304045
//

Diff for: test/Constraints/type_sequence.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func bindPrefixAndSuffix() {
6565

6666
func invalidPacks() {
6767
func monovariadic1() -> (String...) {} // expected-error {{cannot create expansion with non-variadic type 'String'}}
68-
func monovariadic2<T>() -> (T...) {} // expected-error 2 {{cannot create expansion with non-variadic type 'T'}}
68+
func monovariadic2<T>() -> (T...) {} // expected-error {{cannot create expansion with non-variadic type 'T'}}
6969
func monovariadic3<T, U>() -> (T, U...) {} // expected-error {{cannot create a variadic tuple}}
7070
}
7171

0 commit comments

Comments
 (0)