Skip to content

Commit 31f888e

Browse files
author
Tobias Gysi
committed
[mlir][linalg][python] Add attribute support to the OpDSL.
Extend the OpDSL with index attributes. After tensors and scalars, index attributes are the third operand type. An index attribute represents a compile-time constant that is limited to index expressions. A use cases are the strides and dilations defined by convolution and pooling operations. The patch only updates the OpDSL. The C++ yaml codegen is updated by a followup patch. Differential Revision: https://reviews.llvm.org/D104711
1 parent e76c008 commit 31f888e

File tree

13 files changed

+422
-134
lines changed

13 files changed

+422
-134
lines changed

mlir/include/mlir-c/AffineMap.h

+7
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults);
169169
MLIR_CAPI_EXPORTED MlirAffineMap
170170
mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults);
171171

172+
/// Apply AffineExpr::replace(`map`) to each of the results and return a new
173+
/// new AffineMap with the new results and the specified number of dims and
174+
/// symbols.
175+
MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapReplace(
176+
MlirAffineMap affineMap, MlirAffineExpr expression,
177+
MlirAffineExpr replacement, intptr_t numResultDims, intptr_t numResultSyms);
178+
172179
/// Returns the simplified affine map resulting from dropping the symbols that
173180
/// do not appear in any of the individual maps in `affineMaps`.
174181
/// Asserts that all maps in `affineMaps` are normalized to the same number of

mlir/lib/Bindings/Python/IRAffine.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,14 @@ void mlir::python::populateIRAffine(py::module &m) {
654654
mlirAffineMapGetMinorSubMap(self, nResults);
655655
return PyAffineMap(self.getContext(), affineMap);
656656
})
657+
.def("replace",
658+
[](PyAffineMap &self, PyAffineExpr &expression,
659+
PyAffineExpr &replacement, intptr_t numResultDims,
660+
intptr_t numResultSyms) {
661+
MlirAffineMap affineMap = mlirAffineMapReplace(
662+
self, expression, replacement, numResultDims, numResultSyms);
663+
return PyAffineMap(self.getContext(), affineMap);
664+
})
657665
.def_property_readonly(
658666
"is_permutation",
659667
[](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })

mlir/lib/CAPI/IR/AffineMap.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,15 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap,
138138
return wrap(unwrap(affineMap).getMinorSubMap(numResults));
139139
}
140140

141+
MlirAffineMap mlirAffineMapReplace(MlirAffineMap affineMap,
142+
MlirAffineExpr expression,
143+
MlirAffineExpr replacement,
144+
intptr_t numResultDims,
145+
intptr_t numResultSyms) {
146+
return wrap(unwrap(affineMap).replace(unwrap(expression), unwrap(replacement),
147+
numResultDims, numResultSyms));
148+
}
149+
141150
void mlirAffineMapCompressUnusedSymbols(
142151
MlirAffineMap *affineMaps, intptr_t size, void *result,
143152
void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) {

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

+58-26
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
12+
from enum import Enum
1213

1314
from mlir import ir as _ir
1415

@@ -133,18 +134,31 @@ def __repr__(self):
133134
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
134135

135136

137+
class OperandKind(Enum):
138+
InputTensor = 0
139+
Scalar = 1
140+
OutputTensor = 2
141+
Attribute = 3
142+
143+
136144
class OperandDef:
137-
"""Definition of a Tensor or Scalar operand passed to an operation."""
145+
"""Definition of an operand passed to an operation.
146+
147+
Keep the meta information of Tensor, Scalar, and Attribute operands and
148+
provide the shared registration functionality.
149+
"""
138150

139-
def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef],
140-
scalar: bool, output: bool):
151+
def __init__(self,
152+
kind: OperandKind,
153+
type_var: TypeVar,
154+
size_exprs: Optional[Sequence[AffineExprDef]] = None):
141155
if not isinstance(type_var, TypeVar):
142-
raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}")
156+
raise ValueError(
157+
f"OperandDef requires a TypeVar but got {repr(type_var)}")
143158
self.owner = None # type: Optional["LinalgOpDef"]
144159
self.type_var = type_var
145-
self.shape = shape
146-
self.scalar = scalar
147-
self.output = output
160+
self.size_exprs = size_exprs
161+
self.kind = kind
148162
self.name = None # type: Optional[str]
149163
self.registered_index = -1 # type: int
150164

@@ -159,25 +173,26 @@ def __hash__(self):
159173
return hash(id(self))
160174

