Skip to content

Commit 9c5982e

Browse files
committed
[mlir] support recursive types in type conversion infra
MLIR supports recursive types but they could not be handled by the conversion infrastructure directly as it would result in infinite recursion in `convertType` for elemental types. Support this case by keeping the "call stack" of nested type conversions in the TypeConverter class and by passing it as an optional argument to the individual conversion callback. The callback can then check if a specific type is present on the stack more than once to detect and handle the recursive case. This approach is preferred to the alternative approach of having a separate callback dedicated to handling only the recursive case as the latter was observed to introduce ~3% time overhead on a 50MB IR file even if it did not contain recursive types. This approach is also preferred to keeping a local stack in type converters that need to handle recursive types as that would compose poorly in case of out-of-tree or cross-project extensions. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D113579
1 parent 774f783 commit 9c5982e

File tree

5 files changed

+109
-20
lines changed

5 files changed

+109
-20
lines changed

mlir/docs/DialectConversion.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ class TypeConverter {
307307
/// existing value are expected to be removed during conversion. If
308308
/// `llvm::None` is returned, the converter is allowed to try another
309309
/// conversion function to perform the conversion.
310+
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
311+
/// - This form represents a 1-N type conversion supporting recursive
312+
/// types. The first two arguments and the return value are the same as
313+
/// for the regular 1-N form. The third argument is contains is the
314+
/// "call stack" of the recursive conversion: it contains the list of
315+
/// types currently being converted, with the current type being the
316+
/// last one. If it is present more than once in the list, the
317+
/// conversion concerns a recursive type.
310318
/// Note: When attempting to convert a type, e.g. via 'convertType', the
311319
/// mostly recently added conversions will be invoked first.
312320
template <typename FnT,

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ class TypeConverter {
101101
/// existing value are expected to be removed during conversion. If
102102
/// `llvm::None` is returned, the converter is allowed to try another
103103
/// conversion function to perform the conversion.
104+
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
105+
/// - This form represents a 1-N type conversion supporting recursive
106+
/// types. The first two arguments and the return value are the same as
107+
/// for the regular 1-N form. The third argument is contains is the
108+
/// "call stack" of the recursive conversion: it contains the list of
109+
/// types currently being converted, with the current type being the
110+
/// last one. If it is present more than once in the list, the
111+
/// conversion concerns a recursive type.
104112
/// Note: When attempting to convert a type, e.g. via 'convertType', the
105113
/// mostly recently added conversions will be invoked first.
106114
template <typename FnT, typename T = typename llvm::function_traits<
@@ -221,8 +229,8 @@ class TypeConverter {
221229
/// The signature of the callback used to convert a type. If the new set of
222230
/// types is empty, the type is removed and any usages of the existing value
223231
/// are expected to be removed during conversion.
224-
using ConversionCallbackFn =
225-
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
232+
using ConversionCallbackFn = std::function<Optional<LogicalResult>(
233+
Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
226234

227235
/// The signature of the callback used to materialize a conversion.
228236
using MaterializationCallbackFn =
@@ -240,28 +248,44 @@ class TypeConverter {
240248
template <typename T, typename FnT>
241249
std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
242250
wrapCallback(FnT &&callback) {
243-
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
244-
T type, SmallVectorImpl<Type> &results) {
245-
if (Optional<Type> resultOpt = callback(type)) {
246-
bool wasSuccess = static_cast<bool>(resultOpt.getValue());
247-
if (wasSuccess)
248-
results.push_back(resultOpt.getValue());
249-
return Optional<LogicalResult>(success(wasSuccess));
250-
}
251-
return Optional<LogicalResult>();
252-
});
253-
}
254-
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
251+
return wrapCallback<T>(
252+
[callback = std::forward<FnT>(callback)](
253+
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
254+
if (Optional<Type> resultOpt = callback(type)) {
255+
bool wasSuccess = static_cast<bool>(resultOpt.getValue());
256+
if (wasSuccess)
257+
results.push_back(resultOpt.getValue());
258+
return Optional<LogicalResult>(success(wasSuccess));
259+
}
260+
return Optional<LogicalResult>();
261+
});
262+
}
263+
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
264+
/// &)`
255265
template <typename T, typename FnT>
256-
std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
266+
std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
267+
ConversionCallbackFn>
268+
wrapCallback(FnT &&callback) {
269+
return wrapCallback<T>(
270+
[callback = std::forward<FnT>(callback)](
271+
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
272+
return callback(type, results);
273+
});
274+
}
275+
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
276+
/// &, ArrayRef<Type>)`.
277+
template <typename T, typename FnT>
278+
std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
279+
ArrayRef<Type>>::value,
280+
ConversionCallbackFn>
257281
wrapCallback(FnT &&callback) {
258282
return [callback = std::forward<FnT>(callback)](
259-
Type type,
260-
SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
283+
Type type, SmallVectorImpl<Type> &results,
284+
ArrayRef<Type> callStack) -> Optional<LogicalResult> {
261285
T derivedType = type.dyn_cast<T>();
262286
if (!derivedType)
263287
return llvm::None;
264-
return callback(derivedType, results);
288+
return callback(derivedType, results, callStack);
265289
};
266290
}
267291

@@ -300,6 +324,10 @@ class TypeConverter {
300324
DenseMap<Type, Type> cachedDirectConversions;
301325
/// This cache stores the successful 1->N conversions, where N != 1.
302326
DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
327+
328+
/// Stores the types that are being converted in the case when convertType
329+
/// is being called recursively to convert nested types.
330+
SmallVector<Type, 2> conversionCallStack;
303331
};
304332

305333
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/FunctionSupport.h"
1515
#include "mlir/Rewrite/PatternApplicator.h"
1616
#include "mlir/Transforms/Utils.h"
17+
#include "llvm/ADT/ScopeExit.h"
1718
#include "llvm/ADT/SetVector.h"
1819
#include "llvm/ADT/SmallPtrSet.h"
1920
#include "llvm/Support/Debug.h"
@@ -2931,8 +2932,12 @@ LogicalResult TypeConverter::convertType(Type t,
29312932
// Walk the added converters in reverse order to apply the most recently
29322933
// registered first.
29332934
size_t currentCount = results.size();
2935+
conversionCallStack.push_back(t);
2936+
auto popConversionCallStack =
2937+
llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
29342938
for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
2935-
if (Optional<LogicalResult> result = converter(t, results)) {
2939+
if (Optional<LogicalResult> result =
2940+
converter(t, results, conversionCallStack)) {
29362941
if (!succeeded(*result)) {
29372942
cachedDirectConversions.try_emplace(t, nullptr);
29382943
return failure();

mlir/test/Transforms/test-legalize-type-conversion.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,12 @@ func @test_signature_conversion_no_converter() {
112112
}) : () -> ()
113113
return
114114
}
115+
116+
// -----
117+
118+
// CHECK-LABEL: @recursive_type_conversion
119+
func @recursive_type_conversion() {
120+
// CHECK: !test.test_rec<outer_converted_type, smpla>
121+
"test.type_producer"() : () -> !test.test_rec<something, test_rec<something>>
122+
return
123+
}

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "TestDialect.h"
10+
#include "TestTypes.h"
1011
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1112
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1213
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
@@ -924,10 +925,16 @@ struct TestTypeConversionProducer
924925
matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
925926
ConversionPatternRewriter &rewriter) const final {
926927
Type resultType = op.getType();
928+
Type convertedType = getTypeConverter()
929+
? getTypeConverter()->convertType(resultType)
930+
: resultType;
927931
if (resultType.isa<FloatType>())
928932
resultType = rewriter.getF64Type();
929933
else if (resultType.isInteger(16))
930934
resultType = rewriter.getIntegerType(64);
935+
else if (resultType.isa<test::TestRecursiveType>() &&
936+
convertedType != resultType)
937+
resultType = convertedType;
931938
else
932939
return failure();
933940

@@ -1035,6 +1042,35 @@ struct TestTypeConversionDriver
10351042
// Drop all integer types.
10361043
return success();
10371044
});
1045+
converter.addConversion(
1046+
// Convert a recursive self-referring type into a non-self-referring
1047+
// type named "outer_converted_type" that contains a SimpleAType.
1048+
[&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
1049+
ArrayRef<Type> callStack) -> Optional<LogicalResult> {
1050+
// If the type is already converted, return it to indicate that it is
1051+
// legal.
1052+
if (type.getName() == "outer_converted_type") {
1053+
results.push_back(type);
1054+
return success();
1055+
}
1056+
1057+
// If the type is on the call stack more than once (it is there at
1058+
// least once because of the _current_ call, which is always the last
1059+
// element on the stack), we've hit the recursive case. Just return
1060+
// SimpleAType here to create a non-recursive type as a result.
1061+
if (llvm::is_contained(callStack.drop_back(), type)) {
1062+
results.push_back(test::SimpleAType::get(type.getContext()));
1063+
return success();
1064+
}
1065+
1066+
// Convert the body recursively.
1067+
auto result = test::TestRecursiveType::get(type.getContext(),
1068+
"outer_converted_type");
1069+
if (failed(result.setBody(converter.convertType(type.getBody()))))
1070+
return failure();
1071+
results.push_back(result);
1072+
return success();
1073+
});
10381074

10391075
/// Add the legal set of type materializations.
10401076
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1059,7 +1095,10 @@ struct TestTypeConversionDriver
10591095
// Initialize the conversion target.
10601096
mlir::ConversionTarget target(getContext());
10611097
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
1062-
return op.getType().isF64() || op.getType().isInteger(64);
1098+
auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
1099+
return op.getType().isF64() || op.getType().isInteger(64) ||
1100+
(recursiveType &&
1101+
recursiveType.getName() == "outer_converted_type");
10631102
});
10641103
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
10651104
return converter.isSignatureLegal(op.getType()) &&

0 commit comments

Comments
 (0)