Skip to content

Commit a7f8b7c

Browse files
committed
[mlir][python] Remove "Raw" OpView classes
The raw `OpView` classes are used to bypass the constructors of `OpView` subclasses, but having a separate class can create some confusing behaviour, e.g.: ``` op = MyOp(...) # fails, lhs is 'MyOp', rhs is '_MyOp' assert type(op) == type(op.operation.opview) ``` Instead we can use `__new__` to achieve the same thing without a separate class: ``` my_op = MyOp.__new__(MyOp) OpView.__init__(my_op, op) ``` Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D143830
1 parent c00f81c commit a7f8b7c

File tree

7 files changed

+46
-77
lines changed

7 files changed

+46
-77
lines changed

mlir/lib/Bindings/Python/Globals.h

+9-14
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ class PyGlobals {
7474
/// Raises an exception if the mapping already exists.
7575
/// This is intended to be called by implementation code.
7676
void registerOperationImpl(const std::string &operationName,
77-
pybind11::object pyClass,
78-
pybind11::object rawOpViewClass);
77+
pybind11::object pyClass);
7978

8079
/// Returns the custom Attribute builder for Attribute kind.
8180
std::optional<pybind11::function>
@@ -86,10 +85,11 @@ class PyGlobals {
8685
std::optional<pybind11::object>
8786
lookupDialectClass(const std::string &dialectNamespace);
8887

89-
/// Looks up a registered raw OpView class by operation name. Note that this
90-
/// may trigger a load of the dialect, which can arbitrarily re-enter.
88+
/// Looks up a registered operation class (deriving from OpView) by operation
89+
/// name. Note that this may trigger a load of the dialect, which can
90+
/// arbitrarily re-enter.
9191
std::optional<pybind11::object>
92-
lookupRawOpViewClass(llvm::StringRef operationName);
92+
lookupOperationClass(llvm::StringRef operationName);
9393

9494
private:
9595
static PyGlobals *instance;
@@ -99,21 +99,16 @@ class PyGlobals {
9999
llvm::StringMap<pybind11::object> dialectClassMap;
100100
/// Map of full operation name to external operation class object.
101101
llvm::StringMap<pybind11::object> operationClassMap;
102-
/// Map of operation name to custom subclass that directly initializes
103-
/// the OpView base class (bypassing the user class constructor).
104-
llvm::StringMap<pybind11::object> rawOpViewClassMap;
105102
/// Map of attribute ODS name to custom builder.
106103
llvm::StringMap<pybind11::object> attributeBuilderMap;
107104

108105
/// Set of dialect namespaces that we have attempted to import implementation
109106
/// modules for.
110107
llvm::StringSet<> loadedDialectModulesCache;
111-
/// Cache of operation name to custom OpView subclass that directly
112-
/// initializes the OpView base class (or an undefined object for negative
113-
/// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap
114-
/// in order for repeat lookups of the OpView classes to only incur the cost
115-
/// of one hashtable lookup.
116-
llvm::StringMap<pybind11::object> rawOpViewClassMapCache;
108+
/// Cache of operation name to external operation class object. This is
109+
/// maintained on lookup as a shadow of operationClassMap in order for repeat
110+
/// lookups of the classes to only incur the cost of one hashtable lookup.
111+
llvm::StringMap<pybind11::object> operationClassMapCache;
117112
};
118113

119114
} // namespace python

mlir/lib/Bindings/Python/IRCore.cpp

+15-39
Original file line numberDiff line numberDiff line change
@@ -1339,10 +1339,10 @@ py::object PyOperation::createOpView() {
13391339
checkValid();
13401340
MlirIdentifier ident = mlirOperationGetName(get());
13411341
MlirStringRef identStr = mlirIdentifierStr(ident);
1342-
auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1342+
auto operationCls = PyGlobals::get().lookupOperationClass(
13431343
StringRef(identStr.data, identStr.length));
1344-
if (opViewClass)
1345-
return (*opViewClass)(getRef().getObject());
1344+
if (operationCls)
1345+
return PyOpView::constructDerived(*operationCls, *getRef().get());
13461346
return py::cast(PyOpView(getRef().getObject()));
13471347
}
13481348

@@ -1618,47 +1618,23 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
16181618
/*regions=*/*regions, location, maybeIp);
16191619
}
16201620

