Skip to content

Commit 7fd6f40

Browse files
committed
[mlir][python] Add custom constructor for memref load
The type can be inferred trivially, but it is currently done as string stitching between ODS and C++ and is not easily exposed to Python. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111712
1 parent cc83c24 commit 7fd6f40

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
try:
6+
from ..ir import *
7+
from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
8+
except ImportError as e:
9+
raise RuntimeError("Error loading imports from extension module") from e
10+
11+
from typing import Optional, Sequence, Union
12+
13+
14+
class LoadOp:
15+
"""Specialization for the MemRef load operation."""
16+
17+
def __init__(self,
18+
memref: Union[Operation, OpView, Value],
19+
indices: Optional[Union[Operation, OpView,
20+
Sequence[Value]]] = None,
21+
*,
22+
loc=None,
23+
ip=None):
24+
"""Creates a memref load operation.
25+
26+
Args:
27+
memref: the buffer to load from.
28+
indices: the list of subscripts, may be empty for zero-dimensional
29+
buffers.
30+
loc: user-visible location of the operation.
31+
ip: insertion point.
32+
"""
33+
memref_resolved = _get_op_result_or_value(memref)
34+
indices_resolved = [] if indices is None else _get_op_results_or_values(
35+
indices)
36+
return_type = memref_resolved.type
37+
super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip)

mlir/test/python/dialects/memref.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
def run(f):
99
print("\nTEST:", f.__name__)
1010
f()
11+
return f
1112

1213

1314
# CHECK-LABEL: TEST: testSubViewAccessors
15+
@run
1416
def testSubViewAccessors():
1517
ctx = Context()
1618
module = Module.parse(
@@ -52,4 +54,20 @@ def testSubViewAccessors():
5254
print(subview.strides[1])
5355

5456

55-
run(testSubViewAccessors)
57+
# CHECK-LABEL: TEST: testCustomBuidlers
58+
@run
59+
def testCustomBuidlers():
60+
with Context() as ctx, Location.unknown(ctx):
61+
module = Module.parse(r"""
62+
func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
63+
return
64+
}
65+
""")
66+
func = module.body.operations[0]
67+
func_body = func.regions[0].blocks[0]
68+
with InsertionPoint.at_block_terminator(func_body):
69+
memref.LoadOp(func.arguments[0], func.arguments[1:])
70+
71+
# CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
72+
# CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
73+
print(module)

0 commit comments

Comments
 (0)