Skip to content

Commit a14057d

Browse files
committed
[mlir][sparse] Add more complex operations.
Support complex operations sqrt, expm1, and tanh. Add tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D126393
1 parent 338e76f commit a14057d

File tree

3 files changed

+113
-5
lines changed

3 files changed

+113
-5
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

+3
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@ enum Kind {
3535
kCeilF,
3636
kFloorF,
3737
kSqrtF,
38+
kSqrtC,
3839
kExpm1F,
40+
kExpm1C,
3941
kLog1pF,
4042
kLog1pC,
4143
kSinF,
4244
kSinC,
4345
kTanhF,
46+
kTanhC,
4447
kNegF,
4548
kNegC,
4649
kNegI,

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
4141
case kCeilF:
4242
case kFloorF:
4343
case kSqrtF:
44+
case kSqrtC:
4445
case kExpm1F:
46+
case kExpm1C:
4547
case kLog1pF:
4648
case kLog1pC:
4749
case kSinF:
4850
case kSinC:
4951
case kTanhF:
52+
case kTanhC:
5053
case kNegF:
5154
case kNegC:
5255
case kNegI:
@@ -284,12 +287,15 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
284287
case kCeilF:
285288
case kFloorF:
286289
case kSqrtF:
290+
case kSqrtC:
287291
case kExpm1F:
292+
case kExpm1C:
288293
case kLog1pF:
289294
case kLog1pC:
290295
case kSinF:
291296
case kSinC:
292297
case kTanhF:
298+
case kTanhC:
293299
case kNegF:
294300
case kNegC:
295301
case kNegI:
@@ -360,8 +366,10 @@ static const char *kindToOpSymbol(Kind kind) {
360366
case kFloorF:
361367
return "floor";
362368
case kSqrtF:
369+
case kSqrtC:
363370
return "sqrt";
364371
case kExpm1F:
372+
case kExpm1C:
365373
return "expm1";
366374
case kLog1pF:
367375
case kLog1pC:
@@ -370,6 +378,7 @@ static const char *kindToOpSymbol(Kind kind) {
370378
case kSinC:
371379
return "sin";
372380
case kTanhF:
381+
case kTanhC:
373382
return "tanh";
374383
case kNegF:
375384
case kNegC:
@@ -449,10 +458,13 @@ void Merger::dumpExp(unsigned e) const {
449458
case kCeilF:
450459
case kFloorF:
451460
case kSqrtF:
461+
case kSqrtC:
452462
case kExpm1F:
463+
case kExpm1C:
453464
case kLog1pF:
454465
case kSinF:
455466
case kTanhF:
467+
case kTanhC:
456468
case kNegF:
457469
case kNegI:
458470
case kTruncF:
@@ -555,12 +567,15 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
555567
case kCRe:
556568
case kFloorF:
557569
case kSqrtF:
570+
case kSqrtC:
558571
case kExpm1F:
572+
case kExpm1C:
559573
case kLog1pF:
560574
case kLog1pC:
561575
case kSinF:
562576
case kSinC:
563577
case kTanhF:
578+
case kTanhC:
564579
case kNegF:
565580
case kNegC:
566581
case kNegI:
@@ -785,8 +800,12 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
785800
return addExp(kFloorF, e);
786801
if (isa<math::SqrtOp>(def))
787802
return addExp(kSqrtF, e);
803+
if (isa<complex::SqrtOp>(def))
804+
return addExp(kSqrtC, e);
788805
if (isa<math::ExpM1Op>(def))
789806
return addExp(kExpm1F, e);
807+
if (isa<complex::Expm1Op>(def))
808+
return addExp(kExpm1C, e);
790809
if (isa<math::Log1pOp>(def))
791810
return addExp(kLog1pF, e);
792811
if (isa<complex::Log1pOp>(def))
@@ -797,6 +816,8 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
797816
return addExp(kSinC, e);
798817
if (isa<math::TanhOp>(def))
799818
return addExp(kTanhF, e);
819+
if (isa<complex::TanhOp>(def))
820+
return addExp(kTanhC, e);
800821
if (isa<arith::NegFOp>(def))
801822
return addExp(kNegF, e); // no negi in std
802823
if (isa<complex::NegOp>(def))
@@ -952,8 +973,12 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
952973
return rewriter.create<math::FloorOp>(loc, v0);
953974
case kSqrtF:
954975
return rewriter.create<math::SqrtOp>(loc, v0);
976+
case kSqrtC:
977+
return rewriter.create<complex::SqrtOp>(loc, v0);
955978
case kExpm1F:
956979
return rewriter.create<math::ExpM1Op>(loc, v0);
980+
case kExpm1C:
981+
return rewriter.create<complex::Expm1Op>(loc, v0);
957982
case kLog1pF:
958983
return rewriter.create<math::Log1pOp>(loc, v0);
959984
case kLog1pC:
@@ -964,6 +989,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
964989
return rewriter.create<complex::SinOp>(loc, v0);
965990
case kTanhF:
966991
return rewriter.create<math::TanhOp>(loc, v0);
992+
case kTanhC:
993+
return rewriter.create<complex::TanhOp>(loc, v0);
967994
case kNegF:
968995
return rewriter.create<arith::NegFOp>(loc, v0);
969996
case kNegC:

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir

+83-5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,54 @@ module {
5959
return %0 : tensor<?xcomplex<f64>, #SparseVector>
6060
}
6161

62+
func.func @complex_sqrt(%arga: tensor<?xcomplex<f64>, #SparseVector>)
63+
-> tensor<?xcomplex<f64>, #SparseVector> {
64+
%c0 = arith.constant 0 : index
65+
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
66+
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
67+
%0 = linalg.generic #trait_op1
68+
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
69+
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
70+
^bb(%a: complex<f64>, %x: complex<f64>):
71+
%1 = complex.sqrt %a : complex<f64>
72+
linalg.yield %1 : complex<f64>
73+
} -> tensor<?xcomplex<f64>, #SparseVector>
74+
return %0 : tensor<?xcomplex<f64>, #SparseVector>
75+
}
76+
77+
func.func @complex_tanh(%arga: tensor<?xcomplex<f64>, #SparseVector>)
78+
-> tensor<?xcomplex<f64>, #SparseVector> {
79+
%c0 = arith.constant 0 : index
80+
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
81+
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
82+
%0 = linalg.generic #trait_op1
83+
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
84+
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
85+
^bb(%a: complex<f64>, %x: complex<f64>):
86+
%1 = complex.tanh %a : complex<f64>
87+
linalg.yield %1 : complex<f64>
88+
} -> tensor<?xcomplex<f64>, #SparseVector>
89+
return %0 : tensor<?xcomplex<f64>, #SparseVector>
90+
}
91+
92+
func.func @clog1p_expm1(%arga: tensor<?xcomplex<f64>, #SparseVector>)
93+
-> tensor<?xcomplex<f64>, #SparseVector> {
94+
%c0 = arith.constant 0 : index
95+
%d = tensor.dim %arga, %c0 : tensor<?xcomplex<f64>, #SparseVector>
96+
%xv = sparse_tensor.init [%d] : tensor<?xcomplex<f64>, #SparseVector>
97+
%0 = linalg.generic #trait_op1
98+
ins(%arga: tensor<?xcomplex<f64>, #SparseVector>)
99+
outs(%xv: tensor<?xcomplex<f64>, #SparseVector>) {
100+
^bb(%a: complex<f64>, %x: complex<f64>):
101+
%1 = complex.log1p %a : complex<f64>
102+
// TODO(bixia): Enable this line after adding complex.expm1 to
103+
// complex to standard lowering.
104+
// %2 = complex.expm1 %1 : complex<f64>
105+
linalg.yield %1 : complex<f64>
106+
} -> tensor<?xcomplex<f64>, #SparseVector>
107+
return %0 : tensor<?xcomplex<f64>, #SparseVector>
108+
}
109+
62110
func.func @cdiv(%arga: tensor<?xcomplex<f64>, #SparseVector>)
63111
-> tensor<?xcomplex<f64>, #SparseVector> {
64112
%c0 = arith.constant 0 : index
@@ -131,9 +179,15 @@ module {
131179
tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
132180
%1 = call @csin(%sv1)
133181
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
134-
%2 = call @cdiv(%sv1)
182+
%2 = call @complex_sqrt(%sv1)
183+
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
184+
%3 = call @complex_tanh(%sv2)
185+
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
186+
%4 = call @clog1p_expm1(%sv1)
135187
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
136-
%3 = call @cabs(%sv1)
188+
%5 = call @cdiv(%sv1)
189+
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xcomplex<f64>, #SparseVector>
190+
%6 = call @cabs(%sv1)
137191
: (tensor<?xcomplex<f64>, #SparseVector>) -> tensor<?xf64, #SparseVector>
138192

139193
//
@@ -157,23 +211,47 @@ module {
157211
// CHECK-NEXT: -193.43
158212
// CHECK-NEXT: 57.2184
159213
call @dumpc(%1, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
214+
// CHECK-NEXT: 0.433635
215+
// CHECK-NEXT: 2.30609
216+
// CHECK-NEXT: 2
217+
// CHECK-NEXT: 1
218+
// CHECK-NEXT: 2.53083
219+
// CHECK-NEXT: 1.18538
220+
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
221+
// CHECK-NEXT: 0.761594
222+
// CHECK-NEXT: 0
223+
// CHECK-NEXT: -0.964028
224+
// CHECK-NEXT: 0
225+
// CHECK-NEXT: 0.995055
226+
// CHECK-NEXT: 0
227+
call @dumpc(%3, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
228+
// CHECK-NEXT: 1.52361
229+
// CHECK-NEXT: 2.69061
230+
// CHECK-NEXT: 1.73287
231+
// CHECK-NEXT: 0.785398
232+
// CHECK-NEXT: 2.13833
233+
// CHECK-NEXT: 0.785398
234+
call @dumpc(%4, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
160235
// CHECK-NEXT: -2.565
161236
// CHECK-NEXT: 1
162237
// CHECK-NEXT: 1.5
163238
// CHECK-NEXT: 2
164239
// CHECK-NEXT: 2.5
165240
// CHECK-NEXT: 3
166-
call @dumpc(%2, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
241+
call @dumpc(%5, %d3) : (tensor<?xcomplex<f64>, #SparseVector>, index) -> ()
167242
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
168-
call @dumpf(%3) : (tensor<?xf64, #SparseVector>) -> ()
243+
call @dumpf(%6) : (tensor<?xf64, #SparseVector>) -> ()
169244

170245
// Release the resources.
171246
sparse_tensor.release %sv1 : tensor<?xcomplex<f64>, #SparseVector>
172247
sparse_tensor.release %sv2 : tensor<?xcomplex<f64>, #SparseVector>
173248
sparse_tensor.release %0 : tensor<?xcomplex<f64>, #SparseVector>
174249
sparse_tensor.release %1 : tensor<?xcomplex<f64>, #SparseVector>
175250
sparse_tensor.release %2 : tensor<?xcomplex<f64>, #SparseVector>
176-
sparse_tensor.release %3 : tensor<?xf64, #SparseVector>
251+
sparse_tensor.release %3 : tensor<?xcomplex<f64>, #SparseVector>
252+
sparse_tensor.release %4 : tensor<?xcomplex<f64>, #SparseVector>
253+
sparse_tensor.release %5 : tensor<?xcomplex<f64>, #SparseVector>
254+
sparse_tensor.release %6 : tensor<?xf64, #SparseVector>
177255
return
178256
}
179257
}

0 commit comments

Comments
 (0)