161175
def __repr__(self):
162-
output = "OUTPUT " if self.output else ""
163-
scalar = "SCALAR " if self.scalar else ""
164-
return (f"{self.name}:OperandDef({output}{scalar}"
165-
f"{repr(self.type_var)}, shape={self.shape})")
176+
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
177+
f"type={repr(self.type_var)}, size_exprs={self.size_exprs})")
166178

167179

168180
class TensorDef:
169181
"""Tensor operand definition.
170182
171183
Tensor operands are indexed using the associated indexing_map when forwarded
172184
to the body of the structured op. A unique name identifies the tensor operands
173-
and an index determines their position in the operation's parameter list.
185+
and an index determines their position in the operation's parameter list. A
186+
tensor definition takes type, a shape, and an optional flag to mark output
187+
tensors.
174188
"""
175189

176190
def __init__(self,
177191
type_var: TypeVar,
178192
*shape: AffineExprDef,
179193
output: bool = False):
180-
self.operand_def = OperandDef(type_var, shape, False, output)
194+
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
195+
self.operand_def = OperandDef(kind, type_var, size_exprs=shape)
181196

182197
def __getitem__(self, dims) -> TensorUse:
183198
assert self.operand_def.owner, "TensorDef is not attached to an op"
@@ -221,7 +236,7 @@ class ScalarDef(TensorExpression):
221236
"""
222237

223238
def __init__(self, type_var: TypeVar):
224-
self.operand_def = OperandDef(type_var, (), True, False)
239+
self.operand_def = OperandDef(OperandKind.Scalar, type_var)
225240

226241
@property
227242
def scalar_name(self) -> str:
@@ -233,6 +248,22 @@ def to_scalar_expression(self) -> ScalarExpression:
233248
return ScalarArg(self.scalar_name).expr()
234249

235250

251+
class AttributeDef:
252+
"""Index Attribute definition.
253+
254+
Index attributes provide a way to define and set symbols that can be used in
255+
indexing expressions. Every attribute specifies a tuple of symbols that at
256+
compile-time are replaced by integer values.
257+
"""
258+
yaml_tag = "!LinalgAttributeDef"
259+
260+
def __init__(self, *sizes: SymbolDef):
261+
if any(not isinstance(size, SymbolDef) for size in sizes):
262+
raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got "
263+
f"{type(sizes)}")
264+
self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
265+
266+
236267
class Comprehension:
237268
"""Represents a single comprehension."""
238269

@@ -303,7 +334,7 @@ class ReduceFnType:
303334
def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
304335
"""Initializes the ReduceFn with a primitive function and dims."""
305336
if not isinstance(operator, PrimFnType):
306-
raise ValueError(f"Reduce expected a Prim operator. Got: {operator}")
337+
raise ValueError(f"Reduce expected a Prim operator but got {operator}")
307338
self.operator = operator
308339
self.reduce_dims = tuple(reduce_dims)
309340

@@ -353,7 +384,7 @@ def __init__(self, value: Any):
353384
self.value = str(
354385
_ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
355386
else:
356-
raise ValueError(f"const requires int or float. Got: {type(value)}")
387+
raise ValueError(f"const requires int or float but got {type(value)}")
357388

358389
def to_scalar_expression(self) -> ScalarExpression:
359390
return ScalarConst(self.value).expr()
@@ -475,21 +506,22 @@ def __init__(self,
475506
self.comprehensions = list() # type: List[Comprehension]
476507
self._affine_state = AffineBuildState()
477508

478-
@property
479-
def outputs(self) -> Sequence[OperandDef]:
480-
return [
481-
operand for operand in self.registered_operands.values()
482-
if operand.output
483-
]
484-
485509
def add_operand(self, name: str, operand: OperandDef):
486510
"""Registers an operand."""
487511
if name in self.registered_operands:
488512
raise ValueError(f"The operand {name} is already registered "
489513
f"to {self.registered_operands['name']}")
490-
if not operand.output and self.outputs:
491-
raise ValueError(f"The operand {name} is an input registered after "
492-
f"the output {self.outputs[-1]}")
514+
# Ensure output tensors are registered after input tensors and scalars and
515+
# attributes are registered after all other operand types.
516+
registered_kinds = [
517+
operand.kind.value for operand in self.registered_operands.values()
518+
]
519+
if registered_kinds:
520+
maximum = max(registered_kinds)
521+
if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
522+
raise ValueError(
523+
f"The operand {name} of kind {operand.kind.name} is registered "
524+
f"after an operand of kind {OperandKind(maximum).name}")
493525
operand.attach(len(self.registered_operands), name, self)
494526
self.registered_operands[name] = operand
495527

0 commit comments

Comments
 (0)