Skip to content

Commit 0829f5b

Browse files
authored
Merge pull request #39080 from jckarter/emit-closure-literal-at-context-abstraction-2
SILGen: Emit literal closures at the abstraction level of their context. [take 2]
2 parents eba9e5b + 43506a2 commit 0829f5b

34 files changed

+528
-158
lines changed

include/swift/SIL/AbstractionPattern.h

+26-1
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ class AbstractionPattern {
432432
const clang::ObjCMethodDecl *ObjCMethod;
433433
const clang::CXXMethodDecl *CXXMethod;
434434
const AbstractionPattern *OrigTupleElements;
435+
const void *RawTypePtr;
435436
};
436437
CanGenericSignature GenericSig;
437438

@@ -1274,7 +1275,7 @@ class AbstractionPattern {
12741275
/// pattern?
12751276
bool matchesTuple(CanTupleType substType);
12761277

1277-
bool isTuple() {
1278+
bool isTuple() const {
12781279
switch (getKind()) {
12791280
case Kind::Invalid:
12801281
llvm_unreachable("querying invalid abstraction pattern!");
@@ -1386,8 +1387,32 @@ class AbstractionPattern {
13861387
Lowering::TypeConverter &TC
13871388
) const;
13881389

1390+
/// How values are passed or returned according to this abstraction pattern.
1391+
enum CallingConventionKind {
1392+
// Value is passed or returned directly as a unit.
1393+
Direct,
1394+
// Value is passed or returned indirectly through memory.
1395+
Indirect,
1396+
// Value is a tuple that is destructured, and each element is considered
1397+
// independently.
1398+
Destructured,
1399+
};
1400+
1401+
/// If this abstraction pattern appears in function return position, how is
1402+
/// the corresponding value returned?
1403+
CallingConventionKind getResultConvention(TypeConverter &TC) const;
1404+
1405+
/// If this abstraction pattern appears in function parameter position, how
1406+
/// is the corresponding value passed?
1407+
CallingConventionKind getParameterConvention(TypeConverter &TC) const;
1408+
13891409
void dump() const LLVM_ATTRIBUTE_USED;
13901410
void print(raw_ostream &OS) const;
1411+
1412+
bool operator==(const AbstractionPattern &other) const;
1413+
bool operator!=(const AbstractionPattern &other) const {
1414+
return !(*this == other);
1415+
}
13911416
};
13921417

13931418
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &out,

include/swift/SIL/TypeLowering.h

+19
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,9 @@ class TypeConverter {
734734
///
735735
/// Second element is a ResilienceExpansion.
736736
llvm::DenseMap<std::pair<SILType, unsigned>, unsigned> TypeFields;
737+
738+
llvm::DenseMap<AbstractClosureExpr *, Optional<AbstractionPattern>>
739+
ClosureAbstractionPatterns;
737740

738741
CanAnyFunctionType makeConstantInterfaceType(SILDeclRef constant);
739742

@@ -1106,6 +1109,22 @@ class TypeConverter {
11061109
SILType enumType,
11071110
EnumElementDecl *elt);
11081111

1112+
/// Get the preferred abstraction pattern, if any, by which to lower a
1113+
/// declaration.
1114+
///
1115+
/// This can be set using \c setAbstractionPattern , but only before
1116+
/// the abstraction pattern is queried using this function. Once the
1117+
/// abstraction pattern has been asked for, it may not be changed.
1118+
Optional<AbstractionPattern> getConstantAbstractionPattern(SILDeclRef constant);
1119+
1120+
/// Set the preferred abstraction pattern for a closure.
1121+
///
1122+
/// The abstraction pattern can only be set before any calls to
1123+
/// \c getConstantAbstractionPattern on the same closure. It may not be
1124+
/// changed once it has been read.
1125+
void setAbstractionPattern(AbstractClosureExpr *closure,
1126+
AbstractionPattern pattern);
1127+
11091128
private:
11101129
CanType computeLoweredRValueType(TypeExpansionContext context,
11111130
AbstractionPattern origType,

lib/SIL/IR/AbstractionPattern.cpp

+152
Original file line numberDiff line numberDiff line change
@@ -1233,3 +1233,155 @@ AbstractionPattern AbstractionPattern::getAutoDiffDerivativeFunctionType(
12331233
llvm_unreachable("called on unsupported abstraction pattern kind");
12341234
}
12351235
}
1236+
1237+
AbstractionPattern::CallingConventionKind
1238+
AbstractionPattern::getResultConvention(TypeConverter &TC) const {
1239+
// Tuples should be destructured.
1240+
if (isTuple()) {
1241+
return Destructured;
1242+
}
1243+
switch (getKind()) {
1244+
case Kind::Opaque:
1245+
// Maximally abstracted values are always passed indirectly.
1246+
return Indirect;
1247+
1248+
case Kind::OpaqueFunction:
1249+
case Kind::OpaqueDerivativeFunction:
1250+
case Kind::PartialCurriedObjCMethodType:
1251+
case Kind::CurriedObjCMethodType:
1252+
case Kind::PartialCurriedCFunctionAsMethodType:
1253+
case Kind::CurriedCFunctionAsMethodType:
1254+
case Kind::CFunctionAsMethodType:
1255+
case Kind::ObjCMethodType:
1256+
case Kind::CXXMethodType:
1257+
case Kind::CurriedCXXMethodType:
1258+
case Kind::PartialCurriedCXXMethodType:
1259+
case Kind::CXXOperatorMethodType:
1260+
case Kind::CurriedCXXOperatorMethodType:
1261+
case Kind::PartialCurriedCXXOperatorMethodType:
1262+
// Function types are always passed directly
1263+
return Direct;
1264+
1265+
case Kind::ClangType:
1266+
case Kind::Type:
1267+
case Kind::Discard:
1268+
// Pass according to the formal type.
1269+
return SILType::isFormallyReturnedIndirectly(getType(),
1270+
TC,
1271+
getGenericSignatureOrNull())
1272+
? Indirect : Direct;
1273+
1274+
case Kind::Invalid:
1275+
case Kind::Tuple:
1276+
case Kind::ObjCCompletionHandlerArgumentsType:
1277+
llvm_unreachable("should not get here");
1278+
}
1279+
}
1280+
1281+
AbstractionPattern::CallingConventionKind
1282+
AbstractionPattern::getParameterConvention(TypeConverter &TC) const {
1283+
// Tuples should be destructured.
1284+
if (isTuple()) {
1285+
return Destructured;
1286+
}
1287+
switch (getKind()) {
1288+
case Kind::Opaque:
1289+
// Maximally abstracted values are always passed indirectly.
1290+
return Indirect;
1291+
1292+
case Kind::OpaqueFunction:
1293+
case Kind::OpaqueDerivativeFunction:
1294+
case Kind::PartialCurriedObjCMethodType:
1295+
case Kind::CurriedObjCMethodType:
1296+
case Kind::PartialCurriedCFunctionAsMethodType:
1297+
case Kind::CurriedCFunctionAsMethodType:
1298+
case Kind::CFunctionAsMethodType:
1299+
case Kind::ObjCMethodType:
1300+
case Kind::CXXMethodType:
1301+
case Kind::CurriedCXXMethodType:
1302+
case Kind::PartialCurriedCXXMethodType:
1303+
case Kind::CXXOperatorMethodType:
1304+
case Kind::CurriedCXXOperatorMethodType:
1305+
case Kind::PartialCurriedCXXOperatorMethodType:
1306+
// Function types are always passed directly
1307+
return Direct;
1308+
1309+
case Kind::ClangType:
1310+
case Kind::Type:
1311+
case Kind::Discard:
1312+
// Pass according to the formal type.
1313+
return SILType::isFormallyPassedIndirectly(getType(),
1314+
TC,
1315+
getGenericSignatureOrNull())
1316+
? Indirect : Direct;
1317+
1318+
case Kind::Invalid:
1319+
case Kind::Tuple:
1320+
case Kind::ObjCCompletionHandlerArgumentsType:
1321+
llvm_unreachable("should not get here");
1322+
}
1323+
}
1324+
1325+
bool
1326+
AbstractionPattern::operator==(const AbstractionPattern &other) const {
1327+
if (TheKind != other.TheKind)
1328+
return false;
1329+
1330+
switch (getKind()) {
1331+
case Kind::Opaque:
1332+
case Kind::Invalid:
1333+
case Kind::OpaqueFunction:
1334+
case Kind::OpaqueDerivativeFunction:
1335+
// No additional info to compare.
1336+
return true;
1337+
1338+
case Kind::Tuple:
1339+
if (getNumTupleElements() != other.getNumTupleElements()) {
1340+
return false;
1341+
}
1342+
for (unsigned i = 0; i < getNumTupleElements(); ++i) {
1343+
if (getTupleElementType(i) != other.getTupleElementType(i)) {
1344+
return false;
1345+
}
1346+
}
1347+
return true;
1348+
1349+
case Kind::Type:
1350+
case Kind::Discard:
1351+
return OrigType == other.OrigType
1352+
&& GenericSig == other.GenericSig;
1353+
1354+
case Kind::ClangType:
1355+
return OrigType == other.OrigType
1356+
&& GenericSig == other.GenericSig
1357+
&& ClangType == other.ClangType;
1358+
1359+
case Kind::ObjCCompletionHandlerArgumentsType:
1360+
case Kind::CFunctionAsMethodType:
1361+
case Kind::CurriedCFunctionAsMethodType:
1362+
case Kind::PartialCurriedCFunctionAsMethodType:
1363+
return OrigType == other.OrigType
1364+
&& GenericSig == other.GenericSig
1365+
&& ClangType == other.ClangType
1366+
&& OtherData == other.OtherData;
1367+
1368+
case Kind::ObjCMethodType:
1369+
case Kind::CurriedObjCMethodType:
1370+
case Kind::PartialCurriedObjCMethodType:
1371+
return OrigType == other.OrigType
1372+
&& GenericSig == other.GenericSig
1373+
&& ObjCMethod == other.ObjCMethod
1374+
&& OtherData == other.OtherData;
1375+
1376+
case Kind::CXXMethodType:
1377+
case Kind::CXXOperatorMethodType:
1378+
case Kind::CurriedCXXMethodType:
1379+
case Kind::CurriedCXXOperatorMethodType:
1380+
case Kind::PartialCurriedCXXMethodType:
1381+
case Kind::PartialCurriedCXXOperatorMethodType:
1382+
return OrigType == other.OrigType
1383+
&& GenericSig == other.GenericSig
1384+
&& CXXMethod == other.CXXMethod
1385+
&& OtherData == other.OtherData;
1386+
}
1387+
}

lib/SIL/IR/SILFunctionType.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -2072,7 +2072,8 @@ static CanSILFunctionType getSILFunctionType(
20722072
// for thick or polymorphic functions. We don't need to worry about
20732073
// non-opaque patterns because the type-checker forbids non-thick
20742074
// function types from having generic parameters or results.
2075-
if (origType.isTypeParameter() &&
2075+
if (!constant &&
2076+
origType.isTypeParameter() &&
20762077
substFnInterfaceType->getExtInfo().getSILRepresentation()
20772078
!= SILFunctionType::Representation::Thick &&
20782079
isa<FunctionType>(substFnInterfaceType)) {
@@ -3183,9 +3184,18 @@ static CanSILFunctionType getUncachedSILFunctionTypeForConstant(
31833184
auto proto = constant.getDecl()->getDeclContext()->getSelfProtocolDecl();
31843185
witnessMethodConformance = ProtocolConformanceRef(proto);
31853186
}
3187+
3188+
// Does this constant have a preferred abstraction pattern set?
3189+
AbstractionPattern origType = [&]{
3190+
if (auto abstraction = TC.getConstantAbstractionPattern(constant)) {
3191+
return *abstraction;
3192+
} else {
3193+
return AbstractionPattern(origLoweredInterfaceType);
3194+
}
3195+
}();
31863196

31873197
return ::getNativeSILFunctionType(
3188-
TC, context, AbstractionPattern(origLoweredInterfaceType),
3198+
TC, context, origType,
31893199
origLoweredInterfaceType, extInfoBuilder, constant, constant, None,
31903200
witnessMethodConformance);
31913201
}

lib/SIL/IR/TypeLowering.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -3491,6 +3491,28 @@ CanSILBoxType TypeConverter::getBoxTypeForEnumElement(
34913491
return boxTy;
34923492
}
34933493

3494+
Optional<AbstractionPattern>
3495+
TypeConverter::getConstantAbstractionPattern(SILDeclRef constant) {
3496+
if (auto closure = constant.getAbstractClosureExpr()) {
3497+
// Using operator[] here creates an entry in the map if one doesn't exist
3498+
// yet, marking the fact that the lack of abstraction pattern has been
3499+
// established and cannot be overridden by `setAbstractionPattern` later.
3500+
return ClosureAbstractionPatterns[closure];
3501+
}
3502+
return None;
3503+
}
3504+
3505+
void TypeConverter::setAbstractionPattern(AbstractClosureExpr *closure,
3506+
AbstractionPattern pattern) {
3507+
auto existing = ClosureAbstractionPatterns.find(closure);
3508+
if (existing != ClosureAbstractionPatterns.end()) {
3509+
assert(*existing->second == pattern
3510+
&& "closure shouldn't be emitted at different abstraction level contexts");
3511+
} else {
3512+
ClosureAbstractionPatterns[closure] = pattern;
3513+
}
3514+
}
3515+
34943516
static void countNumberOfInnerFields(unsigned &fieldsCount, TypeConverter &TC,
34953517
SILType Ty,
34963518
TypeExpansionContext expansion) {

lib/SILGen/Conversion.h

+23
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class Conversion {
5959
static bool isBridgingKind(KindTy kind) {
6060
return kind <= LastBridgingKind;
6161
}
62+
63+
static bool isReabstractionKind(KindTy kind) {
64+
// Update if we end up with more kinds!
65+
return !isBridgingKind(kind);
66+
}
6267

6368
private:
6469
KindTy Kind;
@@ -139,6 +144,10 @@ class Conversion {
139144
bool isBridging() const {
140145
return isBridgingKind(getKind());
141146
}
147+
148+
bool isReabstraction() const {
149+
return isReabstractionKind(getKind());
150+
}
142151

143152
AbstractionPattern getReabstractionOrigType() const {
144153
return Types.get<ReabstractionTypes>(Kind).OrigType;
@@ -264,12 +273,21 @@ class ConvertingInitialization final : public Initialization {
264273
StateTy getState() const {
265274
return State;
266275
}
276+
277+
InitializationPtr OwnedSubInitialization;
267278

268279
public:
269280
ConvertingInitialization(Conversion conversion, SGFContext finalContext)
270281
: State(Uninitialized), TheConversion(conversion),
271282
FinalContext(finalContext) {}
272283

284+
ConvertingInitialization(Conversion conversion,
285+
InitializationPtr subInitialization)
286+
: State(Uninitialized), TheConversion(conversion),
287+
FinalContext(SGFContext(subInitialization.get())) {
288+
OwnedSubInitialization = std::move(subInitialization);
289+
}
290+
273291
/// Return the conversion to apply to the unconverted value.
274292
const Conversion &getConversion() const {
275293
return TheConversion;
@@ -328,11 +346,16 @@ class ConvertingInitialization final : public Initialization {
328346
ConvertingInitialization *getAsConversion() override {
329347
return this;
330348
}
349+
350+
// Get the abstraction pattern, if any, the value is converted to.
351+
Optional<AbstractionPattern> getAbstractionPattern() const override;
331352

332353
// Bookkeeping.
333354
void finishInitialization(SILGenFunction &SGF) override {
334355
assert(getState() == Initialized);
335356
State = Finished;
357+
if (OwnedSubInitialization)
358+
OwnedSubInitialization->finishInitialization(SGF);
336359
}
337360
};
338361

lib/SILGen/Initialization.h

+10
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ class Initialization {
161161
"uninitialized");
162162
}
163163

164+
/// The preferred abstraction pattern to initialize with.
165+
///
166+
/// Returning something other than None here gives expression emission the
167+
/// opportunity to generate the initial value directly at the proper
168+
/// abstraction level, avoiding the need for a conversion in some
169+
/// circumstances.
170+
virtual Optional<AbstractionPattern> getAbstractionPattern() const {
171+
return None;
172+
}
173+
164174
protected:
165175
bool EmitDebugValueOnInit = true;
166176

lib/SILGen/SGFContext.h

+8
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ class SGFContext {
167167

168168
return SGFContext();
169169
}
170+
171+
/// Return the abstraction pattern of the context we're emitting into.
172+
Optional<AbstractionPattern> getAbstractionPattern() const {
173+
if (auto *init = getEmitInto()) {
174+
return init->getAbstractionPattern();
175+
}
176+
return None;
177+
}
170178
};
171179

172180
using ValueProducerRef =

0 commit comments

Comments
 (0)