Skip to content

Commit 15757ea

Browse files
author
gysit
committed
[mlir][OpDSL] Add TypeFn class.
This revision introduces a the `TypeFn` class that similar to the `PrimFn` class contains an extensible set of type conversion functions. Having the same mechanism for both type conversion functions and arithmetic functions improves code consistency. Additionally, having an explicit function class and function name is a prerequisite to specify a conversion or arithmetic function via attribute. In a follow up commits, we will introduce function attributes to make OpDSL operations more generic. In particular, the goal is to handle signed and unsigned computation in one operations. Today, there is a linalg.matmul and a linalg.matmul_unsigned. The commit implements the following changes: - Introduce the class of type conversion functions `TypeFn` - Replace the hardwired cast and cast_unsigned ops by the `TypeFn` counterparts - Adapt the python and C++ code generation paths to support the new cast operations Example: ``` cast(U, A[D.m, D.k]) ``` changes to ``` TypeFn.cast(U, A[D.m, D.k]) ``` Depends On D115237 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115239
1 parent babad7c commit 15757ea

File tree

17 files changed

+473
-386
lines changed

17 files changed

+473
-386
lines changed

mlir/docs/Dialects/Linalg/OpDSL.md

+30-15
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
5656
"""
5757
domain(D.m, D.n, D.k)
5858
implements(ContractionOpInterface)
59-
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
59+
C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n])
6060
```
6161

6262
Here we have a simple type polymorphic contraction that takes arguments `A` and
@@ -159,8 +159,8 @@ def pooling_poly(
159159
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
160160
strides=IndexAttrDef(S.SH, S.SW),
161161
dilations=IndexAttrDef(S.DH, S.DW)):
162-
O[D.n, D.oh, D.ow, D.c] += \
163-
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
162+
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
163+
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
164164
```
165165

166166
The pooling operation does not access the shape-only tensor `K`. Instead, the
@@ -192,10 +192,18 @@ Reduction functions can appear as the outer-most function on the RHS:
192192
* `ReduceFn.mul`
193193
* `ReduceFn.max`
194194

195+
Additionally, type conversion functions cast an operand to a target type:
196+
197+
* `TypeFn.cast(TypeVar, operand)`
198+
* `TypeFn.cast_unsigned(TypeVar, operand)`
199+
200+
As the integer types are signless, signedness is implement by different
201+
functions that treat integers as signed (`TypeFn.cast`) or unsigned
202+
(`TypeFn.cast_unsigned`) values.
203+
195204
There are also special forms:
196205

197-
* `cast(TypeVar, operand)` casts the `operand` to the target type `TypeVar`.
198-
* `const(TypeVar, value)` returns a constant value of type `TypeVar`.
206+
* `const(value)` returns a constant value.
199207
* `index(dim)` returns the iteration index in the given dimension `dim`.
200208

201209
## Types
@@ -206,18 +214,25 @@ output types of constructed ops. An exception are predefined types such as
206214
computations with a type that is independent of the input and output types. For
207215
example, parts of floating point computation may require double precision
208216
arithmetic despite all inputs and outputs being single precision values.
209-
Assignment expressions with no `cast` calls will generally require uniform types
210-
throughout and will fail to verify if violated. The presence of a `cast` allows
211-
for a limited form of numeric type conversion between element types that can be
212-
derived from inputs and outputs (and in the future, attributes). `cast` calls
213-
with a `TypeVar` first argument are emitted as `symbolic_cast` primitives in the
214-
YAML definition.
217+
Assignment expressions with no `TypeFn.cast` calls will generally require
218+
uniform types throughout and will fail to verify if violated. The presence of a
219+
`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric
220+
type conversion between element types that can be derived from inputs and
221+
outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar`
222+
first argument are emitted as `type_fn` primitives in the YAML definition.
215223

216224
Casting will perform `int<->float` and `index->int` type conversions and will
217-
perform any necessary extension or truncation within type family. Note that
218-
presently, any integer type is assumed to be signed for the purpose of
219-
determining how to extend or truncate. Supporting unsigned integer types is left
220-
for future work.
225+
perform any necessary extension or truncation within the type family. The
226+
integer types themselves are signless and signedness is implemented by
227+
functions/operations. The `TypeFn.cast` function treats all integers as signed,
228+
while `TypeFn.cast_unsigned` treats them as unsigned.
229+
230+
The following examples illustrate the lowering of signed and unsigned functions:
231+
232+
* cast(I32 -> I64) -> `arith.ExtSIOp`
233+
* cast(F32 -> I32) -> `arith.FPToSIOp`
234+
* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
235+
* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp`
221236

222237
Not all functions are applicable for all numeric types, and on mismatch, op
223238
verification will fail.

0 commit comments

Comments
 (0)