Skip to content

Commit 255a690

Browse files
committed
[mlir][python] Provide more convenient constructors for std.CallOp
The new constructor relies on type-based dynamic dispatch and allows one to construct call operations given an object representing a FuncOp or its name as a string, as opposed to requiring an explicitly constructed attribute. Depends On D110947 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110948
1 parent 3a3a09f commit 255a690

File tree

4 files changed

+120
-9
lines changed

4 files changed

+120
-9
lines changed

mlir/python/mlir/dialects/_builtin_ops_ext.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
try:
6-
from typing import Optional, Sequence
6+
from typing import Optional, Sequence, Union
77

88
import inspect
99

@@ -82,8 +82,8 @@ def visibility(self):
8282
return self.attributes["sym_visibility"]
8383

8484
@property
85-
def name(self):
86-
return self.attributes["sym_name"]
85+
def name(self) -> StringAttr:
86+
return StringAttr(self.attributes["sym_name"])
8787

8888
@property
8989
def entry_block(self):
@@ -104,11 +104,15 @@ def add_entry_block(self):
104104

105105
@property
106106
def arg_attrs(self):
107-
return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
107+
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
108108

109109
@arg_attrs.setter
110-
def arg_attrs(self, attribute: ArrayAttr):
111-
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
110+
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
111+
if isinstance(attribute, ArrayAttr):
112+
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
113+
else:
114+
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
115+
attribute, context=self.context)
112116

113117
@property
114118
def arguments(self):

mlir/python/mlir/dialects/_std_ops_ext.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,73 @@ def literal_value(self) -> Union[int, float]:
6969
return FloatAttr(self.value).value
7070
else:
7171
raise ValueError("only integer and float constants have literal values")
72+
73+
74+
class CallOp:
75+
"""Specialization for the call op class."""
76+
77+
def __init__(self,
78+
calleeOrResults: Union[FuncOp, List[Type]],
79+
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
80+
arguments: Optional[List] = None,
81+
*,
82+
loc=None,
83+
ip=None):
84+
"""Creates an call operation.
85+
86+
The constructor accepts three different forms:
87+
88+
1. A function op to be called followed by a list of arguments.
89+
2. A list of result types, followed by the name of the function to be
90+
called as string, following by a list of arguments.
91+
3. A list of result types, followed by the name of the function to be
92+
called as symbol reference attribute, followed by a list of arguments.
93+
94+
For example
95+
96+
f = builtin.FuncOp("foo", ...)
97+
std.CallOp(f, [args])
98+
std.CallOp([result_types], "foo", [args])
99+
100+
In all cases, the location and insertion point may be specified as keyword
101+
arguments if not provided by the surrounding context managers.
102+
"""
103+
104+
# TODO: consider supporting constructor "overloads", e.g., through a custom
105+
# or pybind-provided metaclass.
106+
if isinstance(calleeOrResults, FuncOp):
107+
if not isinstance(argumentsOrCallee, list):
108+
raise ValueError(
109+
"when constructing a call to a function, expected " +
110+
"the second argument to be a list of call arguments, " +
111+
f"got {type(argumentsOrCallee)}")
112+
if arguments is not None:
113+
raise ValueError("unexpected third argument when constructing a call" +
114+
"to a function")
115+
116+
super().__init__(
117+
calleeOrResults.type.results,
118+
FlatSymbolRefAttr.get(
119+
calleeOrResults.name.value,
120+
context=_get_default_loc_context(loc)),
121+
argumentsOrCallee,
122+
loc=loc,
123+
ip=ip)
124+
return
125+
126+
if isinstance(argumentsOrCallee, list):
127+
raise ValueError("when constructing a call to a function by name, " +
128+
"expected the second argument to be a string or a " +
129+
f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}")
130+
131+
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
132+
super().__init__(
133+
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip)
134+
elif isinstance(argumentsOrCallee, str):
135+
super().__init__(
136+
calleeOrResults,
137+
FlatSymbolRefAttr.get(
138+
argumentsOrCallee, context=_get_default_loc_context(loc)),
139+
arguments,
140+
loc=loc,
141+
ip=ip)

mlir/test/python/dialects/builtin.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def testFuncArgumentAccess():
171171
f32 = F32Type.get()
172172
f64 = F64Type.get()
173173
with InsertionPoint(module.body):
174-
func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
174+
func = builtin.FuncOp("some_func", ([f32, f32], [f32, f32]))
175175
with InsertionPoint(func.add_entry_block()):
176176
std.ReturnOp(func.arguments)
177177
func.arg_attrs = ArrayAttr.get([
@@ -186,6 +186,14 @@ def testFuncArgumentAccess():
186186
DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
187187
])
188188

189+
other = builtin.FuncOp("other_func", ([f32, f32], []))
190+
with InsertionPoint(other.add_entry_block()):
191+
std.ReturnOp([])
192+
other.arg_attrs = [
193+
DictAttr.get({"foo": StringAttr.get("qux")}),
194+
DictAttr.get()
195+
]
196+
189197
# CHECK: [{baz, foo = "bar"}, {qux = []}]
190198
print(func.arg_attrs)
191199

@@ -195,7 +203,11 @@ def testFuncArgumentAccess():
195203
# CHECK: func @some_func(
196204
# CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
197205
# CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
198-
# CHECK: f64 {res1 = 4.200000e+01 : f32},
199-
# CHECK: f64 {res2 = 2.560000e+02 : f64})
206+
# CHECK: f32 {res1 = 4.200000e+01 : f32},
207+
# CHECK: f32 {res2 = 2.560000e+02 : f64})
200208
# CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
209+
#
210+
# CHECK: func @other_func(
211+
# CHECK: %{{.*}}: f32 {foo = "qux"},
212+
# CHECK: %{{.*}}: f32)
201213
print(module)

mlir/test/python/dialects/std.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4+
from mlir.dialects import builtin
45
from mlir.dialects import std
56

67

@@ -62,3 +63,27 @@ def testConstantIndexOp():
6263
print(c1.literal_value)
6364

6465
# CHECK: = constant 10 : index
66+
67+
# CHECK-LABEL: TEST: testFunctionCalls
68+
@constructAndPrintInModule
69+
def testFunctionCalls():
70+
foo = builtin.FuncOp("foo", ([], []))
71+
bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
72+
qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
73+
74+
with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
75+
std.CallOp(foo, [])
76+
std.CallOp([IndexType.get()], "bar", [])
77+
std.CallOp([F32Type.get()], FlatSymbolRefAttr.get("qux"), [])
78+
std.ReturnOp([])
79+
80+
# CHECK: func @foo()
81+
# CHECK: func @bar() -> index
82+
# CHECK: func @qux() -> f32
83+
# CHECK: func @caller() {
84+
# CHECK: call @foo() : () -> ()
85+
# CHECK: %0 = call @bar() : () -> index
86+
# CHECK: %1 = call @qux() : () -> f32
87+
# CHECK: return
88+
# CHECK: }
89+

0 commit comments

Comments
 (0)