Skip to content

Commit f431d38

Browse files
committed
Make Python MLIR Operation not iterable
The current behavior is conveniently allowing to iterate on the regions of an operation implicitly by exposing an operation as Iterable. However this is also error prone and code that may intend to iterate on the results or the operands could end up "working" apparently instead of throwing a runtime error. The lack of static type checking in Python contributes to the ambiguity here, it seems safer to not do this and require and explicit qualification to iterate (`op.results`, `op.regions`, ...). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D111697
1 parent 4c8ea90 commit f431d38

File tree

6 files changed

+24
-8
lines changed

6 files changed

+24
-8
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -2152,10 +2152,6 @@ void mlir::python::populateIRCore(py::module &m) {
21522152
},
21532153
"Returns the source location the operation was defined or derived "
21542154
"from.")
2155-
.def("__iter__",
2156-
[](PyOperationBase &self) {
2157-
return PyRegionIterator(self.getOperation().getRef());
2158-
})
21592155
.def(
21602156
"__str__",
21612157
[](PyOperationBase &self) {

mlir/python/mlir/dialects/_builtin_ops_ext.py

+9
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,17 @@ def decorator(f):
195195
# Coerce return values, add ReturnOp and rewrite func type.
196196
if return_values is None:
197197
return_values = []
198+
elif isinstance(return_values, tuple):
199+
return_values = list(return_values)
198200
elif isinstance(return_values, Value):
201+
# Returning a single value is fine, coerce it into a list.
199202
return_values = [return_values]
203+
elif isinstance(return_values, OpView):
204+
# Returning a single operation is fine, coerce its results a list.
205+
return_values = return_values.operation.results
206+
elif isinstance(return_values, Operation):
207+
# Returning a single operation is fine, coerce its results a list.
208+
return_values = return_values.results
200209
else:
201210
return_values = list(return_values)
202211
std.ReturnOp(return_values)

mlir/python/mlir/dialects/_ods_common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_default_loc_context(location=None):
124124

125125

126126
def get_op_result_or_value(
127-
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]
127+
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList]
128128
) -> _cext.ir.Value:
129129
"""Returns the given value or the single result of the given op.
130130
@@ -136,6 +136,8 @@ def get_op_result_or_value(
136136
return arg.operation.result
137137
elif isinstance(arg, _cext.ir.Operation):
138138
return arg.result
139+
elif isinstance(arg, _cext.ir.OpResultList):
140+
return arg[0]
139141
else:
140142
assert isinstance(arg, _cext.ir.Value)
141143
return arg

mlir/test/python/dialects/builtin.py

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def run(f):
1515
@run
1616
def testFromPyFunc():
1717
with Context() as ctx, Location.unknown() as loc:
18+
ctx.allow_unregistered_dialects = True
1819
m = builtin.ModuleOp()
1920
f32 = F32Type.get()
2021
f64 = F64Type.get()
@@ -51,6 +52,14 @@ def call_unary(a):
5152
def call_binary(a, b):
5253
return binary_return(a, b)
5354

55+
# We expect coercion of a single result operation to a returned value.
56+
# CHECK-LABEL: func @single_result_op
57+
# CHECK: %0 = "custom.op1"() : () -> f32
58+
# CHECK: return %0 : f32
59+
@builtin.FuncOp.from_py_func()
60+
def single_result_op():
61+
return Operation.create("custom.op1", results=[f32])
62+
5463
# CHECK-LABEL: func @call_none
5564
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
5665
# CHECK: return

mlir/test/python/dialects/math.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def emit_sqrt(arg):
1919
return mlir_math.SqrtOp(arg)
2020

2121
# CHECK-LABEL: func @emit_sqrt(
22-
# CHECK-SAME: %[[ARG:.*]]: f32) {
22+
# CHECK-SAME: %[[ARG:.*]]: f32) -> f32 {
2323
# CHECK: math.sqrt %[[ARG]] : f32
2424
# CHECK: return
2525
# CHECK: }

mlir/test/python/ir/operation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def testTraverseOpRegionBlockIterators():
4040
print(f".verify = {module.operation.verify()}")
4141

4242
# Get the regions and blocks from the default collections.
43-
default_regions = list(op)
43+
default_regions = list(op.regions)
4444
default_blocks = list(default_regions[0])
4545
# They should compare equal regardless of how obtained.
4646
assert default_regions == regions
@@ -53,7 +53,7 @@ def testTraverseOpRegionBlockIterators():
5353
assert default_operations == operations
5454

5555
def walk_operations(indent, op):
56-
for i, region in enumerate(op):
56+
for i, region in enumerate(op.regions):
5757
print(f"{indent}REGION {i}:")
5858
for j, block in enumerate(region):
5959
print(f"{indent} BLOCK {j}:")

0 commit comments

Comments
 (0)