Skip to content

Commit 37b8547

Browse files
authored
Merge pull request #5173 from nakagawa-fj/gemm_load_imbalance
Improving Load Imbalance in Thread-Parallel GEMM
2 parents a3e7b16 + 80d3c2a commit 37b8547

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

driver/level3/level3_thread.c

+19-12
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
591591

592592
BLASLONG nthreads = args -> nthreads;
593593

594-
BLASLONG width, i, j, k, js;
594+
BLASLONG width, width_n, i, j, k, js;
595595
BLASLONG m, n, n_from, n_to;
596596
int mode;
597597
#if defined(DYNAMIC_ARCH)
@@ -740,18 +740,25 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
740740
/* Partition (a step of) n into nthreads regions */
741741
range_N[0] = js;
742742
num_parts = 0;
743-
while (n > 0){
744-
width = blas_quickdivide(n + nthreads - num_parts - 1, nthreads - num_parts);
745-
if (width < switch_ratio) {
746-
width = switch_ratio;
743+
for(j = 0; j < nthreads_n; j++){
744+
width_n = blas_quickdivide(n + nthreads_n - j - 1, nthreads_n - j);
745+
n -= width_n;
746+
for(i = 0; i < nthreads_m; i++){
747+
width = blas_quickdivide(width_n + nthreads_m - i - 1, nthreads_m - i);
748+
if (width < switch_ratio) {
749+
width = switch_ratio;
750+
}
751+
width = round_up(width_n, width, GEMM_PREFERED_SIZE);
752+
753+
width_n -= width;
754+
if (width_n < 0) {
755+
width = width + width_n;
756+
width_n = 0;
757+
}
758+
range_N[num_parts + 1] = range_N[num_parts] + width;
759+
760+
num_parts ++;
747761
}
748-
width = round_up(n, width, GEMM_PREFERED_SIZE);
749-
750-
n -= width;
751-
if (n < 0) width = width + n;
752-
range_N[num_parts + 1] = range_N[num_parts] + width;
753-
754-
num_parts ++;
755762
}
756763
for (j = num_parts; j < MAX_CPU_NUMBER; j++) {
757764
range_N[j + 1] = range_N[num_parts];

0 commit comments

Comments
 (0)