1621+
pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1622+
const PyOperation &operation) {
1623+
// TODO: pybind11 2.6 supports a more direct form.
1624+
// Upgrade many years from now.
1625+
// auto opViewType = py::type::of<PyOpView>();
1626+
py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1627+
py::object instance = cls.attr("__new__")(cls);
1628+
opViewType.attr("__init__")(instance, operation);
1629+
return instance;
1630+
}
1631+
16211632
PyOpView::PyOpView(const py::object &operationObject)
16221633
// Casting through the PyOperationBase base-class and then back to the
16231634
// Operation lets us accept any PyOperationBase subclass.
16241635
: operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
16251636
operationObject(operation.getRef().getObject()) {}
16261637

1627-
py::object PyOpView::createRawSubclass(const py::object &userClass) {
1628-
// This is... a little gross. The typical pattern is to have a pure python
1629-
// class that extends OpView like:
1630-
// class AddFOp(_cext.ir.OpView):
1631-
// def __init__(self, loc, lhs, rhs):
1632-
// operation = loc.context.create_operation(
1633-
// "addf", lhs, rhs, results=[lhs.type])
1634-
// super().__init__(operation)
1635-
//
1636-
// I.e. The goal of the user facing type is to provide a nice constructor
1637-
// that has complete freedom for the op under construction. This is at odds
1638-
// with our other desire to sometimes create this object by just passing an
1639-
// operation (to initialize the base class). We could do *arg and **kwargs
1640-
// munging to try to make it work, but instead, we synthesize a new class
1641-
// on the fly which extends this user class (AddFOp in this example) and
1642-
// *give it* the base class's __init__ method, thus bypassing the
1643-
// intermediate subclass's __init__ method entirely. While slightly,
1644-
// underhanded, this is safe/legal because the type hierarchy has not changed
1645-
// (we just added a new leaf) and we aren't mucking around with __new__.
1646-
// Typically, this new class will be stored on the original as "_Raw" and will
1647-
// be used for casts and other things that need a variant of the class that
1648-
// is initialized purely from an operation.
1649-
py::object parentMetaclass =
1650-
py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1651-
py::dict attributes;
1652-
// TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1653-
// now.
1654-
// auto opViewType = py::type::of<PyOpView>();
1655-
auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1656-
attributes["__init__"] = opViewType.attr("__init__");
1657-
py::str origName = userClass.attr("__name__");
1658-
py::str newName = py::str("_") + origName;
1659-
return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1660-
}
1661-
16621638
//------------------------------------------------------------------------------
16631639
// PyInsertionPoint.
16641640
//------------------------------------------------------------------------------
@@ -2863,7 +2839,7 @@ void mlir::python::populateIRCore(py::module &m) {
28632839
throw py::value_error(
28642840
"Expected a '" + clsOpName + "' op, got: '" +
28652841
std::string(parsedOpName.data, parsedOpName.length) + "'");
2866-
return cls.attr("_Raw")(parsed.getObject());
2842+
return PyOpView::constructDerived(cls, *parsed.get());
28672843
},
28682844
py::arg("cls"), py::arg("source"), py::kw_only(),
28692845
py::arg("source_name") = "", py::arg("context") = py::none(),

mlir/lib/Bindings/Python/IRModule.cpp

+9-11
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,14 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
8484
}
8585

8686
void PyGlobals::registerOperationImpl(const std::string &operationName,
87-
py::object pyClass,
88-
py::object rawOpViewClass) {
87+
py::object pyClass) {
8988
py::object &found = operationClassMap[operationName];
9089
if (found) {
9190
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
9291
operationName +
9392
"' is already registered.");
9493
}
9594
found = std::move(pyClass);
96-
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
9795
}
9896

9997
std::optional<py::function>
@@ -130,10 +128,10 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
130128
}
131129

