Skip to content

Commit 275fe3a

Browse files
authored
[mlir][sparse] support complex type for sparse_tensor.print (#83934)
With an integration test example
1 parent 3e40c96 commit 275fe3a

File tree

2 files changed

+87
-74
lines changed

2 files changed

+87
-74
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

+15-1
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,21 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
692692
rewriter.setInsertionPointToStart(forOp.getBody());
693693
auto idx = forOp.getInductionVar();
694694
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
695-
rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
695+
if (llvm::isa<ComplexType>(val.getType())) {
696+
// Since the vector dialect does not support complex types in any op,
697+
// we split those into (real, imag) pairs here.
698+
Value real = rewriter.create<complex::ReOp>(loc, val);
699+
Value imag = rewriter.create<complex::ImOp>(loc, val);
700+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
701+
rewriter.create<vector::PrintOp>(loc, real,
702+
vector::PrintPunctuation::Comma);
703+
rewriter.create<vector::PrintOp>(loc, imag,
704+
vector::PrintPunctuation::Close);
705+
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
706+
} else {
707+
rewriter.create<vector::PrintOp>(loc, val,
708+
vector::PrintPunctuation::Comma);
709+
}
696710
rewriter.setInsertionPointAfter(forOp);
697711
// Close bracket and end of line.
698712
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir

+72-73
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
1111
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
1212
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13-
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
13+
// DEFINE: %{run_opts} = -e main -entry-point-result=void
1414
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
1515
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
1616
//
@@ -162,31 +162,8 @@ module {
162162
return %0 : tensor<?xf64, #SparseVector>
163163
}
164164

165-
func.func @dumpc(%arg0: tensor<?xcomplex<f64>, #SparseVector>, %d: index) {
166-
%c0 = arith.constant 0 : index
167-
%c1 = arith.constant 1 : index
168-
%mem = sparse_tensor.values %arg0 : tensor<?xcomplex<f64>, #SparseVector> to memref<?xcomplex<f64>>
169-
scf.for %i = %c0 to %d step %c1 {
170-
%v = memref.load %mem[%i] : memref<?xcomplex<f64>>
171-
%real = complex.re %v : complex<f64>
172-
%imag = complex.im %v : complex<f64>
173-
vector.print %real : f64
174-
vector.print %imag : f64
175-
}
176-
return
177-
}
178-
179-
func.func @dumpf(%arg0: tensor<?xf64, #SparseVector>) {
180-
%c0 = arith.constant 0 : index
181-
%d0 = arith.constant 0.0 : f64
182-
%values = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
183-
%0 = vector.transfer_read %values[%c0], %d0: memref<?xf64>, vector<3xf64>
184-
vector.print %0 : vector<3xf64>
185-
return
186-
}
187-
188165
// Driver method to call and verify complex kernels.
189-
func.func @entry() {
166+
func.func @main() {
190167
// Setup sparse vectors.
191168
%v1 = arith.constant sparse<
192169
[ [0], [28], [31] ],
@@ -217,54 +194,76 @@ module {
217194
//
218195
// Verify the results.
219196
//
220-
%d3 = arith.constant 3 : index
221-
%d4 = arith.constant 4 : index
222-
// CHECK: -5.13
223-
// CHECK-NEXT: 2
224-
// CHECK-NEXT: 1
225-
// CHECK-NEXT: 0
226-
// CHECK-NEXT: 1
227-
// CHECK-NEXT: 4
228-
// CHECK-NEXT: 8
229-
// CHECK-NEXT: 6
230-
call @dumpc(%0, %d4) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
231-
// CHECK-NEXT: 3.43887
232-
// CHECK-NEXT: 1.47097
233-
// CHECK-NEXT: 3.85374
234-
// CHECK-NEXT: -27.0168
235-
// CHECK-NEXT: -193.43
236-
// CHECK-NEXT: 57.2184
237-
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
238-
// CHECK-NEXT: 0.433635
239-
// CHECK-NEXT: 2.30609
240-
// CHECK-NEXT: 2
241-
// CHECK-NEXT: 1
242-
// CHECK-NEXT: 2.53083
243-
// CHECK-NEXT: 1.18538
244-
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
245-
// CHECK-NEXT: 0.761594
246-
// CHECK-NEXT: 0
247-
// CHECK-NEXT: -0.964028
248-
// CHECK-NEXT: 0
249-
// CHECK-NEXT: 0.995055
250-
// CHECK-NEXT: 0
251-
call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
252-
// CHECK-NEXT: -5.13
253-
// CHECK-NEXT: 2
254-
// CHECK-NEXT: 3
255-
// CHECK-NEXT: 4
256-
// CHECK-NEXT: 5
257-
// CHECK-NEXT: 6
258-
call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
259-
// CHECK-NEXT: -2.565
260-
// CHECK-NEXT: 1
261-
// CHECK-NEXT: 1.5
262-
// CHECK-NEXT: 2
263-
// CHECK-NEXT: 2.5
264-
// CHECK-NEXT: 3
265-
call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
266-
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
267-
call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
197+
// CHECK: ---- Sparse Tensor ----
198+
// CHECK-NEXT: nse = 4
199+
// CHECK-NEXT: dim = ( 32 )
200+
// CHECK-NEXT: lvl = ( 32 )
201+
// CHECK-NEXT: pos[0] : ( 0, 4,
202+
// CHECK-NEXT: crd[0] : ( 0, 1, 28, 31,
203+
// CHECK-NEXT: values : ( ( -5.13, 2 ), ( 1, 0 ), ( 1, 4 ), ( 8, 6 ),
204+
// CHECK-NEXT: ----
205+
//
206+
// CHECK-NEXT: ---- Sparse Tensor ----
207+
// CHECK-NEXT: nse = 3
208+
// CHECK-NEXT: dim = ( 32 )
209+
// CHECK-NEXT: lvl = ( 32 )
210+
// CHECK-NEXT: pos[0] : ( 0, 3,
211+
// CHECK-NEXT: crd[0] : ( 0, 28, 31,
212+
// CHECK-NEXT: values : ( ( 3.43887, 1.47097 ), ( 3.85374, -27.0168 ), ( -193.43, 57.2184 ),
213+
// CHECK-NEXT: ----
214+
//
215+
// CHECK-NEXT: ---- Sparse Tensor ----
216+
// CHECK-NEXT: nse = 3
217+
// CHECK-NEXT: dim = ( 32 )
218+
// CHECK-NEXT: lvl = ( 32 )
219+
// CHECK-NEXT: pos[0] : ( 0, 3,
220+
// CHECK-NEXT: crd[0] : ( 0, 28, 31,
221+
// CHECK-NEXT: values : ( ( 0.433635, 2.30609 ), ( 2, 1 ), ( 2.53083, 1.18538 ),
222+
// CHECK-NEXT: ----
223+
//
224+
// CHECK-NEXT: ---- Sparse Tensor ----
225+
// CHECK-NEXT: nse = 3
226+
// CHECK-NEXT: dim = ( 32 )
227+
// CHECK-NEXT: lvl = ( 32 )
228+
// CHECK-NEXT: pos[0] : ( 0, 3,
229+
// CHECK-NEXT: crd[0] : ( 1, 28, 31,
230+
// CHECK-NEXT: values : ( ( 0.761594, 0 ), ( -0.964028, 0 ), ( 0.995055, 0 ),
231+
// CHECK-NEXT: ----
232+
//
233+
// CHECK-NEXT: ---- Sparse Tensor ----
234+
// CHECK-NEXT: nse = 3
235+
// CHECK-NEXT: dim = ( 32 )
236+
// CHECK-NEXT: lvl = ( 32 )
237+
// CHECK-NEXT: pos[0] : ( 0, 3,
238+
// CHECK-NEXT: crd[0] : ( 0, 28, 31,
239+
// CHECK-NEXT: values : ( ( -5.13, 2 ), ( 3, 4 ), ( 5, 6 ),
240+
// CHECK-NEXT: ----
241+
//
242+
// CHECK-NEXT: ---- Sparse Tensor ----
243+
// CHECK-NEXT: nse = 3
244+
// CHECK-NEXT: dim = ( 32 )
245+
// CHECK-NEXT: lvl = ( 32 )
246+
// CHECK-NEXT: pos[0] : ( 0, 3,
247+
// CHECK-NEXT: crd[0] : ( 0, 28, 31,
248+
// CHECK-NEXT: values : ( ( -2.565, 1 ), ( 1.5, 2 ), ( 2.5, 3 ),
249+
// CHECK-NEXT: ----
250+
//
251+
// CHECK-NEXT: ---- Sparse Tensor ----
252+
// CHECK-NEXT: nse = 3
253+
// CHECK-NEXT: dim = ( 32 )
254+
// CHECK-NEXT: lvl = ( 32 )
255+
// CHECK-NEXT: pos[0] : ( 0, 3,
256+
// CHECK-NEXT: crd[0] : ( 0, 28, 31,
257+
// CHECK-NEXT: values : ( 5.50608, 5, 7.81025,
258+
// CHECK-NEXT: ----
259+
//
260+
sparse_tensor.print %0 : tensor<?xcomplex<f64>, #SparseVector>
261+
sparse_tensor.print %1 : tensor<?xcomplex<f64>, #SparseVector>
262+
sparse_tensor.print %2 : tensor<?xcomplex<f64>, #SparseVector>
263+
sparse_tensor.print %3 : tensor<?xcomplex<f64>, #SparseVector>
264+
sparse_tensor.print %4 : tensor<?xcomplex<f64>, #SparseVector>
265+
sparse_tensor.print %5 : tensor<?xcomplex<f64>, #SparseVector>
266+
sparse_tensor.print %6 : tensor<?xf64, #SparseVector>
268267

269268
// Release the resources.
270269
bufferization.dealloc_tensor %sv1 : tensor<?xcomplex<f64>, #SparseVector>

0 commit comments

Comments
 (0)