-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
Copy pathsgbtrf.c
240 lines (202 loc) · 7.24 KB
/
sgbtrf.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#include "relapack.h"
#include "stdlib.h"
static void RELAPACK_sgbtrf_rec(const blasint *, const blasint *, const blasint *,
const blasint *, float *, const blasint *, blasint *, float *, const blasint *, float *,
const blasint *, blasint *);
/** SGBTRF computes an LU factorization of a real m-by-n band matrix A using partial pivoting with row interchanges.
*
* This routine is functionally equivalent to LAPACK's sgbtrf.
* For details on its interface, see
* http://www.netlib.org/lapack/explore-html/d5/d72/sgbtrf_8f.html
* */
void RELAPACK_sgbtrf(
const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
float *Ab, const blasint *ldAb, blasint *ipiv,
blasint *info
) {
// Check arguments
*info = 0;
if (*m < 0)
*info = -1;
else if (*n < 0)
*info = -2;
else if (*kl < 0)
*info = -3;
else if (*ku < 0)
*info = -4;
else if (*ldAb < 2 * *kl + *ku + 1)
*info = -6;
if (*info) {
const blasint minfo = -*info;
LAPACK(xerbla)("SGBTRF", &minfo, strlen("SGBTRF"));
return;
}
if (*m == 0 || *n == 0) return;
if (*ldAb == 1) {
LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
return;
}
// Constant
const float ZERO[] = { 0. };
// Result upper band width
const blasint kv = *ku + *kl;
// Unskewg A
const blasint ldA[] = { *ldAb - 1 };
float *const A = Ab + kv;
// Zero upper diagonal fill-in elements
blasint i, j;
for (j = 0; j < *n; j++) {
float *const A_j = A + *ldA * j;
for (i = MAX(0, j - kv); i < j - *ku; i++)
A_j[i] = 0.;
}
// Allocate work space
const blasint n1 = SREC_SPLIT(*n);
const blasint mWorkl = abs( (kv > n1) ? MAX(1, *m - *kl) : kv );
const blasint nWorkl = abs( (kv > n1) ? n1 : kv );
const blasint mWorku = abs( (*kl > n1) ? n1 : *kl );
const blasint nWorku = abs( (*kl > n1) ? MAX(0, *n - *kl) : *kl );
float *Workl = malloc(mWorkl * nWorkl * sizeof(float));
float *Worku = malloc(mWorku * nWorku * sizeof(float));
LAPACK(slaset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
LAPACK(slaset)("U", &mWorku, &nWorku, ZERO, ZERO, Worku, &mWorku);
// Recursive kernel
RELAPACK_sgbtrf_rec(m, n, kl, ku, Ab, ldAb, ipiv, Workl, &mWorkl, Worku, &mWorku, info);
// Free work space
free(Workl);
free(Worku);
}
/** sgbtrf's recursive compute kernel */
static void RELAPACK_sgbtrf_rec(
const blasint *m, const blasint *n, const blasint *kl, const blasint *ku,
float *Ab, const blasint *ldAb, blasint *ipiv,
float *Workl, const blasint *ldWorkl, float *Worku, const blasint *ldWorku,
blasint *info
) {
if (*m == 0 || *n == 0) return;
if ( *n <= MAX(CROSSOVER_SGBTRF, 1) || *n > *kl || *ldAb == 1) {
// Unblocked
LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
return;
}
// Constants
const float ONE[] = { 1. };
const float MONE[] = { -1. };
const blasint iONE[] = { 1 };
// Loop iterators
blasint i, j;
// Output upper band width
const blasint kv = *ku + *kl;
// Unskew A
const blasint ldA[] = { *ldAb - 1 };
float *const A = Ab + kv;
// Splitting
const blasint n1 = MIN(SREC_SPLIT(*n), *kl);
const blasint n2 = *n - n1;
const blasint m1 = MIN(n1, *m);
const blasint m2 = *m - m1;
const blasint mn1 = MIN(m1, n1);
const blasint mn2 = MIN(m2, n2);
// Ab_L *
// Ab_BR
float *const Ab_L = Ab;
float *const Ab_BR = Ab + *ldAb * n1;
// A_L A_R
float *const A_L = A;
float *const A_R = A + *ldA * n1;
// A_TL A_TR
// A_BL A_BR
float *const A_TL = A;
float *const A_TR = A + *ldA * n1;
float *const A_BL = A + m1;
float *const A_BR = A + *ldA * n1 + m1;
// ipiv_T
// ipiv_B
blasint *const ipiv_T = ipiv;
blasint *const ipiv_B = ipiv + n1;
// Banded splitting
const blasint n21 = MIN(n2, kv - n1);
const blasint n22 = MIN(n2 - n21, n1);
const blasint m21 = MIN(m2, *kl - m1);
const blasint m22 = MIN(m2 - m21, m1);
// n1 n21 n22
// m * A_Rl ARr
float *const A_Rl = A_R;
float *const A_Rr = A_R + *ldA * n21;
// n1 n21 n22
// m1 * A_TRl A_TRr
// m21 A_BLt A_BRtl A_BRtr
// m22 A_BLb A_BRbl A_BRbr
float *const A_TRl = A_TR;
float *const A_TRr = A_TR + *ldA * n21;
float *const A_BLt = A_BL;
float *const A_BLb = A_BL + m21;
float *const A_BRtl = A_BR;
float *const A_BRtr = A_BR + *ldA * n21;
float *const A_BRbl = A_BR + m21;
float *const A_BRbr = A_BR + *ldA * n21 + m21;
// recursion(Ab_L, ipiv_T)
RELAPACK_sgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
if (*info) return;
// Workl = A_BLb
LAPACK(slacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
// partially redo swaps in A_L
for (i = 0; i < mn1; i++) {
const blasint ip = ipiv_T[i] - 1;
if (ip != i) {
if (ip < *kl)
BLAS(sswap)(&i, A_L + i, ldA, A_L + ip, ldA);
else
BLAS(sswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
}
}
// apply pivots to A_Rl
LAPACK(slaswp)(&n21, A_Rl, ldA, iONE, &mn1, ipiv_T, iONE);
// apply pivots to A_Rr columnwise
for (j = 0; j < n22; j++) {
float *const A_Rrj = A_Rr + *ldA * j;
for (i = j; i < mn1; i++) {
const blasint ip = ipiv_T[i] - 1;
if (ip != i) {
const float tmp = A_Rrj[i];
A_Rrj[i] = A_Rr[ip];
A_Rrj[ip] = tmp;
}
}
}
// A_TRl = A_TL \ A_TRl
BLAS(strsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
// Worku = A_TRr
LAPACK(slacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
// Worku = A_TL \ Worku
BLAS(strsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
// A_TRr = Worku
LAPACK(slacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
// A_BRtl = A_BRtl - A_BLt * A_TRl
BLAS(sgemm)("N", "N", &m21, &n21, &n1, MONE, A_BLt, ldA, A_TRl, ldA, ONE, A_BRtl, ldA);
// A_BRbl = A_BRbl - Workl * A_TRl
BLAS(sgemm)("N", "N", &m22, &n21, &n1, MONE, Workl, ldWorkl, A_TRl, ldA, ONE, A_BRbl, ldA);
// A_BRtr = A_BRtr - A_BLt * Worku
BLAS(sgemm)("N", "N", &m21, &n22, &n1, MONE, A_BLt, ldA, Worku, ldWorku, ONE, A_BRtr, ldA);
// A_BRbr = A_BRbr - Workl * Worku
BLAS(sgemm)("N", "N", &m22, &n22, &n1, MONE, Workl, ldWorkl, Worku, ldWorku, ONE, A_BRbr, ldA);
// partially undo swaps in A_L
for (i = mn1 - 1; i >= 0; i--) {
const blasint ip = ipiv_T[i] - 1;
if (ip != i) {
if (ip < *kl)
BLAS(sswap)(&i, A_L + i, ldA, A_L + ip, ldA);
else
BLAS(sswap)(&i, A_L + i, ldA, Workl + ip - *kl, ldWorkl);
}
}
// recursion(Ab_BR, ipiv_B)
//cause of infinite recursion here ?
RELAPACK_sgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
// LAPACK(sgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
if (*info)
*info += n1;
// shift pivots
for (i = 0; i < mn2; i++)
ipiv_B[i] += n1;
}