132130
std::optional<pybind11::object>
133-
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
131+
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
134132
{
135-
auto foundIt = rawOpViewClassMapCache.find(operationName);
136-
if (foundIt != rawOpViewClassMapCache.end()) {
133+
auto foundIt = operationClassMapCache.find(operationName);
134+
if (foundIt != operationClassMapCache.end()) {
137135
if (foundIt->second.is_none())
138136
return std::nullopt;
139137
assert(foundIt->second && "py::object is defined");
@@ -148,22 +146,22 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
148146

149147
// Attempt to find from the canonical map and cache.
150148
{
151-
auto foundIt = rawOpViewClassMap.find(operationName);
152-
if (foundIt != rawOpViewClassMap.end()) {
149+
auto foundIt = operationClassMap.find(operationName);
150+
if (foundIt != operationClassMap.end()) {
153151
if (foundIt->second.is_none())
154152
return std::nullopt;
155153
assert(foundIt->second && "py::object is defined");
156154
// Positive cache.
157-
rawOpViewClassMapCache[operationName] = foundIt->second;
155+
operationClassMapCache[operationName] = foundIt->second;
158156
return foundIt->second;
159157
}
160158
// Negative cache.
161-
rawOpViewClassMap[operationName] = py::none();
159+
operationClassMap[operationName] = py::none();
162160
return std::nullopt;
163161
}
164162
}
165163

166164
void PyGlobals::clearImportCache() {
167165
loadedDialectModulesCache.clear();
168-
rawOpViewClassMapCache.clear();
166+
operationClassMapCache.clear();
169167
}

mlir/lib/Bindings/Python/IRModule.h

+10-2
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,6 @@ class PyOpView : public PyOperationBase {
654654
PyOpView(const pybind11::object &operationObject);
655655
PyOperation &getOperation() override { return operation; }
656656

657-
static pybind11::object createRawSubclass(const pybind11::object &userClass);
658-
659657
pybind11::object getOperationObject() { return operationObject; }
660658

661659
static pybind11::object
@@ -666,6 +664,16 @@ class PyOpView : public PyOperationBase {
666664
std::optional<int> regions, DefaultingPyLocation location,
667665
const pybind11::object &maybeIp);
668666

667+
/// Construct an instance of a class deriving from OpView, bypassing its
668+
/// `__init__` method. The derived class will typically define a constructor
669+
/// that provides a convenient builder, but we need to side-step this when
670+
/// constructing an `OpView` for an already-built operation.
671+
///
672+
/// The caller is responsible for verifying that `operation` is a valid
673+
/// operation to construct `cls` with.
674+
static pybind11::object constructDerived(const pybind11::object &cls,
675+
const PyOperation &operation);
676+
669677
private:
670678
PyOperation &operation; // For efficient, cast-free access from C++
671679
pybind11::object operationObject; // Holds the reference.

mlir/lib/Bindings/Python/MainModule.cpp

+1-9
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ PYBIND11_MODULE(_mlir, m) {
4141
"Testing hook for directly registering a dialect")
4242
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
4343
py::arg("operation_name"), py::arg("operation_class"),
44-
py::arg("raw_opview_class"),
4544
"Testing hook for directly registering an operation");
4645

4746
// Aside from making the globals accessible to python, having python manage
@@ -68,18 +67,11 @@ PYBIND11_MODULE(_mlir, m) {
6867
[dialectClass](py::object opClass) -> py::object {
6968
std::string operationName =
7069
opClass.attr("OPERATION_NAME").cast<std::string>();
71-
auto rawSubclass = PyOpView::createRawSubclass(opClass);
72-
PyGlobals::get().registerOperationImpl(operationName, opClass,
73-
rawSubclass);
70+
PyGlobals::get().registerOperationImpl(operationName, opClass);
7471

7572
// Dict-stuff the new opClass by name onto the dialect class.
7673
py::object opClassName = opClass.attr("__name__");
7774
dialectClass.attr(opClassName) = opClass;
78-
79-
// Now create a special "Raw" subclass that passes through
80-
// construction to the OpView parent (bypasses the intermediate
81-
// child's __init__).
82-
opClass.attr("_Raw") = rawSubclass;
8375
return opClass;
8476
});
8577
},

mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ globals: "_Globals"
55
class _Globals:
66
dialect_search_modules: List[str]
77
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ...
8-
def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ...
8+
def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ...
99
def append_dialect_search_prefix(self, module_name: str) -> None: ...
1010

1111
def register_dialect(dialect_class: type) -> object: ...

mlir/test/python/ir/operation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def testKnownOpView():
620620
# addf should map to a known OpView class in the arithmetic dialect.
621621
# We know the OpView for it defines an 'lhs' attribute.
622622
addf = module.body.operations[2]
623-
# CHECK: <mlir.dialects._arith_ops_gen._AddFOp object
623+
# CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
624624
print(repr(addf))
625625
# CHECK: "custom.f32"()
626626
print(addf.lhs)

0 commit comments

Comments
 (0)