Skip to content

Commit 8d615a2

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.
Reviewed By: aartbik, bixia Differential Revision: https://reviews.llvm.org/D138223
1 parent 48dbf35 commit 8d615a2

File tree

4 files changed

+98
-38
lines changed

4 files changed

+98
-38
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,36 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
10241024
;
10251025
return op;
10261026
}
1027+
1028+
void sparse_tensor::foreachInSparseConstant(
1029+
Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
1030+
function_ref<void(ArrayRef<Value>, Value)> callback) {
1031+
int64_t rank = attr.getType().getRank();
1032+
// Foreach on constant.
1033+
DenseElementsAttr indicesAttr = attr.getIndices();
1034+
DenseElementsAttr valuesAttr = attr.getValues();
1035+
1036+
SmallVector<Value> coords;
1037+
for (int i = 0, e = valuesAttr.size(); i < e; i++) {
1038+
coords.clear();
1039+
for (int j = 0; j < rank; j++) {
1040+
auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
1041+
auto coord =
1042+
rewriter.create<arith::ConstantIndexOp>(loc, coordAttr.getInt());
1043+
// Remaps coordinates.
1044+
coords.push_back(coord);
1045+
}
1046+
Value val;
1047+
if (attr.getElementType().isa<ComplexType>()) {
1048+
auto valAttr = valuesAttr.getValues<ArrayAttr>()[i];
1049+
val = rewriter.create<complex::ConstantOp>(loc, attr.getElementType(),
1050+
valAttr);
1051+
} else {
1052+
auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
1053+
// Remaps value.
1054+
val = rewriter.create<arith::ConstantOp>(loc, valAttr);
1055+
}
1056+
assert(val);
1057+
callback(coords, val);
1058+
}
1059+
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

+23-3
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,26 @@ void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
183183
/// Scans to top of generated loop.
184184
Operation *getTop(Operation *op);
185185

186+
/// Iterate over a sparse constant, generates constantOp for value and indices.
187+
/// E.g.,
188+
/// sparse<[ [0], [28], [31] ],
189+
/// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] >
190+
/// =>
191+
/// %c1 = arith.constant 0
192+
/// %v1 = complex.constant (5.13, 2.0)
193+
/// callback({%c1}, %v1)
194+
///
195+
/// %c2 = arith.constant 28
196+
/// %v2 = complex.constant (3.0, 4.0)
197+
/// callback({%c2}, %v2)
198+
///
199+
/// %c3 = arith.constant 31
200+
/// %v3 = complex.constant (5.0, 6.0)
201+
/// callback({%c3}, %v3)
202+
void foreachInSparseConstant(
203+
Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
204+
function_ref<void(ArrayRef<Value>, Value)> callback);
205+
186206
//===----------------------------------------------------------------------===//
187207
// Inlined constant generators.
188208
//
@@ -197,9 +217,9 @@ Operation *getTop(Operation *op);
197217
//===----------------------------------------------------------------------===//
198218

199219
/// Generates a 0-valued constant of the given type. In addition to
200-
/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
201-
/// this also works for `RankedTensorType` and `VectorType` (for which it
202-
/// generates a constant `DenseElementsAttr` of zeros).
220+
/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`,
221+
/// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
222+
/// (for which it generates a constant `DenseElementsAttr` of zeros).
203223
inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
204224
if (auto ctp = tp.dyn_cast<ComplexType>()) {
205225
auto zeroe = builder.getZeroAttr(ctp.getElementType());

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

+30-30
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,35 @@ static void getDynamicSizes(RankedTensorType tp,
170170
}
171171
}
172172

173+
static LogicalResult genForeachOnSparseConstant(ForeachOp op,
174+
RewriterBase &rewriter,
175+
SparseElementsAttr attr) {
176+
auto loc = op.getLoc();
177+
SmallVector<Value> reduc = op.getInitArgs();
178+
179+
// Foreach on constant.
180+
foreachInSparseConstant(
181+
loc, rewriter, attr,
182+
[&reduc, &rewriter, op](ArrayRef<Value> coords, Value v) mutable {
183+
SmallVector<Value> args;
184+
args.append(coords.begin(), coords.end());
185+
args.push_back(v);
186+
args.append(reduc);
187+
// Clones the foreach op to get a copy of the loop body.
188+
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
189+
assert(args.size() == cloned.getBody()->getNumArguments());
190+
Operation *yield = cloned.getBody()->getTerminator();
191+
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
192+
// clean up
193+
rewriter.eraseOp(cloned);
194+
reduc = yield->getOperands();
195+
rewriter.eraseOp(yield);
196+
});
197+
198+
rewriter.replaceOp(op, reduc);
199+
return success();
200+
}
201+
173202
//===---------------------------------------------------------------------===//
174203
// The actual sparse tensor rewriting rules.
175204
//===---------------------------------------------------------------------===//
@@ -752,36 +781,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
752781
// rule.
753782
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
754783
if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
755-
// Foreach on constant.
756-
DenseElementsAttr indicesAttr = attr.getIndices();
757-
DenseElementsAttr valuesAttr = attr.getValues();
758-
759-
SmallVector<Value> args;
760-
for (int i = 0, e = valuesAttr.size(); i < e; i++) {
761-
auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
762-
for (int j = 0; j < rank; j++) {
763-
auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
764-
auto coord = rewriter.create<arith::ConstantIndexOp>(
765-
loc, coordAttr.getInt());
766-
// Remaps coordinates.
767-
args.push_back(coord);
768-
}
769-
// Remaps value.
770-
auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
771-
args.push_back(val);
772-
// Remaps iteration args.
773-
args.append(reduc);
774-
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
775-
Operation *yield = cloned.getBody()->getTerminator();
776-
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
777-
// clean up
778-
args.clear();
779-
rewriter.eraseOp(cloned);
780-
reduc = yield->getOperands();
781-
rewriter.eraseOp(yield);
782-
}
783-
rewriter.replaceOp(op, reduc);
784-
return success();
784+
return genForeachOnSparseConstant(op, rewriter, attr);
785785
}
786786
}
787787

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
// RUN: mlir-opt %s --sparse-compiler | \
2-
// RUN: mlir-cpu-runner \
3-
// RUN: -e entry -entry-point-result=void \
4-
// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
5-
// RUN: FileCheck %s
1+
// DEFINE: %{option} = enable-runtime-library=true
2+
// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
3+
// DEFINE: mlir-cpu-runner \
4+
// DEFINE: -e entry -entry-point-result=void \
5+
// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
6+
// DEFINE: FileCheck %s
7+
//
8+
// RUN: %{command}
9+
//
10+
// Do the same run, but now with direct IR generation.
11+
// REDEFINE: %{option} = enable-runtime-library=false
12+
// RUN: %{command}
613

714
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
815

0 commit comments

Comments
 (0)