@@ -59,6 +59,54 @@ module {
59
59
return %0 : tensor <?xcomplex <f64 >, #SparseVector >
60
60
}
61
61
62
+ func.func @complex_sqrt (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
63
+ -> tensor <?xcomplex <f64 >, #SparseVector > {
64
+ %c0 = arith.constant 0 : index
65
+ %d = tensor.dim %arga , %c0 : tensor <?xcomplex <f64 >, #SparseVector >
66
+ %xv = sparse_tensor.init [%d ] : tensor <?xcomplex <f64 >, #SparseVector >
67
+ %0 = linalg.generic #trait_op1
68
+ ins (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
69
+ outs (%xv: tensor <?xcomplex <f64 >, #SparseVector >) {
70
+ ^bb (%a: complex <f64 >, %x: complex <f64 >):
71
+ %1 = complex.sqrt %a : complex <f64 >
72
+ linalg.yield %1 : complex <f64 >
73
+ } -> tensor <?xcomplex <f64 >, #SparseVector >
74
+ return %0 : tensor <?xcomplex <f64 >, #SparseVector >
75
+ }
76
+
77
+ func.func @complex_tanh (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
78
+ -> tensor <?xcomplex <f64 >, #SparseVector > {
79
+ %c0 = arith.constant 0 : index
80
+ %d = tensor.dim %arga , %c0 : tensor <?xcomplex <f64 >, #SparseVector >
81
+ %xv = sparse_tensor.init [%d ] : tensor <?xcomplex <f64 >, #SparseVector >
82
+ %0 = linalg.generic #trait_op1
83
+ ins (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
84
+ outs (%xv: tensor <?xcomplex <f64 >, #SparseVector >) {
85
+ ^bb (%a: complex <f64 >, %x: complex <f64 >):
86
+ %1 = complex.tanh %a : complex <f64 >
87
+ linalg.yield %1 : complex <f64 >
88
+ } -> tensor <?xcomplex <f64 >, #SparseVector >
89
+ return %0 : tensor <?xcomplex <f64 >, #SparseVector >
90
+ }
91
+
92
+ func.func @clog1p_expm1 (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
93
+ -> tensor <?xcomplex <f64 >, #SparseVector > {
94
+ %c0 = arith.constant 0 : index
95
+ %d = tensor.dim %arga , %c0 : tensor <?xcomplex <f64 >, #SparseVector >
96
+ %xv = sparse_tensor.init [%d ] : tensor <?xcomplex <f64 >, #SparseVector >
97
+ %0 = linalg.generic #trait_op1
98
+ ins (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
99
+ outs (%xv: tensor <?xcomplex <f64 >, #SparseVector >) {
100
+ ^bb (%a: complex <f64 >, %x: complex <f64 >):
101
+ %1 = complex.log1p %a : complex <f64 >
102
+ // TODO(bixia): Enable this line after adding complex.expm1 to
103
+ // complex to standard lowering.
104
+ // %2 = complex.expm1 %1 : complex<f64>
105
+ linalg.yield %1 : complex <f64 >
106
+ } -> tensor <?xcomplex <f64 >, #SparseVector >
107
+ return %0 : tensor <?xcomplex <f64 >, #SparseVector >
108
+ }
109
+
62
110
func.func @cdiv (%arga: tensor <?xcomplex <f64 >, #SparseVector >)
63
111
-> tensor <?xcomplex <f64 >, #SparseVector > {
64
112
%c0 = arith.constant 0 : index
@@ -131,9 +179,15 @@ module {
131
179
tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xcomplex <f64 >, #SparseVector >
132
180
%1 = call @csin (%sv1 )
133
181
: (tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xcomplex <f64 >, #SparseVector >
134
- %2 = call @cdiv (%sv1 )
182
+ %2 = call @complex_sqrt (%sv1 )
183
+ : (tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xcomplex <f64 >, #SparseVector >
184
+ %3 = call @complex_tanh (%sv2 )
185
+ : (tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xcomplex <f64 >, #SparseVector >
186
+ %4 = call @clog1p_expm1 (%sv1 )
135
187
: (tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xcomplex <f64 >, #SparseVector >
136
- %3 = call @cabs (%sv1 )
188
+ %5 = call @cdiv (%sv1 )
189
+ : (tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xcomplex <f64 >, #SparseVector >
190
+ %6 = call @cabs (%sv1 )
137
191
: (tensor <?xcomplex <f64 >, #SparseVector >) -> tensor <?xf64 , #SparseVector >
138
192
139
193
//
@@ -157,23 +211,47 @@ module {
157
211
// CHECK-NEXT: -193.43
158
212
// CHECK-NEXT: 57.2184
159
213
call @dumpc (%1 , %d3 ) : (tensor <?xcomplex <f64 >, #SparseVector >, index ) -> ()
214
+ // CHECK-NEXT: 0.433635
215
+ // CHECK-NEXT: 2.30609
216
+ // CHECK-NEXT: 2
217
+ // CHECK-NEXT: 1
218
+ // CHECK-NEXT: 2.53083
219
+ // CHECK-NEXT: 1.18538
220
+ call @dumpc (%2 , %d3 ) : (tensor <?xcomplex <f64 >, #SparseVector >, index ) -> ()
221
+ // CHECK-NEXT: 0.761594
222
+ // CHECK-NEXT: 0
223
+ // CHECK-NEXT: -0.964028
224
+ // CHECK-NEXT: 0
225
+ // CHECK-NEXT: 0.995055
226
+ // CHECK-NEXT: 0
227
+ call @dumpc (%3 , %d3 ) : (tensor <?xcomplex <f64 >, #SparseVector >, index ) -> ()
228
+ // CHECK-NEXT: 1.52361
229
+ // CHECK-NEXT: 2.69061
230
+ // CHECK-NEXT: 1.73287
231
+ // CHECK-NEXT: 0.785398
232
+ // CHECK-NEXT: 2.13833
233
+ // CHECK-NEXT: 0.785398
234
+ call @dumpc (%4 , %d3 ) : (tensor <?xcomplex <f64 >, #SparseVector >, index ) -> ()
160
235
// CHECK-NEXT: -2.565
161
236
// CHECK-NEXT: 1
162
237
// CHECK-NEXT: 1.5
163
238
// CHECK-NEXT: 2
164
239
// CHECK-NEXT: 2.5
165
240
// CHECK-NEXT: 3
166
- call @dumpc (%2 , %d3 ) : (tensor <?xcomplex <f64 >, #SparseVector >, index ) -> ()
241
+ call @dumpc (%5 , %d3 ) : (tensor <?xcomplex <f64 >, #SparseVector >, index ) -> ()
167
242
// CHECK-NEXT: ( 5.50608, 5, 7.81025 )
168
- call @dumpf (%3 ) : (tensor <?xf64 , #SparseVector >) -> ()
243
+ call @dumpf (%6 ) : (tensor <?xf64 , #SparseVector >) -> ()
169
244
170
245
// Release the resources.
171
246
sparse_tensor.release %sv1 : tensor <?xcomplex <f64 >, #SparseVector >
172
247
sparse_tensor.release %sv2 : tensor <?xcomplex <f64 >, #SparseVector >
173
248
sparse_tensor.release %0 : tensor <?xcomplex <f64 >, #SparseVector >
174
249
sparse_tensor.release %1 : tensor <?xcomplex <f64 >, #SparseVector >
175
250
sparse_tensor.release %2 : tensor <?xcomplex <f64 >, #SparseVector >
176
- sparse_tensor.release %3 : tensor <?xf64 , #SparseVector >
251
+ sparse_tensor.release %3 : tensor <?xcomplex <f64 >, #SparseVector >
252
+ sparse_tensor.release %4 : tensor <?xcomplex <f64 >, #SparseVector >
253
+ sparse_tensor.release %5 : tensor <?xcomplex <f64 >, #SparseVector >
254
+ sparse_tensor.release %6 : tensor <?xf64 , #SparseVector >
177
255
return
178
256
}
179
257
}
0 commit comments