Skip to content

Commit edaf51d

Browse files
committed
Add sbgemv_t_bfdot kernel for ARM64
This improves performance for sbgemv_t by up to 100x on NEOVERSEV1. The geometric mean speedup is ~61x for M=N=[2,512].
1 parent ef9e3f7 commit edaf51d

File tree

5 files changed

+214
-0
lines changed

5 files changed

+214
-0
lines changed

CONTRIBUTORS.md

+1
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ In chronological order:
236236
* Annop Wongwathanarat <annop.wongwathanarat@arm.com>
237237
* [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1
238238
* [2025-01-21] Optimize gemv_t_sve_v1x3 kernel
239+
* [2025-02-26] Add sbgemv_t_bfdot kernel
239240

240241
* Marek Michalowski <marek.michalowski@arm.com>
241242
* [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1`

kernel/arm64/KERNEL.NEOVERSEN2

+1
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX)
198198
SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX)
199199
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
200200
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
201+
SBGEMVTKERNEL = sbgemv_t_bfdot.c

kernel/arm64/KERNEL.NEOVERSEV1

+1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ SBGEMMONCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversev1.c
1515
SBGEMMOTCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_N)_neoversev1.c
1616
SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX)
1717
SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX)
18+
SBGEMVTKERNEL = sbgemv_t_bfdot.c
1819
endif

kernel/arm64/KERNEL.NEOVERSEV2

+4
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
2+
3+
ifeq ($(BUILD_BFLOAT16), 1)
4+
SBGEMVTKERNEL = sbgemv_t_bfdot.c
5+
endif

kernel/arm64/sbgemv_t_bfdot.c

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/***************************************************************************
2+
Copyright (c) 2025, The OpenBLAS Project
3+
All rights reserved.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions are
7+
met:
8+
9+
1. Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in
14+
the documentation and/or other materials provided with the
15+
distribution.
16+
3. Neither the name of the OpenBLAS project nor the names of
17+
its contributors may be used to endorse or promote products
18+
derived from this software without specific prior written
19+
permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
25+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
30+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
*****************************************************************************/
32+
33+
#include <arm_neon.h>
34+
#include "common.h"
35+
36+
static inline float bf16_to_fp32(bfloat16 bf16) {
37+
uint32_t fp32 = (uint32_t)bf16 << 16;
38+
return *((float*)&fp32);
39+
}
40+
41+
int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy)
42+
{
43+
if (m < 1 || n < 1) return(0);
44+
BLASLONG i;
45+
BLASLONG ix,iy;
46+
BLASLONG j;
47+
bfloat16_t *a_ptr;
48+
bfloat16_t *x_ptr;
49+
float *y_ptr;
50+
float temp;
51+
52+
iy = 0;
53+
a_ptr = (bfloat16_t*)(a);
54+
x_ptr = (bfloat16_t*)(x);
55+
56+
if (incx == 1) {
57+
BLASLONG width = n / 4;
58+
59+
bfloat16_t *a0_ptr = a_ptr + lda * width * 0;
60+
bfloat16_t *a1_ptr = a_ptr + lda * width * 1;
61+
bfloat16_t *a2_ptr = a_ptr + lda * width * 2;
62+
bfloat16_t *a3_ptr = a_ptr + lda * width * 3;
63+
64+
float *y0_ptr = y + incy * width * 0;
65+
float *y1_ptr = y + incy * width * 1;
66+
float *y2_ptr = y + incy * width * 2;
67+
float *y3_ptr = y + incy * width * 3;
68+
69+
for (j = 0; j < width; j++) {
70+
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
71+
float32x4_t temp1_vec = vdupq_n_f32(0.0f);
72+
float32x4_t temp2_vec = vdupq_n_f32(0.0f);
73+
float32x4_t temp3_vec = vdupq_n_f32(0.0f);
74+
75+
i = 0;
76+
while (i + 7 < m) {
77+
bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i);
78+
79+
bfloat16x8_t a0_vec = vld1q_bf16(a0_ptr + i);
80+
bfloat16x8_t a1_vec = vld1q_bf16(a1_ptr + i);
81+
bfloat16x8_t a2_vec = vld1q_bf16(a2_ptr + i);
82+
bfloat16x8_t a3_vec = vld1q_bf16(a3_ptr + i);
83+
84+
temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec);
85+
temp1_vec = vbfdotq_f32(temp1_vec, a1_vec, x_vec);
86+
temp2_vec = vbfdotq_f32(temp2_vec, a2_vec, x_vec);
87+
temp3_vec = vbfdotq_f32(temp3_vec, a3_vec, x_vec);
88+
89+
i += 8;
90+
}
91+
if (i + 3 < m) {
92+
float32x2_t t0 = vdup_n_f32(0.0f);
93+
float32x2_t t1 = vdup_n_f32(0.0f);
94+
float32x2_t t2 = vdup_n_f32(0.0f);
95+
float32x2_t t3 = vdup_n_f32(0.0f);
96+
97+
bfloat16x4_t x_vec = vld1_bf16(x_ptr + i);
98+
99+
bfloat16x4_t a0_vec = vld1_bf16(a0_ptr + i);
100+
bfloat16x4_t a1_vec = vld1_bf16(a1_ptr + i);
101+
bfloat16x4_t a2_vec = vld1_bf16(a2_ptr + i);
102+
bfloat16x4_t a3_vec = vld1_bf16(a3_ptr + i);
103+
104+
t0 = vbfdot_f32(t0, a0_vec, x_vec);
105+
t1 = vbfdot_f32(t1, a1_vec, x_vec);
106+
t2 = vbfdot_f32(t2, a2_vec, x_vec);
107+
t3 = vbfdot_f32(t3, a3_vec, x_vec);
108+
109+
float32x2_t temp0_vec_low = vget_low_f32(temp0_vec);
110+
float32x2_t temp1_vec_low = vget_low_f32(temp1_vec);
111+
float32x2_t temp2_vec_low = vget_low_f32(temp2_vec);
112+
float32x2_t temp3_vec_low = vget_low_f32(temp3_vec);
113+
114+
temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec));
115+
temp1_vec = vcombine_f32(vadd_f32(t1, temp1_vec_low), vget_high_f32(temp1_vec));
116+
temp2_vec = vcombine_f32(vadd_f32(t2, temp2_vec_low), vget_high_f32(temp2_vec));
117+
temp3_vec = vcombine_f32(vadd_f32(t3, temp3_vec_low), vget_high_f32(temp3_vec));
118+
119+
i += 4;
120+
}
121+
if (beta == 0.0f) {
122+
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
123+
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec);
124+
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec);
125+
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec);
126+
}
127+
else {
128+
y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy];
129+
y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy];
130+
y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy];
131+
y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy];
132+
}
133+
134+
for (; i < m; ++i) {
135+
y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i];
136+
y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i];
137+
y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i];
138+
y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i];
139+
}
140+
141+
iy += incy;
142+
143+
a0_ptr += lda;
144+
a1_ptr += lda;
145+
a2_ptr += lda;
146+
a3_ptr += lda;
147+
}
148+
149+
a_ptr = a3_ptr;
150+
y_ptr = y3_ptr;
151+
for (j = width * 4; j < n; j++) {
152+
float32x4_t temp0_vec = vdupq_n_f32(0.0f);
153+
i = 0;
154+
while (i + 7 < m) {
155+
bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i);
156+
bfloat16x8_t a0_vec = vld1q_bf16(a_ptr + i);
157+
temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec);
158+
159+
i += 8;
160+
}
161+
if (i + 3 < m) {
162+
float32x2_t t0 = vdup_n_f32(0.0f);
163+
bfloat16x4_t x_vec = vld1_bf16(x_ptr + i);
164+
bfloat16x4_t a0_vec = vld1_bf16(a_ptr + i);
165+
166+
t0 = vbfdot_f32(t0, a0_vec, x_vec);
167+
float32x2_t temp0_vec_low = vget_low_f32(temp0_vec);
168+
temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec));
169+
170+
i += 4;
171+
}
172+
if (beta == 0.0f) {
173+
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec);
174+
}
175+
else {
176+
y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy];
177+
}
178+
179+
for (; i < m; ++i) {
180+
y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i];
181+
}
182+
183+
iy += incy;
184+
185+
a_ptr += lda;
186+
}
187+
return(0);
188+
}
189+
190+
for (j = 0; j < n; j++) {
191+
temp = 0.0;
192+
ix = 0;
193+
for (i = 0; i < m; i++) {
194+
temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]);
195+
ix += incx;
196+
}
197+
if (beta == 0.0f) {
198+
y[iy] = alpha * temp;
199+
}
200+
else {
201+
y[iy] = alpha * temp + beta * y[iy];
202+
}
203+
iy += incy;
204+
a += lda;
205+
}
206+
return (0);
207+
}

0 commit comments

Comments
 (0)