Skip to content

Commit 58e7bf7

Browse files
committed
[mlir] Add isa/dyn_cast support for dialect interfaces
This matches the same API usage as attributes/ops/types. For example: ```c++ Dialect *dialect = ...; // Instead of this: if (auto *interface = dialect->getRegisteredInterface<DialectInlinerInterface>()) // You can do this: if (auto *interface = dyn_cast<DialectInlinerInterface>(dialect)) ``` Differential Revision: https://reviews.llvm.org/D117859
1 parent 51ed14d commit 58e7bf7

File tree

7 files changed

+62
-21
lines changed

7 files changed

+62
-21
lines changed

mlir/docs/Interfaces.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ or transformation without the need to determine the specific dialect subclass:
7777
7878
```c++
7979
Dialect *dialect = ...;
80-
if (DialectInlinerInterface *interface
81-
= dialect->getRegisteredInterface<DialectInlinerInterface>()) {
80+
if (DialectInlinerInterface *interface = dyn_cast<DialectInlinerInterface>(dialect)) {
8281
// The dialect has provided an implementation of this interface.
8382
...
8483
}

mlir/include/mlir/IR/Dialect.h

+48-1
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,58 @@ class DialectRegistry {
440440

441441
namespace llvm {
442442
/// Provide isa functionality for Dialects.
443-
template <typename T> struct isa_impl<T, ::mlir::Dialect> {
443+
template <typename T>
444+
struct isa_impl<T, ::mlir::Dialect,
445+
std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
444446
static inline bool doit(const ::mlir::Dialect &dialect) {
445447
return mlir::TypeID::get<T>() == dialect.getTypeID();
446448
}
447449
};
450+
template <typename T>
451+
struct isa_impl<
452+
T, ::mlir::Dialect,
453+
std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
454+
static inline bool doit(const ::mlir::Dialect &dialect) {
455+
return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
456+
}
457+
};
458+
template <typename T>
459+
struct cast_retty_impl<T, ::mlir::Dialect *> {
460+
using ret_type =
461+
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T *,
462+
const T *>;
463+
};
464+
template <typename T>
465+
struct cast_retty_impl<T, ::mlir::Dialect> {
466+
using ret_type =
467+
std::conditional_t<std::is_base_of<::mlir::Dialect, T>::value, T &,
468+
const T &>;
469+
};
470+
471+
template <typename T>
472+
struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
473+
template <typename To>
474+
static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
475+
doitImpl(::mlir::Dialect &dialect) {
476+
return static_cast<To &>(dialect);
477+
}
478+
template <typename To>
479+
static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
480+
const To &>
481+
doitImpl(::mlir::Dialect &dialect) {
482+
return *dialect.getRegisteredInterface<To>();
483+
}
484+
485+
static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
486+
};
487+
template <class T>
488+
struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
489+
static auto doit(::mlir::Dialect *dialect) {
490+
return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
491+
*dialect);
492+
}
493+
};
494+
448495
} // namespace llvm
449496

450497
#endif

mlir/lib/Dialect/DLTI/DLTI.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ combineOneSpec(DataLayoutSpecInterface spec,
231231
// dialect is not loaded for some reason, use the default combinator
232232
// that conservatively accepts identical entries only.
233233
entriesForID[id] =
234-
dialect ? dialect->getRegisteredInterface<DataLayoutDialectInterface>()
235-
->combine(entriesForID[id], kvp.second)
234+
dialect ? cast<DataLayoutDialectInterface>(dialect)->combine(
235+
entriesForID[id], kvp.second)
236236
: DataLayoutDialectInterface::defaultCombine(entriesForID[id],
237237
kvp.second);
238238
if (!entriesForID[id])

mlir/lib/IR/BuiltinAttributes.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,7 @@ bool OpaqueElementsAttr::decode(ElementsAttr &result) {
12361236
Dialect *dialect = getContext()->getLoadedDialect(getDialect());
12371237
if (!dialect)
12381238
return true;
1239-
auto *interface =
1240-
dialect->getRegisteredInterface<DialectDecodeAttributesInterface>();
1239+
auto *interface = llvm::dyn_cast<DialectDecodeAttributesInterface>(dialect);
12411240
if (!interface)
12421241
return true;
12431242
return failed(interface->decode(*this, result));

mlir/lib/IR/Operation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
506506
if (!dialect)
507507
return failure();
508508

509-
auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
509+
auto *interface = dyn_cast<DialectFoldInterface>(dialect);
510510
if (!interface)
511511
return failure();
512512

mlir/lib/Interfaces/DataLayoutInterfaces.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
438438
if (!dialect)
439439
continue;
440440

441-
const auto *iface =
442-
dialect->getRegisteredInterface<DataLayoutDialectInterface>();
441+
const auto *iface = dyn_cast<DataLayoutDialectInterface>(dialect);
443442
if (!iface) {
444443
return emitError(loc)
445444
<< "the '" << dialect->getNamespace()

mlir/unittests/IR/DialectTest.cpp

+8-11
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,17 @@ TEST(Dialect, DelayedInterfaceRegistration) {
6868
MLIRContext context(registry);
6969

7070
// Load the TestDialect and check that the interface got registered for it.
71-
auto *testDialect = context.getOrLoadDialect<TestDialect>();
71+
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
7272
ASSERT_TRUE(testDialect != nullptr);
73-
auto *testDialectInterface =
74-
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
73+
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
7574
EXPECT_TRUE(testDialectInterface != nullptr);
7675

7776
// Load the SecondTestDialect and check that the interface is not registered
7877
// for it.
79-
auto *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
78+
Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
8079
ASSERT_TRUE(secondTestDialect != nullptr);
8180
auto *secondTestDialectInterface =
82-
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
81+
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
8382
EXPECT_TRUE(secondTestDialectInterface == nullptr);
8483

8584
// Use the same mechanism as for delayed registration but for an already
@@ -90,7 +89,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
9089
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
9190
context.appendDialectRegistry(secondRegistry);
9291
secondTestDialectInterface =
93-
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
92+
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
9493
EXPECT_TRUE(secondTestDialectInterface != nullptr);
9594
}
9695

@@ -102,10 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
102101
MLIRContext context(registry);
103102

104103
// Load the TestDialect and check that the interface got registered for it.
105-
auto *testDialect = context.getOrLoadDialect<TestDialect>();
104+
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
106105
ASSERT_TRUE(testDialect != nullptr);
107-
auto *testDialectInterface =
108-
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
106+
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
109107
EXPECT_TRUE(testDialectInterface != nullptr);
110108

111109
// Try adding the same dialect interface again and check that we don't crash
@@ -114,8 +112,7 @@ TEST(Dialect, RepeatedDelayedRegistration) {
114112
secondRegistry.insert<TestDialect>();
115113
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
116114
context.appendDialectRegistry(secondRegistry);
117-
testDialectInterface =
118-
testDialect->getRegisteredInterface<TestDialectInterfaceBase>();
115+
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
119116
EXPECT_TRUE(testDialectInterface != nullptr);
120117
}
121118

0 commit comments

Comments
 (0)