9
9
"""
10
10
11
11
from typing import Any , Dict , List , Optional , Sequence , Set , Tuple
12
+ from enum import Enum
12
13
13
14
from mlir import ir as _ir
14
15
@@ -133,18 +134,31 @@ def __repr__(self):
133
134
return f"{ self .tensor_name } [{ ', ' .join ([repr (i ) for i in self .indices ])} ]"
134
135
135
136
137
+ class OperandKind (Enum ):
138
+ InputTensor = 0
139
+ Scalar = 1
140
+ OutputTensor = 2
141
+ Attribute = 3
142
+
143
+
136
144
class OperandDef :
137
- """Definition of a Tensor or Scalar operand passed to an operation."""
145
+ """Definition of an operand passed to an operation.
146
+
147
+ Keep the meta information of Tensor, Scalar, and Attribute operands and
148
+ provide the shared registration functionality.
149
+ """
138
150
139
- def __init__ (self , type_var : TypeVar , shape : Sequence [AffineExprDef ],
140
- scalar : bool , output : bool ):
151
+ def __init__ (self ,
152
+ kind : OperandKind ,
153
+ type_var : TypeVar ,
154
+ size_exprs : Optional [Sequence [AffineExprDef ]] = None ):
141
155
if not isinstance (type_var , TypeVar ):
142
- raise ValueError (f"OperandDef requires a TypeVar. Got: { repr (type_var )} " )
156
+ raise ValueError (
157
+ f"OperandDef requires a TypeVar but got { repr (type_var )} " )
143
158
self .owner = None # type: Optional["LinalgOpDef"]
144
159
self .type_var = type_var
145
- self .shape = shape
146
- self .scalar = scalar
147
- self .output = output
160
+ self .size_exprs = size_exprs
161
+ self .kind = kind
148
162
self .name = None # type: Optional[str]
149
163
self .registered_index = - 1 # type: int
150
164
@@ -159,25 +173,26 @@ def __hash__(self):
159
173
return hash (id (self ))
160
174
161
175
def __repr__ (self ):
162
- output = "OUTPUT " if self .output else ""
163
- scalar = "SCALAR " if self .scalar else ""
164
- return (f"{ self .name } :OperandDef({ output } { scalar } "
165
- f"{ repr (self .type_var )} , shape={ self .shape } )" )
176
+ return (f"{ self .name } :OperandDef(kind={ self .kind .name } , "
177
+ f"type={ repr (self .type_var )} , size_exprs={ self .size_exprs } )" )
166
178
167
179
168
180
class TensorDef :
169
181
"""Tensor operand definition.
170
182
171
183
Tensor operands are indexed using the associated indexing_map when forwarded
172
184
to the body of the structured op. A unique name identifies the tensor operands
173
- and an index determines their position in the operation's parameter list.
185
+ and an index determines their position in the operation's parameter list. A
186
+ tensor definition takes type, a shape, and an optional flag to mark output
187
+ tensors.
174
188
"""
175
189
176
190
def __init__ (self ,
177
191
type_var : TypeVar ,
178
192
* shape : AffineExprDef ,
179
193
output : bool = False ):
180
- self .operand_def = OperandDef (type_var , shape , False , output )
194
+ kind = OperandKind .OutputTensor if output else OperandKind .InputTensor
195
+ self .operand_def = OperandDef (kind , type_var , size_exprs = shape )
181
196
182
197
def __getitem__ (self , dims ) -> TensorUse :
183
198
assert self .operand_def .owner , "TensorDef is not attached to an op"
@@ -221,7 +236,7 @@ class ScalarDef(TensorExpression):
221
236
"""
222
237
223
238
def __init__ (self , type_var : TypeVar ):
224
- self .operand_def = OperandDef (type_var , (), True , False )
239
+ self .operand_def = OperandDef (OperandKind . Scalar , type_var )
225
240
226
241
@property
227
242
def scalar_name (self ) -> str :
@@ -233,6 +248,22 @@ def to_scalar_expression(self) -> ScalarExpression:
233
248
return ScalarArg (self .scalar_name ).expr ()
234
249
235
250
251
+ class AttributeDef :
252
+ """Index Attribute definition.
253
+
254
+ Index attributes provide a way to define and set symbols that can be used in
255
+ indexing expressions. Every attribute specifies a tuple of symbols that at
256
+ compile-time are replaced by integer values.
257
+ """
258
+ yaml_tag = "!LinalgAttributeDef"
259
+
260
+ def __init__ (self , * sizes : SymbolDef ):
261
+ if any (not isinstance (size , SymbolDef ) for size in sizes ):
262
+ raise ValueError (f"AttributeDef requires sizes of type SymbolDef but got "
263
+ f"{ type (sizes )} " )
264
+ self .operand_def = OperandDef (OperandKind .Attribute , I64 , size_exprs = sizes )
265
+
266
+
236
267
class Comprehension :
237
268
"""Represents a single comprehension."""
238
269
@@ -303,7 +334,7 @@ class ReduceFnType:
303
334
def __init__ (self , operator : PrimFnType , * reduce_dims : DimDef ):
304
335
"""Initializes the ReduceFn with a primitive function and dims."""
305
336
if not isinstance (operator , PrimFnType ):
306
- raise ValueError (f"Reduce expected a Prim operator. Got: { operator } " )
337
+ raise ValueError (f"Reduce expected a Prim operator but got { operator } " )
307
338
self .operator = operator
308
339
self .reduce_dims = tuple (reduce_dims )
309
340
@@ -353,7 +384,7 @@ def __init__(self, value: Any):
353
384
self .value = str (
354
385
_ir .IntegerAttr .get (_ir .IntegerType .get_signless (64 ), int (value )))
355
386
else :
356
- raise ValueError (f"const requires int or float. Got: { type (value )} " )
387
+ raise ValueError (f"const requires int or float but got { type (value )} " )
357
388
358
389
def to_scalar_expression (self ) -> ScalarExpression :
359
390
return ScalarConst (self .value ).expr ()
@@ -475,21 +506,22 @@ def __init__(self,
475
506
self .comprehensions = list () # type: List[Comprehension]
476
507
self ._affine_state = AffineBuildState ()
477
508
478
- @property
479
- def outputs (self ) -> Sequence [OperandDef ]:
480
- return [
481
- operand for operand in self .registered_operands .values ()
482
- if operand .output
483
- ]
484
-
485
509
def add_operand (self , name : str , operand : OperandDef ):
486
510
"""Registers an operand."""
487
511
if name in self .registered_operands :
488
512
raise ValueError (f"The operand { name } is already registered "
489
513
f"to { self .registered_operands ['name' ]} " )
490
- if not operand .output and self .outputs :
491
- raise ValueError (f"The operand { name } is an input registered after "
492
- f"the output { self .outputs [- 1 ]} " )
514
+ # Ensure output tensors are registered after input tensors and scalars and
515
+ # attributes are registered after all other operand types.
516
+ registered_kinds = [
517
+ operand .kind .value for operand in self .registered_operands .values ()
518
+ ]
519
+ if registered_kinds :
520
+ maximum = max (registered_kinds )
521
+ if maximum > operand .kind .value and maximum > OperandKind .Scalar .value :
522
+ raise ValueError (
523
+ f"The operand { name } of kind { operand .kind .name } is registered "
524
+ f"after an operand of kind { OperandKind (maximum ).name } " )
493
525
operand .attach (len (self .registered_operands ), name , self )
494
526
self .registered_operands [name ] = operand
495
527
0 commit comments