Skip to content

Commit fdc0e08

Browse files
committed
SILGen: Emit literal closures at the abstraction level of their context.
Literal closures are only ever directly referenced in the context of the expression they're written in, so it's wasteful to emit them at their fully-substituted calling convention and then reabstract them if they're passed directly to a generic function. Avoid this by saving the abstraction pattern of the context before emitting the closure, and then lowering its main entry point's calling convention at that level of abstraction. Generalize some of the prolog/epilog code to handle converting arguments and returns to the correct representation for a different abstraction level.
1 parent 52e852a commit fdc0e08

34 files changed

+532
-159
lines changed

include/swift/SIL/AbstractionPattern.h

+26-1
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ class AbstractionPattern {
432432
const clang::ObjCMethodDecl *ObjCMethod;
433433
const clang::CXXMethodDecl *CXXMethod;
434434
const AbstractionPattern *OrigTupleElements;
435+
const void *RawTypePtr;
435436
};
436437
CanGenericSignature GenericSig;
437438

@@ -1274,7 +1275,7 @@ class AbstractionPattern {
12741275
/// pattern?
12751276
bool matchesTuple(CanTupleType substType);
12761277

1277-
bool isTuple() {
1278+
bool isTuple() const {
12781279
switch (getKind()) {
12791280
case Kind::Invalid:
12801281
llvm_unreachable("querying invalid abstraction pattern!");
@@ -1386,8 +1387,32 @@ class AbstractionPattern {
13861387
Lowering::TypeConverter &TC
13871388
) const;
13881389

1390+
/// How values are passed or returned according to this abstraction pattern.
1391+
enum CallingConventionKind {
1392+
// Value is passed or returned directly as a unit.
1393+
Direct,
1394+
// Value is passed or returned indirectly through memory.
1395+
Indirect,
1396+
// Value is a tuple that is destructured, and each element is considered
1397+
// independently.
1398+
Destructured,
1399+
};
1400+
1401+
/// If this abstraction pattern appears in function return position, how is
1402+
/// the corresponding value returned?
1403+
CallingConventionKind getResultConvention(TypeConverter &TC) const;
1404+
1405+
/// If this abstraction pattern appears in function parameter position, how
1406+
/// is the corresponding value passed?
1407+
CallingConventionKind getParameterConvention(TypeConverter &TC) const;
1408+
13891409
void dump() const LLVM_ATTRIBUTE_USED;
13901410
void print(raw_ostream &OS) const;
1411+
1412+
bool operator==(const AbstractionPattern &other) const;
1413+
bool operator!=(const AbstractionPattern &other) const {
1414+
return !(*this == other);
1415+
}
13911416
};
13921417

13931418
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &out,

include/swift/SIL/TypeLowering.h

+19
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,9 @@ class TypeConverter {
734734
///
735735
/// Second element is a ResilienceExpansion.
736736
llvm::DenseMap<std::pair<SILType, unsigned>, unsigned> TypeFields;
737+
738+
llvm::DenseMap<AbstractClosureExpr *, Optional<AbstractionPattern>>
739+
ClosureAbstractionPatterns;
737740

738741
CanAnyFunctionType makeConstantInterfaceType(SILDeclRef constant);
739742

@@ -1106,6 +1109,22 @@ class TypeConverter {
11061109
SILType enumType,
11071110
EnumElementDecl *elt);
11081111

1112+
/// Get the preferred abstraction pattern, if any, by which to lower a
1113+
/// declaration.
1114+
///
1115+
/// This can be set using \c setAbstractionPattern , but only before
1116+
/// the abstraction pattern is queried using this function. Once the
1117+
/// abstraction pattern has been asked for, it may not be changed.
1118+
Optional<AbstractionPattern> getConstantAbstractionPattern(SILDeclRef constant);
1119+
1120+
/// Set the preferred abstraction pattern for a closure.
1121+
///
1122+
/// The abstraction pattern can only be set before any calls to
1123+
/// \c getConstantAbstractionPattern on the same closure. It may not be
1124+
/// changed once it has been read.
1125+
void setAbstractionPattern(AbstractClosureExpr *closure,
1126+
AbstractionPattern pattern);
1127+
11091128
private:
11101129
CanType computeLoweredRValueType(TypeExpansionContext context,
11111130
AbstractionPattern origType,

lib/SIL/IR/AbstractionPattern.cpp

+152
Original file line numberDiff line numberDiff line change
@@ -1233,3 +1233,155 @@ AbstractionPattern AbstractionPattern::getAutoDiffDerivativeFunctionType(
12331233
llvm_unreachable("called on unsupported abstraction pattern kind");
12341234
}
12351235
}
1236+
1237+
AbstractionPattern::CallingConventionKind
1238+
AbstractionPattern::getResultConvention(TypeConverter &TC) const {
1239+
// Tuples should be destructured.
1240+
if (isTuple()) {
1241+
return Destructured;
1242+
}
1243+
switch (getKind()) {
1244+
case Kind::Opaque:
1245+
// Maximally abstracted values are always passed indirectly.
1246+
return Indirect;
1247+
1248+
case Kind::OpaqueFunction:
1249+
case Kind::OpaqueDerivativeFunction:
1250+
case Kind::PartialCurriedObjCMethodType:
1251+
case Kind::CurriedObjCMethodType:
1252+
case Kind::PartialCurriedCFunctionAsMethodType:
1253+
case Kind::CurriedCFunctionAsMethodType:
1254+
case Kind::CFunctionAsMethodType:
1255+
case Kind::ObjCMethodType:
1256+
case Kind::CXXMethodType:
1257+
case Kind::CurriedCXXMethodType:
1258+
case Kind::PartialCurriedCXXMethodType:
1259+
case Kind::CXXOperatorMethodType:
1260+
case Kind::CurriedCXXOperatorMethodType:
1261+
case Kind::PartialCurriedCXXOperatorMethodType:
1262+
// Function types are always passed directly
1263+
return Direct;
1264+
1265+
case Kind::ClangType:
1266+
case Kind::Type:
1267+
case Kind::Discard:
1268+
// Pass according to the formal type.
1269+
return SILType::isFormallyReturnedIndirectly(getType(),
1270+
TC,
1271+
getGenericSignatureOrNull())
1272+
? Indirect : Direct;
1273+
1274+
case Kind::Invalid:
1275+
case Kind::Tuple:
1276+
case Kind::ObjCCompletionHandlerArgumentsType:
1277+
llvm_unreachable("should not get here");
1278+
}
1279+
}
1280+
1281+
AbstractionPattern::CallingConventionKind
1282+
AbstractionPattern::getParameterConvention(TypeConverter &TC) const {
1283+
// Tuples should be destructured.
1284+
if (isTuple()) {
1285+
return Destructured;
1286+
}
1287+
switch (getKind()) {
1288+
case Kind::Opaque:
1289+
// Maximally abstracted values are always passed indirectly.
1290+
return Indirect;
1291+
1292+
case Kind::OpaqueFunction:
1293+
case Kind::OpaqueDerivativeFunction:
1294+
case Kind::PartialCurriedObjCMethodType:
1295+
case Kind::CurriedObjCMethodType:
1296+
case Kind::PartialCurriedCFunctionAsMethodType:
1297+
case Kind::CurriedCFunctionAsMethodType:
1298+
case Kind::CFunctionAsMethodType:
1299+
case Kind::ObjCMethodType:
1300+
case Kind::CXXMethodType:
1301+
case Kind::CurriedCXXMethodType:
1302+
case Kind::PartialCurriedCXXMethodType:
1303+
case Kind::CXXOperatorMethodType:
1304+
case Kind::CurriedCXXOperatorMethodType:
1305+
case Kind::PartialCurriedCXXOperatorMethodType:
1306+
// Function types are always passed directly
1307+
return Direct;
1308+
1309+
case Kind::ClangType:
1310+
case Kind::Type:
1311+
case Kind::Discard:
1312+
// Pass according to the formal type.
1313+
return SILType::isFormallyPassedIndirectly(getType(),
1314+
TC,
1315+
getGenericSignatureOrNull())
1316+
? Indirect : Direct;
1317+
1318+
case Kind::Invalid:
1319+
case Kind::Tuple:
1320+
case Kind::ObjCCompletionHandlerArgumentsType:
1321+
llvm_unreachable("should not get here");
1322+
}
1323+
}
1324+
1325+
bool
1326+
AbstractionPattern::operator==(const AbstractionPattern &other) const {
1327+
if (TheKind != other.TheKind)
1328+
return false;
1329+
1330+
switch (getKind()) {
1331+
case Kind::Opaque:
1332+
case Kind::Invalid:
1333+
case Kind::OpaqueFunction:
1334+
case Kind::OpaqueDerivativeFunction:
1335+
// No additional info to compare.
1336+
return true;
1337+
1338+
case Kind::Tuple:
1339+
if (getNumTupleElements() != other.getNumTupleElements()) {
1340+
return false;
1341+
}
1342+
for (unsigned i = 0; i < getNumTupleElements(); ++i) {
1343+
if (getTupleElementType(i) != other.getTupleElementType(i)) {
1344+
return false;
1345+
}
1346+
}
1347+
return true;
1348+
1349+
case Kind::Type:
1350+
case Kind::Discard:
1351+
return OrigType == other.OrigType
1352+
&& GenericSig == other.GenericSig;
1353+
1354+
case Kind::ClangType:
1355+
return OrigType == other.OrigType
1356+
&& GenericSig == other.GenericSig
1357+
&& ClangType == other.ClangType;
1358+
1359+
case Kind::ObjCCompletionHandlerArgumentsType:
1360+
case Kind::CFunctionAsMethodType:
1361+
case Kind::CurriedCFunctionAsMethodType:
1362+
case Kind::PartialCurriedCFunctionAsMethodType:
1363+
return OrigType == other.OrigType
1364+
&& GenericSig == other.GenericSig
1365+
&& ClangType == other.ClangType
1366+
&& OtherData == other.OtherData;
1367+
1368+
case Kind::ObjCMethodType:
1369+
case Kind::CurriedObjCMethodType:
1370+
case Kind::PartialCurriedObjCMethodType:
1371+
return OrigType == other.OrigType
1372+
&& GenericSig == other.GenericSig
1373+
&& ObjCMethod == other.ObjCMethod
1374+
&& OtherData == other.OtherData;
1375+
1376+
case Kind::CXXMethodType:
1377+
case Kind::CXXOperatorMethodType:
1378+
case Kind::CurriedCXXMethodType:
1379+
case Kind::CurriedCXXOperatorMethodType:
1380+
case Kind::PartialCurriedCXXMethodType:
1381+
case Kind::PartialCurriedCXXOperatorMethodType:
1382+
return OrigType == other.OrigType
1383+
&& GenericSig == other.GenericSig
1384+
&& CXXMethod == other.CXXMethod
1385+
&& OtherData == other.OtherData;
1386+
}
1387+
}

lib/SIL/IR/SILFunctionType.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -2072,7 +2072,8 @@ static CanSILFunctionType getSILFunctionType(
20722072
// for thick or polymorphic functions. We don't need to worry about
20732073
// non-opaque patterns because the type-checker forbids non-thick
20742074
// function types from having generic parameters or results.
2075-
if (origType.isTypeParameter() &&
2075+
if (!constant &&
2076+
origType.isTypeParameter() &&
20762077
substFnInterfaceType->getExtInfo().getSILRepresentation()
20772078
!= SILFunctionType::Representation::Thick &&
20782079
isa<FunctionType>(substFnInterfaceType)) {
@@ -3183,9 +3184,18 @@ static CanSILFunctionType getUncachedSILFunctionTypeForConstant(
31833184
auto proto = constant.getDecl()->getDeclContext()->getSelfProtocolDecl();
31843185
witnessMethodConformance = ProtocolConformanceRef(proto);
31853186
}
3187+
3188+
// Does this constant have a preferred abstraction pattern set?
3189+
AbstractionPattern origType = [&]{
3190+
if (auto abstraction = TC.getConstantAbstractionPattern(constant)) {
3191+
return *abstraction;
3192+
} else {
3193+
return AbstractionPattern(origLoweredInterfaceType);
3194+
}
3195+
}();
31863196

31873197
return ::getNativeSILFunctionType(
3188-
TC, context, AbstractionPattern(origLoweredInterfaceType),
3198+
TC, context, origType,
31893199
origLoweredInterfaceType, extInfoBuilder, constant, constant, None,
31903200
witnessMethodConformance);
31913201
}

lib/SIL/IR/TypeLowering.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -3491,6 +3491,28 @@ CanSILBoxType TypeConverter::getBoxTypeForEnumElement(
34913491
return boxTy;
34923492
}
34933493

3494+
Optional<AbstractionPattern>
3495+
TypeConverter::getConstantAbstractionPattern(SILDeclRef constant) {
3496+
if (auto closure = constant.getAbstractClosureExpr()) {
3497+
// Using operator[] here creates an entry in the map if one doesn't exist
3498+
// yet, marking the fact that the lack of abstraction pattern has been
3499+
// established and cannot be overridden by `setAbstractionPattern` later.
3500+
return ClosureAbstractionPatterns[closure];
3501+
}
3502+
return None;
3503+
}
3504+
3505+
void TypeConverter::setAbstractionPattern(AbstractClosureExpr *closure,
3506+
AbstractionPattern pattern) {
3507+
auto existing = ClosureAbstractionPatterns.find(closure);
3508+
if (existing != ClosureAbstractionPatterns.end()) {
3509+
assert(*existing->second == pattern
3510+
&& "closure shouldn't be emitted at different abstraction level contexts");
3511+
} else {
3512+
ClosureAbstractionPatterns[closure] = pattern;
3513+
}
3514+
}
3515+
34943516
static void countNumberOfInnerFields(unsigned &fieldsCount, TypeConverter &TC,
34953517
SILType Ty,
34963518
TypeExpansionContext expansion) {

lib/SILGen/Conversion.h

+23
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class Conversion {
5959
static bool isBridgingKind(KindTy kind) {
6060
return kind <= LastBridgingKind;
6161
}
62+
63+
static bool isReabstractionKind(KindTy kind) {
64+
// Update if we end up with more kinds!
65+
return !isBridgingKind(kind);
66+
}
6267

6368
private:
6469
KindTy Kind;
@@ -139,6 +144,10 @@ class Conversion {
139144
bool isBridging() const {
140145
return isBridgingKind(getKind());
141146
}
147+
148+
bool isReabstraction() const {
149+
return isReabstractionKind(getKind());
150+
}
142151

143152
AbstractionPattern getReabstractionOrigType() const {
144153
return Types.get<ReabstractionTypes>(Kind).OrigType;
@@ -264,12 +273,21 @@ class ConvertingInitialization final : public Initialization {
264273
StateTy getState() const {
265274
return State;
266275
}
276+
277+
InitializationPtr OwnedSubInitialization;
267278

268279
public:
269280
ConvertingInitialization(Conversion conversion, SGFContext finalContext)
270281
: State(Uninitialized), TheConversion(conversion),
271282
FinalContext(finalContext) {}
272283

284+
ConvertingInitialization(Conversion conversion,
285+
InitializationPtr subInitialization)
286+
: State(Uninitialized), TheConversion(conversion),
287+
FinalContext(SGFContext(subInitialization.get())) {
288+
OwnedSubInitialization = std::move(subInitialization);
289+
}
290+
273291
/// Return the conversion to apply to the unconverted value.
274292
const Conversion &getConversion() const {
275293
return TheConversion;
@@ -328,11 +346,16 @@ class ConvertingInitialization final : public Initialization {
328346
ConvertingInitialization *getAsConversion() override {
329347
return this;
330348
}
349+
350+
// Get the abstraction pattern, if any, the value is converted to.
351+
Optional<AbstractionPattern> getAbstractionPattern() const override;
331352

332353
// Bookkeeping.
333354
void finishInitialization(SILGenFunction &SGF) override {
334355
assert(getState() == Initialized);
335356
State = Finished;
357+
if (OwnedSubInitialization)
358+
OwnedSubInitialization->finishInitialization(SGF);
336359
}
337360
};
338361

lib/SILGen/Initialization.h

+10
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ class Initialization {
161161
"uninitialized");
162162
}
163163

164+
/// The preferred abstraction pattern to initialize with.
165+
///
166+
/// Returning something other than None here gives expression emission the
167+
/// opportunity to generate the initial value directly at the proper
168+
/// abstraction level, avoiding the need for a conversion in some
169+
/// circumstances.
170+
virtual Optional<AbstractionPattern> getAbstractionPattern() const {
171+
return None;
172+
}
173+
164174
protected:
165175
bool EmitDebugValueOnInit = true;
166176

lib/SILGen/SGFContext.h

+8
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@ class SGFContext {
167167

168168
return SGFContext();
169169
}
170+
171+
/// Return the abstraction pattern of the context we're emitting into.
172+
Optional<AbstractionPattern> getAbstractionPattern() const {
173+
if (auto *init = getEmitInto()) {
174+
return init->getAbstractionPattern();
175+
}
176+
return None;
177+
}
170178
};
171179

172180
using ValueProducerRef =

0 commit comments

Comments
 (0)