Skip to content

Commit 5ec5ffc

Browse files
authored
Merge pull request #41570 from xedin/add-expr-pattern-handling-to-solver
[ConstraintSystem] Add support for expression patterns
2 parents 428892d + 31b356f commit 5ec5ffc

9 files changed

+115
-41
lines changed

include/swift/Sema/ConstraintLocator.h

+3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ enum ContextualTypePurpose : uint8_t {
8585
CTP_ComposedPropertyWrapper, ///< Composed wrapper type expected to match
8686
///< former 'wrappedValue' type
8787

88+
CTP_ExprPattern, ///< `~=` operator application associated with expression
89+
/// pattern.
90+
8891
CTP_CannotFail, ///< Conversion can never fail. abort() if it does.
8992
};
9093

include/swift/Sema/ConstraintSystem.h

+46-1
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,7 @@ class SolutionApplicationTargetsKey {
968968
stmtCondElement,
969969
expr,
970970
stmt,
971+
pattern,
971972
patternBindingEntry,
972973
varDecl,
973974
};
@@ -982,6 +983,8 @@ class SolutionApplicationTargetsKey {
982983

983984
const Stmt *stmt;
984985

986+
const Pattern *pattern;
987+
985988
struct PatternBindingEntry {
986989
const PatternBindingDecl *patternBinding;
987990
unsigned index;
@@ -1011,6 +1014,11 @@ class SolutionApplicationTargetsKey {
10111014
storage.stmt = stmt;
10121015
}
10131016

1017+
SolutionApplicationTargetsKey(const Pattern *pattern) {
1018+
kind = Kind::pattern;
1019+
storage.pattern = pattern;
1020+
}
1021+
10141022
SolutionApplicationTargetsKey(
10151023
const PatternBindingDecl *patternBinding, unsigned index) {
10161024
kind = Kind::patternBindingEntry;
@@ -1042,6 +1050,9 @@ class SolutionApplicationTargetsKey {
10421050
case Kind::stmt:
10431051
return lhs.storage.stmt == rhs.storage.stmt;
10441052

1053+
case Kind::pattern:
1054+
return lhs.storage.pattern == rhs.storage.pattern;
1055+
10451056
case Kind::patternBindingEntry:
10461057
return (lhs.storage.patternBindingEntry.patternBinding
10471058
== rhs.storage.patternBindingEntry.patternBinding) &&
@@ -1083,6 +1094,11 @@ class SolutionApplicationTargetsKey {
10831094
DenseMapInfo<unsigned>::getHashValue(static_cast<unsigned>(kind)),
10841095
DenseMapInfo<void *>::getHashValue(storage.stmt));
10851096

1097+
case Kind::pattern:
1098+
return hash_combine(
1099+
DenseMapInfo<unsigned>::getHashValue(static_cast<unsigned>(kind)),
1100+
DenseMapInfo<void *>::getHashValue(storage.pattern));
1101+
10861102
case Kind::patternBindingEntry:
10871103
return hash_combine(
10881104
DenseMapInfo<unsigned>::getHashValue(static_cast<unsigned>(kind)),
@@ -1701,6 +1717,13 @@ class SolutionApplicationTarget {
17011717
ContextualTypePurpose contextualPurpose,
17021718
TypeLoc convertType, bool isDiscarded);
17031719

1720+
SolutionApplicationTarget(Expr *expr, DeclContext *dc, ExprPattern *pattern,
1721+
Type patternType)
1722+
: SolutionApplicationTarget(expr, dc, CTP_ExprPattern, patternType,
1723+
/*isDiscarded=*/false) {
1724+
setPattern(pattern);
1725+
}
1726+
17041727
SolutionApplicationTarget(AnyFunctionRef fn)
17051728
: SolutionApplicationTarget(fn, fn.getBody()) { }
17061729

@@ -1786,6 +1809,12 @@ class SolutionApplicationTarget {
17861809
static SolutionApplicationTarget forPropertyWrapperInitializer(
17871810
VarDecl *wrappedVar, DeclContext *dc, Expr *initializer);
17881811

1812+
static SolutionApplicationTarget forExprPattern(Expr *expr, DeclContext *dc,
1813+
ExprPattern *pattern,
1814+
Type patternTy) {
1815+
return {expr, dc, pattern, patternTy};
1816+
}
1817+
17891818
Expr *getAsExpr() const {
17901819
switch (kind) {
17911820
case Kind::expression:
@@ -1888,6 +1917,12 @@ class SolutionApplicationTarget {
18881917
return expression.pattern;
18891918
}
18901919

1920+
ExprPattern *getExprPattern() const {
1921+
assert(kind == Kind::expression);
1922+
assert(expression.contextualPurpose == CTP_ExprPattern);
1923+
return cast<ExprPattern>(expression.pattern);
1924+
}
1925+
18911926
/// For a pattern initialization target, retrieve the contextual pattern.
18921927
ContextualPattern getContextualPattern() const;
18931928

@@ -2008,7 +2043,8 @@ class SolutionApplicationTarget {
20082043
assert(kind == Kind::expression);
20092044
assert(expression.contextualPurpose == CTP_Initialization ||
20102045
expression.contextualPurpose == CTP_ForEachStmt ||
2011-
expression.contextualPurpose == CTP_ForEachSequence);
2046+
expression.contextualPurpose == CTP_ForEachSequence ||
2047+
expression.contextualPurpose == CTP_ExprPattern);
20122048
expression.pattern = pattern;
20132049
}
20142050

@@ -5107,6 +5143,15 @@ class ConstraintSystem {
51075143
= FreeTypeVariableBinding::Disallow);
51085144

51095145
public:
5146+
/// Pre-check the target, validating any types that occur in it
5147+
/// and folding sequence expressions.
5148+
///
5149+
/// \param replaceInvalidRefsWithErrors Indicates whether it's allowed
5150+
/// to replace any discovered invalid member references with `ErrorExpr`.
5151+
static bool preCheckTarget(SolutionApplicationTarget &target,
5152+
bool replaceInvalidRefsWithErrors,
5153+
bool leaveClosureBodiesUnchecked);
5154+
51105155
/// Pre-check the expression, validating any types that occur in the
51115156
/// expression and folding sequence expressions.
51125157
///

lib/Sema/CSApply.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -8936,6 +8936,7 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
89368936
case CTP_Unused:
89378937
case CTP_CaseStmt:
89388938
case CTP_ReturnStmt:
8939+
case CTP_ExprPattern:
89398940
case swift::CTP_ReturnSingleExpr:
89408941
case swift::CTP_YieldByValue:
89418942
case swift::CTP_YieldByReference:

lib/Sema/CSDiagnostics.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ Optional<Diag<Type, Type>> GenericArgumentsMismatchFailure::getDiagnosticFor(
710710
case CTP_YieldByReference:
711711
case CTP_CalleeResult:
712712
case CTP_EnumCaseRawValue:
713+
case CTP_ExprPattern:
713714
break;
714715
}
715716
return None;
@@ -2538,6 +2539,7 @@ getContextualNilDiagnostic(ContextualTypePurpose CTP) {
25382539
case CTP_YieldByReference:
25392540
case CTP_WrappedProperty:
25402541
case CTP_ComposedPropertyWrapper:
2542+
case CTP_ExprPattern:
25412543
return None;
25422544

25432545
case CTP_EnumCaseRawValue:
@@ -3302,6 +3304,7 @@ ContextualFailure::getDiagnosticFor(ContextualTypePurpose context,
33023304
case CTP_CannotFail:
33033305
case CTP_YieldByReference:
33043306
case CTP_CalleeResult:
3307+
case CTP_ExprPattern:
33053308
break;
33063309
}
33073310
return None;

lib/Sema/CSSimplify.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -12996,6 +12996,7 @@ void ConstraintSystem::addContextualConversionConstraint(
1299612996
case CTP_ForEachStmt:
1299712997
case CTP_WrappedProperty:
1299812998
case CTP_ComposedPropertyWrapper:
12999+
case CTP_ExprPattern:
1299913000
break;
1300013001
}
1300113002

lib/Sema/ConstraintSystem.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -5887,6 +5887,7 @@ bool SolutionApplicationTarget::contextualTypeIsOnlyAHint() const {
58875887
case CTP_WrappedProperty:
58885888
case CTP_ComposedPropertyWrapper:
58895889
case CTP_CannotFail:
5890+
case CTP_ExprPattern:
58905891
return false;
58915892
}
58925893
llvm_unreachable("invalid contextual type");

lib/Sema/PreCheckExpr.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -2110,6 +2110,40 @@ Expr *PreCheckExpression::simplifyTypeConstructionWithLiteralArg(Expr *E) {
21102110
: nullptr;
21112111
}
21122112

2113+
bool ConstraintSystem::preCheckTarget(SolutionApplicationTarget &target,
2114+
bool replaceInvalidRefsWithErrors,
2115+
bool leaveClosureBodiesUnchecked) {
2116+
auto *DC = target.getDeclContext();
2117+
2118+
bool hadErrors = false;
2119+
2120+
if (auto *expr = target.getAsExpr()) {
2121+
hadErrors |= preCheckExpression(expr, DC, replaceInvalidRefsWithErrors,
2122+
leaveClosureBodiesUnchecked);
2123+
// Even if the pre-check fails, expression still has to be re-set.
2124+
target.setExpr(expr);
2125+
}
2126+
2127+
if (target.isForEachStmt()) {
2128+
auto &info = target.getForEachStmtInfo();
2129+
2130+
if (info.whereExpr)
2131+
hadErrors |= preCheckExpression(info.whereExpr, DC,
2132+
/*replaceInvalidRefsWithErrors=*/true,
2133+
/*leaveClosureBodiesUnchecked=*/false);
2134+
2135+
// Update sequence and where expressions to pre-checked versions.
2136+
if (!hadErrors) {
2137+
info.stmt->setSequence(target.getAsExpr());
2138+
2139+
if (info.whereExpr)
2140+
info.stmt->setWhere(info.whereExpr);
2141+
}
2142+
}
2143+
2144+
return hadErrors;
2145+
}
2146+
21132147
/// Pre-check the expression, validating any types that occur in the
21142148
/// expression and folding sequence expressions.
21152149
bool ConstraintSystem::preCheckExpression(Expr *&expr, DeclContext *dc,

lib/Sema/TypeCheckCodeCompletion.cpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -579,12 +579,10 @@ bool TypeChecker::typeCheckForCodeCompletion(
579579
if (needsPrecheck) {
580580
// First, pre-check the expression, validating any types that occur in the
581581
// expression and folding sequence expressions.
582-
auto failedPreCheck = ConstraintSystem::preCheckExpression(
583-
expr, DC,
584-
/*replaceInvalidRefsWithErrors=*/true,
585-
/*leaveClosureBodiesUnchecked=*/true);
586-
587-
target.setExpr(expr);
582+
auto failedPreCheck =
583+
ConstraintSystem::preCheckTarget(target,
584+
/*replaceInvalidRefsWithErrors=*/true,
585+
/*leaveClosureBodiesUnchecked=*/true);
588586

589587
if (failedPreCheck)
590588
return false;

lib/Sema/TypeCheckConstraints.cpp

+22-34
Original file line numberDiff line numberDiff line change
@@ -337,25 +337,23 @@ Type TypeChecker::typeCheckExpression(Expr *&expr, DeclContext *dc,
337337
}
338338

339339
Optional<SolutionApplicationTarget>
340-
TypeChecker::typeCheckExpression(
341-
SolutionApplicationTarget &target,
342-
TypeCheckExprOptions options) {
343-
Expr *expr = target.getAsExpr();
340+
TypeChecker::typeCheckExpression(SolutionApplicationTarget &target,
341+
TypeCheckExprOptions options) {
344342
DeclContext *dc = target.getDeclContext();
345343
auto &Context = dc->getASTContext();
346-
FrontendStatsTracer StatsTracer(Context.Stats,
347-
"typecheck-expr", expr);
348-
PrettyStackTraceExpr stackTrace(Context, "type-checking", expr);
344+
FrontendStatsTracer StatsTracer(Context.Stats, "typecheck-expr",
345+
target.getAsExpr());
346+
PrettyStackTraceExpr stackTrace(Context, "type-checking", target.getAsExpr());
349347

350348
// First, pre-check the expression, validating any types that occur in the
351349
// expression and folding sequence expressions.
352-
if (ConstraintSystem::preCheckExpression(
353-
expr, dc, /*replaceInvalidRefsWithErrors=*/true,
354-
options.contains(TypeCheckExprFlags::LeaveClosureBodyUnchecked))) {
355-
target.setExpr(expr);
350+
if (ConstraintSystem::preCheckTarget(
351+
target, /*replaceInvalidRefsWithErrors=*/true,
352+
options.contains(TypeCheckExprFlags::LeaveClosureBodyUnchecked))) {
356353
return None;
357354
}
358-
target.setExpr(expr);
355+
356+
auto *expr = target.getAsExpr();
359357

360358
// Check whether given expression has a code completion token which requires
361359
// special handling.
@@ -810,25 +808,6 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) {
810808
if (!sequenceProto)
811809
return failed();
812810

813-
// Precheck the sequence.
814-
Expr *sequence = stmt->getSequence();
815-
if (ConstraintSystem::preCheckExpression(
816-
sequence, dc, /*replaceInvalidRefsWithErrors=*/true,
817-
/*leaveClosureBodiesUnchecked=*/false))
818-
return failed();
819-
stmt->setSequence(sequence);
820-
821-
// Precheck the filtering condition.
822-
if (Expr *whereExpr = stmt->getWhere()) {
823-
if (ConstraintSystem::preCheckExpression(
824-
whereExpr, dc,
825-
/*replaceInvalidRefsWithErrors=*/true,
826-
/*leaveClosureBodiesUnchecked=*/false))
827-
return failed();
828-
829-
stmt->setWhere(whereExpr);
830-
}
831-
832811
auto target = SolutionApplicationTarget::forForEachStmt(
833812
stmt, sequenceProto, dc, /*bindPatternVarsOneWay=*/false);
834813
if (!typeCheckExpression(target))
@@ -913,13 +892,22 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
913892
/*Implicit=*/true);
914893
Expr *matchCall = BinaryExpr::create(Context, EP->getSubExpr(), matchOp,
915894
matchVarRef, /*implicit*/ true);
895+
896+
// Result of `~=` should always be a boolean.
897+
auto contextualTy = Context.getBoolDecl()->getDeclaredInterfaceType();
898+
auto target = SolutionApplicationTarget::forExprPattern(matchCall, DC, EP,
899+
contextualTy);
900+
916901
// Check the expression as a condition.
917-
bool hadError = typeCheckCondition(matchCall, DC);
902+
auto result = typeCheckExpression(target);
903+
if (!result)
904+
return true;
905+
918906
// Save the type-checked expression in the pattern.
919-
EP->setMatchExpr(matchCall);
907+
EP->setMatchExpr(result->getAsExpr());
920908
// Set the type on the pattern.
921909
EP->setType(rhsType);
922-
return hadError;
910+
return false;
923911
}
924912

925913
static Type replaceArchetypesWithTypeVariables(ConstraintSystem &cs,

0 commit comments

Comments
 (0)