Skip to content

Commit e2e6a4d

Browse files
authored
Merge pull request #5276 from nakagawa-fj/gemm_2d_thread_partitioning
Improvement of 2D thread-partitioned GEMM for M << N case
2 parents 9ef5995 + 2351a98 commit e2e6a4d

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

driver/level3/level3_thread.c

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,9 +851,19 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IF
851851
/* Objective function come from sum of partitions in m and n. */
852852
/* (n / nthreads_n) + (m / nthreads_m) */
853853
/* = (n * nthreads_m + m * nthreads_n) / (nthreads_n * nthreads_m) */
854-
while (nthreads_m % 2 == 0 && n * nthreads_m + m * nthreads_n > n * (nthreads_m / 2) + m * (nthreads_n * 2)) {
855-
nthreads_m /= 2;
856-
nthreads_n *= 2;
854+
BLASLONG cost = 0, div = 0;
855+
for (BLASLONG i = 1; i <= sqrt(nthreads_m); i++) {
856+
if (nthreads_m % i) continue;
857+
BLASLONG j = nthreads_m / i;
858+
BLASLONG cost_i = n * j + m * nthreads_n * i;
859+
BLASLONG cost_j = n * i + m * nthreads_n * j;
860+
if (cost == 0 ||
861+
cost_i < cost) {cost = cost_i; div = i;}
862+
if (cost_j < cost) {cost = cost_j; div = j;}
863+
}
864+
if (div > 1) {
865+
nthreads_m /= div;
866+
nthreads_n *= div;
857867
}
858868
}
859869

0 commit comments

Comments
 (0)