Skip to content

Commit 1ab3efa

Browse files
committed
[mlir][python] Add fused location
1 parent 6bcf1f9 commit 1ab3efa

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ static const char kContextGetCallSiteLocationDocstring[] =
4747
static const char kContextGetFileLocationDocstring[] =
4848
R"(Gets a Location representing a file, line and column)";
4949

50+
static const char kContextGetFusedLocationDocstring[] =
51+
R"(Gets a Location representing a fused location with optional metadata)";
52+
5053
static const char kContextGetNameLocationDocString[] =
5154
R"(Gets a Location representing a named location with optional child location)";
5255

@@ -2197,6 +2200,23 @@ void mlir::python::populateIRCore(py::module &m) {
21972200
},
21982201
py::arg("filename"), py::arg("line"), py::arg("col"),
21992202
py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2203+
.def_static(
2204+
"fused",
2205+
[](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
2206+
DefaultingPyMlirContext context) {
2207+
if (pyLocations.empty())
2208+
throw py::value_error("No locations provided");
2209+
llvm::SmallVector<MlirLocation, 4> locations;
2210+
locations.reserve(pyLocations.size());
2211+
for (auto &pyLocation : pyLocations)
2212+
locations.push_back(pyLocation.get());
2213+
MlirLocation location = mlirLocationFusedGet(
2214+
context->get(), locations.size(), locations.data(),
2215+
metadata ? metadata->get() : MlirAttribute{0});
2216+
return PyLocation(context->getRef(), location);
2217+
},
2218+
py::arg("locations"), py::arg("metadata") = py::none(),
2219+
py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
22002220
.def_static(
22012221
"name",
22022222
[](std::string name, llvm::Optional<PyLocation> childLoc,

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,8 @@ class Location:
658658
@staticmethod
659659
def file(filename: str, line: int, col: int, context: Optional["Context"] = None) -> "Location": ...
660660
@staticmethod
661+
def fused(locations: Sequence["Location"], metadata: Optional["Attribute"] = None, context: Optional["Context"] = None) -> "Location": ...
662+
@staticmethod
661663
def name(name: str, childLoc: Optional["Location"] = None, context: Optional["Context"] = None) -> "Location": ...
662664
@staticmethod
663665
def unknown(context: Optional["Context"] = None) -> Any: ...

mlir/test/python/ir/location.py

+21
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ def testCallSite():
7575
run(testCallSite)
7676

7777

78+
# CHECK-LABEL: TEST: testFused
79+
def testFused():
80+
with Context() as ctx:
81+
loc = Location.fused(
82+
[Location.name("apple"), Location.name("banana")])
83+
attr = Attribute.parse('"sauteed"')
84+
loc_attr = Location.fused([Location.name("carrot"),
85+
Location.name("potatoes")], attr)
86+
ctx = None
87+
# CHECK: file str: loc(fused["apple", "banana"])
88+
print("file str:", str(loc))
89+
# CHECK: file repr: loc(fused["apple", "banana"])
90+
print("file repr:", repr(loc))
91+
# CHECK: file str: loc(fused<"sauteed">["carrot", "potatoes"])
92+
print("file str:", str(loc_attr))
93+
# CHECK: file repr: loc(fused<"sauteed">["carrot", "potatoes"])
94+
print("file repr:", repr(loc_attr))
95+
96+
run(testFused)
97+
98+
7899
# CHECK-LABEL: TEST: testLocationCapsule
79100
def testLocationCapsule():
80101
with Context() as ctx:

0 commit comments

Comments
 (0)