@@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) {
69
69
x += 4 ;
70
70
}
71
71
72
- if (rest_n & 3 ) {
73
- x [0 ] *= beta ;
74
- if ((rest_n & 3 ) > 1 )
75
- x [1 ] *= beta ;
76
- if ((rest_n & 3 ) > 2 )
77
- x [2 ] *= beta ;
72
+ for (BLASLONG i = 0 ; i < (rest_n & 3 ); i ++ ) {
73
+ x [i ] *= beta ;
78
74
}
79
75
}
80
76
return ;
@@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
88
84
89
85
bfloat16x8_t a0 , a1 , a2 , a3 , a4 , a5 , a6 , a7 ;
90
86
bfloat16x8_t t0 , t1 , t2 , t3 , t4 , t5 , t6 , t7 ;
87
+
91
88
bfloat16x8_t x_vec ;
89
+ bfloat16x4_t x_vecx4 ;
90
+
92
91
float32x4_t y1_vec , y2_vec ;
93
92
float32x4_t fp32_low , fp32_high ;
94
93
@@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
106
105
107
106
if (incx == 1 && incy == 1 ) {
108
107
if (beta != 1 ) {
109
- beta_op (y , n , beta );
108
+ beta_op (y , m , beta );
110
109
}
111
110
112
111
for (i = 0 ; i < n / 8 ; i ++ ) {
@@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
290
289
291
290
a_ptr += 4 * lda ;
292
291
293
- bfloat16x4_t x_vecx4 = vld1_bf16 (x_ptr );
292
+ x_vecx4 = vld1_bf16 (x_ptr );
294
293
if (alpha != 1 ) {
295
- x_vec = vcombine_bf16 (x_vecx4 , bf16_zero );
296
- fp32_low = vreinterpretq_f32_u16 (
297
- vzip1q_u16 (vreinterpretq_u16_bf16 (bf16_zero_q ),
298
- vreinterpretq_u16_bf16 (x_vec )));
294
+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
299
295
fp32_low = vmulq_n_f32 (fp32_low , alpha );
300
296
x_vecx4 = vcvt_bf16_f32 (fp32_low );
301
297
}
@@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
348
344
349
345
y1_vec = vld1q_f32 (y_ptr );
350
346
351
- a0 = vcombine_bf16 (a0x4 , bf16_zero );
352
- a1 = vcombine_bf16 (a1x4 , bf16_zero );
353
- a2 = vcombine_bf16 (a2x4 , bf16_zero );
354
- a3 = vcombine_bf16 (a3x4 , bf16_zero );
347
+ a0 = vcombine_bf16 (a0x4 , a2x4 );
348
+ a1 = vcombine_bf16 (a1x4 , a3x4 );
355
349
356
- t0 = vreinterpretq_bf16_u16 (
357
- vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
358
- t1 = vreinterpretq_bf16_u16 (
359
- vzip1q_u16 (vreinterpretq_u16_bf16 (a2 ), vreinterpretq_u16_bf16 (a3 )));
350
+ t0 = vreinterpretq_bf16_u16 (vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
351
+ t1 = vreinterpretq_bf16_u16 (vzip2q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
360
352
361
353
y1_vec = vbfmlalbq_lane_f32 (y1_vec , t0 , x_vecx4 , 0 );
362
354
y1_vec = vbfmlaltq_lane_f32 (y1_vec , t0 , x_vecx4 , 1 );
@@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
374
366
}
375
367
376
368
if (rest_m ) {
377
- x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
378
- x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
379
- x2 = alpha * vcvtah_f32_bf16 (x_ptr [2 ]);
380
- x3 = alpha * vcvtah_f32_bf16 (x_ptr [3 ]);
369
+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
370
+
371
+ x0 = vgetq_lane_f32 (fp32_low , 0 );
372
+ x1 = vgetq_lane_f32 (fp32_low , 1 );
373
+ x2 = vgetq_lane_f32 (fp32_low , 2 );
374
+ x3 = vgetq_lane_f32 (fp32_low , 3 );
381
375
382
376
for (BLASLONG j = 0 ; j < rest_m ; j ++ ) {
383
377
y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
@@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
396
390
397
391
a_ptr += 2 * lda ;
398
392
399
- bfloat16_t tmp_buffer [4 ];
400
- memset ((void * )tmp_buffer , 0 , sizeof (bfloat16_t ));
401
-
402
- tmp_buffer [0 ] = x_ptr [0 ];
403
- tmp_buffer [1 ] = x_ptr [1 ];
393
+ x_vecx4 = vreinterpret_bf16_u16 (vzip1_u16 (
394
+ vreinterpret_u16_bf16 (vdup_n_bf16 (x_ptr [0 ])),
395
+ vreinterpret_u16_bf16 (vdup_n_bf16 (x_ptr [1 ]))
396
+ ));
404
397
405
- bfloat16x4_t x_vecx4 = vld1_bf16 (tmp_buffer );
406
398
if (alpha != 1 ) {
407
- x_vec = vcombine_bf16 (x_vecx4 , bf16_zero );
408
- fp32_low = vreinterpretq_f32_u16 (
409
- vzip1q_u16 (vreinterpretq_u16_bf16 (bf16_zero_q ),
410
- vreinterpretq_u16_bf16 (x_vec )));
399
+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
411
400
fp32_low = vmulq_n_f32 (fp32_low , alpha );
412
401
x_vecx4 = vcvt_bf16_f32 (fp32_low );
413
402
}
@@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
422
411
423
412
t0 = vreinterpretq_bf16_u16 (
424
413
vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
425
- t4 = vreinterpretq_bf16_u16 (
414
+ t1 = vreinterpretq_bf16_u16 (
426
415
vzip2q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
427
416
428
417
y1_vec = vbfmlalbq_lane_f32 (y1_vec , t0 , x_vecx4 , 0 );
429
418
y1_vec = vbfmlaltq_lane_f32 (y1_vec , t0 , x_vecx4 , 1 );
430
419
431
- y2_vec = vbfmlalbq_lane_f32 (y2_vec , t4 , x_vecx4 , 0 );
432
- y2_vec = vbfmlaltq_lane_f32 (y2_vec , t4 , x_vecx4 , 1 );
420
+ y2_vec = vbfmlalbq_lane_f32 (y2_vec , t1 , x_vecx4 , 0 );
421
+ y2_vec = vbfmlaltq_lane_f32 (y2_vec , t1 , x_vecx4 , 1 );
433
422
434
423
vst1q_f32 (y_ptr , y1_vec );
435
424
vst1q_f32 (y_ptr + 4 , y2_vec );
@@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
449
438
a0 = vcombine_bf16 (a0x4 , bf16_zero );
450
439
a1 = vcombine_bf16 (a1x4 , bf16_zero );
451
440
452
- t0 = vreinterpretq_bf16_u16 (
453
- vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
454
- t1 = vreinterpretq_bf16_u16 (
455
- vzip1q_u16 (vreinterpretq_u16_bf16 (a2 ), vreinterpretq_u16_bf16 (a3 )));
441
+ t0 = vreinterpretq_bf16_u16 (vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
456
442
457
443
y1_vec = vbfmlalbq_lane_f32 (y1_vec , t0 , x_vecx4 , 0 );
458
444
y1_vec = vbfmlaltq_lane_f32 (y1_vec , t0 , x_vecx4 , 1 );
459
- y1_vec = vbfmlalbq_lane_f32 (y1_vec , t1 , x_vecx4 , 2 );
460
- y1_vec = vbfmlaltq_lane_f32 (y1_vec , t1 , x_vecx4 , 3 );
461
445
462
446
vst1q_f32 (y_ptr , y1_vec );
463
447
464
448
a_ptr0 += 4 ;
465
449
a_ptr1 += 4 ;
466
- a_ptr2 += 4 ;
467
- a_ptr3 += 4 ;
468
450
469
451
y_ptr += 4 ;
470
452
}
471
453
472
454
if (m & 2 ) {
473
- x0 = alpha * (vcvtah_f32_bf16 (x_ptr [0 ]));
474
- x1 = alpha * (vcvtah_f32_bf16 (x_ptr [1 ]));
455
+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
456
+ x0 = vgetq_lane_f32 (fp32_low , 0 );
457
+ x1 = vgetq_lane_f32 (fp32_low , 1 );
458
+
475
459
476
460
y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
477
461
y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
@@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
485
469
}
486
470
487
471
if (m & 1 ) {
488
- x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
489
- x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
472
+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
473
+ x0 = vgetq_lane_f32 (fp32_low , 0 );
474
+ x1 = vgetq_lane_f32 (fp32_low , 1 );
490
475
491
476
y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
492
477
y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
0 commit comments