Skip to content

Commit f27ba5e

Browse files
committed
fix bugs in aarch64 sbgemv_n kernel
1 parent e9fbe0a commit f27ba5e

File tree

1 file changed

+34
-49
lines changed

1 file changed

+34
-49
lines changed

kernel/arm64/sbgemv_n_neon.c

+34-49
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) {
6969
x += 4;
7070
}
7171

72-
if (rest_n & 3) {
73-
x[0] *= beta;
74-
if ((rest_n & 3) > 1)
75-
x[1] *= beta;
76-
if ((rest_n & 3) > 2)
77-
x[2] *= beta;
72+
for (BLASLONG i = 0; i < (rest_n & 3); i ++) {
73+
x[i] *= beta;
7874
}
7975
}
8076
return;
@@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
8884

8985
bfloat16x8_t a0, a1, a2, a3, a4, a5, a6, a7;
9086
bfloat16x8_t t0, t1, t2, t3, t4, t5, t6, t7;
87+
9188
bfloat16x8_t x_vec;
89+
bfloat16x4_t x_vecx4;
90+
9291
float32x4_t y1_vec, y2_vec;
9392
float32x4_t fp32_low, fp32_high;
9493

@@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
106105

107106
if (incx == 1 && incy == 1) {
108107
if (beta != 1) {
109-
beta_op(y, n, beta);
108+
beta_op(y, m, beta);
110109
}
111110

112111
for (i = 0; i < n / 8; i++) {
@@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
290289

291290
a_ptr += 4 * lda;
292291

293-
bfloat16x4_t x_vecx4 = vld1_bf16(x_ptr);
292+
x_vecx4 = vld1_bf16(x_ptr);
294293
if (alpha != 1) {
295-
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
296-
fp32_low = vreinterpretq_f32_u16(
297-
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
298-
vreinterpretq_u16_bf16(x_vec)));
294+
fp32_low = vcvt_f32_bf16(x_vecx4);
299295
fp32_low = vmulq_n_f32(fp32_low, alpha);
300296
x_vecx4 = vcvt_bf16_f32(fp32_low);
301297
}
@@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
348344

349345
y1_vec = vld1q_f32(y_ptr);
350346

351-
a0 = vcombine_bf16(a0x4, bf16_zero);
352-
a1 = vcombine_bf16(a1x4, bf16_zero);
353-
a2 = vcombine_bf16(a2x4, bf16_zero);
354-
a3 = vcombine_bf16(a3x4, bf16_zero);
347+
a0 = vcombine_bf16(a0x4, a2x4);
348+
a1 = vcombine_bf16(a1x4, a3x4);
355349

356-
t0 = vreinterpretq_bf16_u16(
357-
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
358-
t1 = vreinterpretq_bf16_u16(
359-
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
350+
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
351+
t1 = vreinterpretq_bf16_u16(vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
360352

361353
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
362354
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
@@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
374366
}
375367

376368
if (rest_m) {
377-
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
378-
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
379-
x2 = alpha * vcvtah_f32_bf16(x_ptr[2]);
380-
x3 = alpha * vcvtah_f32_bf16(x_ptr[3]);
369+
fp32_low = vcvt_f32_bf16(x_vecx4);
370+
371+
x0 = vgetq_lane_f32(fp32_low, 0);
372+
x1 = vgetq_lane_f32(fp32_low, 1);
373+
x2 = vgetq_lane_f32(fp32_low, 2);
374+
x3 = vgetq_lane_f32(fp32_low, 3);
381375

382376
for (BLASLONG j = 0; j < rest_m; j++) {
383377
y_ptr[j] += x0 * vcvtah_f32_bf16(a_ptr0[j]);
@@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
396390

397391
a_ptr += 2 * lda;
398392

399-
bfloat16_t tmp_buffer[4];
400-
memset((void*)tmp_buffer, 0, sizeof(bfloat16_t));
401-
402-
tmp_buffer[0] = x_ptr[0];
403-
tmp_buffer[1] = x_ptr[1];
393+
x_vecx4 = vreinterpret_bf16_u16(vzip1_u16(
394+
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[0])),
395+
vreinterpret_u16_bf16(vdup_n_bf16(x_ptr[1]))
396+
));
404397

405-
bfloat16x4_t x_vecx4 = vld1_bf16(tmp_buffer);
406398
if (alpha != 1) {
407-
x_vec = vcombine_bf16(x_vecx4, bf16_zero);
408-
fp32_low = vreinterpretq_f32_u16(
409-
vzip1q_u16(vreinterpretq_u16_bf16(bf16_zero_q),
410-
vreinterpretq_u16_bf16(x_vec)));
399+
fp32_low = vcvt_f32_bf16(x_vecx4);
411400
fp32_low = vmulq_n_f32(fp32_low, alpha);
412401
x_vecx4 = vcvt_bf16_f32(fp32_low);
413402
}
@@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
422411

423412
t0 = vreinterpretq_bf16_u16(
424413
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
425-
t4 = vreinterpretq_bf16_u16(
414+
t1 = vreinterpretq_bf16_u16(
426415
vzip2q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
427416

428417
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
429418
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
430419

431-
y2_vec = vbfmlalbq_lane_f32(y2_vec, t4, x_vecx4, 0);
432-
y2_vec = vbfmlaltq_lane_f32(y2_vec, t4, x_vecx4, 1);
420+
y2_vec = vbfmlalbq_lane_f32(y2_vec, t1, x_vecx4, 0);
421+
y2_vec = vbfmlaltq_lane_f32(y2_vec, t1, x_vecx4, 1);
433422

434423
vst1q_f32(y_ptr, y1_vec);
435424
vst1q_f32(y_ptr + 4, y2_vec);
@@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
449438
a0 = vcombine_bf16(a0x4, bf16_zero);
450439
a1 = vcombine_bf16(a1x4, bf16_zero);
451440

452-
t0 = vreinterpretq_bf16_u16(
453-
vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
454-
t1 = vreinterpretq_bf16_u16(
455-
vzip1q_u16(vreinterpretq_u16_bf16(a2), vreinterpretq_u16_bf16(a3)));
441+
t0 = vreinterpretq_bf16_u16(vzip1q_u16(vreinterpretq_u16_bf16(a0), vreinterpretq_u16_bf16(a1)));
456442

457443
y1_vec = vbfmlalbq_lane_f32(y1_vec, t0, x_vecx4, 0);
458444
y1_vec = vbfmlaltq_lane_f32(y1_vec, t0, x_vecx4, 1);
459-
y1_vec = vbfmlalbq_lane_f32(y1_vec, t1, x_vecx4, 2);
460-
y1_vec = vbfmlaltq_lane_f32(y1_vec, t1, x_vecx4, 3);
461445

462446
vst1q_f32(y_ptr, y1_vec);
463447

464448
a_ptr0 += 4;
465449
a_ptr1 += 4;
466-
a_ptr2 += 4;
467-
a_ptr3 += 4;
468450

469451
y_ptr += 4;
470452
}
471453

472454
if (m & 2) {
473-
x0 = alpha * (vcvtah_f32_bf16(x_ptr[0]));
474-
x1 = alpha * (vcvtah_f32_bf16(x_ptr[1]));
455+
fp32_low = vcvt_f32_bf16(x_vecx4);
456+
x0 = vgetq_lane_f32(fp32_low, 0);
457+
x1 = vgetq_lane_f32(fp32_low, 1);
458+
475459

476460
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
477461
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);
@@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
485469
}
486470

487471
if (m & 1) {
488-
x0 = alpha * vcvtah_f32_bf16(x_ptr[0]);
489-
x1 = alpha * vcvtah_f32_bf16(x_ptr[1]);
472+
fp32_low = vcvt_f32_bf16(x_vecx4);
473+
x0 = vgetq_lane_f32(fp32_low, 0);
474+
x1 = vgetq_lane_f32(fp32_low, 1);
490475

491476
y_ptr[0] += x0 * vcvtah_f32_bf16(a_ptr0[0]);
492477
y_ptr[0] += x1 * vcvtah_f32_bf16(a_ptr1[0]);

0 commit comments

Comments
 (0)