@@ -591,7 +591,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
591
591
592
592
BLASLONG nthreads = args -> nthreads ;
593
593
594
- BLASLONG width , i , j , k , js ;
594
+ BLASLONG width , width_n , i , j , k , js ;
595
595
BLASLONG m , n , n_from , n_to ;
596
596
int mode ;
597
597
#if defined(DYNAMIC_ARCH )
@@ -740,18 +740,25 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
740
740
/* Partition (a step of) n into nthreads regions */
741
741
range_N [0 ] = js ;
742
742
num_parts = 0 ;
743
- while (n > 0 ){
744
- width = blas_quickdivide (n + nthreads - num_parts - 1 , nthreads - num_parts );
745
- if (width < switch_ratio ) {
746
- width = switch_ratio ;
743
+ for (j = 0 ; j < nthreads_n ; j ++ ){
744
+ width_n = blas_quickdivide (n + nthreads_n - j - 1 , nthreads_n - j );
745
+ n -= width_n ;
746
+ for (i = 0 ; i < nthreads_m ; i ++ ){
747
+ width = blas_quickdivide (width_n + nthreads_m - i - 1 , nthreads_m - i );
748
+ if (width < switch_ratio ) {
749
+ width = switch_ratio ;
750
+ }
751
+ width = round_up (width_n , width , GEMM_PREFERED_SIZE );
752
+
753
+ width_n -= width ;
754
+ if (width_n < 0 ) {
755
+ width = width + width_n ;
756
+ width_n = 0 ;
757
+ }
758
+ range_N [num_parts + 1 ] = range_N [num_parts ] + width ;
759
+
760
+ num_parts ++ ;
747
761
}
748
- width = round_up (n , width , GEMM_PREFERED_SIZE );
749
-
750
- n -= width ;
751
- if (n < 0 ) width = width + n ;
752
- range_N [num_parts + 1 ] = range_N [num_parts ] + width ;
753
-
754
- num_parts ++ ;
755
762
}
756
763
for (j = num_parts ; j < MAX_CPU_NUMBER ; j ++ ) {
757
764
range_N [j + 1 ] = range_N [num_parts ];
0 commit comments