Skip to content

Commit c00f81c

Browse files
committed
[mlir][python] Allow running pass manager on any operation
`PassManager.run` is currently restricted to running on `builtin.module` ops, but this restriction doesn't exist on the C++ side. This updates it to take `ir.Operation/OpView` instead of `ir.Module`. Depends on D143354 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D143356
1 parent 6f5590c commit c00f81c

File tree

6 files changed

+13
-13
lines changed

6 files changed

+13
-13
lines changed

mlir/lib/Bindings/Python/Pass.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
116116
"ValueError if the pipeline can't be parsed.")
117117
.def(
118118
"run",
119-
[](PyPassManager &passManager, PyModule &module) {
119+
[](PyPassManager &passManager, PyOperationBase &op) {
120120
MlirLogicalResult status = mlirPassManagerRunOnOp(
121-
passManager.get(), mlirModuleGetOperation(module.get()));
121+
passManager.get(), op.getOperation().get());
122122
if (mlirLogicalResultIsFailure(status))
123123
throw SetPyError(PyExc_RuntimeError,
124124
"Failure while executing pass pipeline.");
125125
},
126-
py::arg("module"),
127-
"Run the pass manager on the provided module, throw a RuntimeError "
128-
"on failure.")
126+
py::arg("operation"),
127+
"Run the pass manager on the provided operation, throw a "
128+
"RuntimeError on failure.")
129129
.def(
130130
"__str__",
131131
[](PyPassManager &self) {

mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __call__(self, module: ir.Module):
2424

2525
def compile(self, module: ir.Module):
2626
"""Compiles the module by invoking the sparse copmiler pipeline."""
27-
passmanager.PassManager.parse(self.pipeline).run(module)
27+
passmanager.PassManager.parse(self.pipeline).run(module.operation)
2828

2929
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
3030
"""Wraps the module in a JIT execution engine."""

mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __call__(self, module: ir.Module):
2727

2828
def compile(self, module: ir.Module):
2929
"""Compiles the module by invoking the sparse copmiler pipeline."""
30-
passmanager.PassManager.parse(self.pipeline).run(module)
30+
passmanager.PassManager.parse(self.pipeline).run(module.operation)
3131

3232
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
3333
"""Wraps the module in a JIT execution engine."""

mlir/test/python/execution_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def testInvalidModule():
6464
def lowerToLLVM(module):
6565
pm = PassManager.parse(
6666
"builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)")
67-
pm.run(module)
67+
pm.run(module.operation)
6868
return module
6969

7070

mlir/test/python/integration/dialects/linalg/opsrun.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def transform(module, boilerplate):
202202
pm.add("finalize-memref-to-llvm")
203203
pm.add("convert-func-to-llvm")
204204
pm.add("reconcile-unrealized-casts")
205-
pm.run(mod)
205+
pm.run(mod.operation)
206206
return mod
207207

208208

mlir/test/python/pass_manager.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc, sys
44
from mlir.ir import *
55
from mlir.passmanager import *
6+
from mlir.dialects.func import FuncOp
67

78
# Log everything to stderr and flush so that we have a unified stream to match
89
# errors/info emitted by MLIR to stderr.
@@ -120,11 +121,10 @@ def testInvalidNesting():
120121
# CHECK-LABEL: TEST: testRun
121122
def testRunPipeline():
122123
with Context():
123-
pm = PassManager.parse("builtin.module(print-op-stats{json=false})")
124-
module = Module.parse(r"""func.func @successfulParse() { return }""")
125-
pm.run(module)
124+
pm = PassManager.parse("any(print-op-stats{json=false})")
125+
func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
126+
pm.run(func)
126127
# CHECK: Operations encountered:
127-
# CHECK: builtin.module , 1
128128
# CHECK: func.func , 1
129129
# CHECK: func.return , 1
130130
run(testRunPipeline)

0 commit comments

Comments
 (0)