Skip to content

Commit f3065a0

Browse files
authored
Fix race conditions in multithreaded GEMM3M
by adding barriers (and a mutex lock for the non-OpenMP case) like it was already done for GEMM in level3_thread.c some time ago
1 parent 7887c45 commit f3065a0

File tree

1 file changed

+42
-11
lines changed

1 file changed

+42
-11
lines changed

driver/level3/level3_gemm3m_thread.c

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
408408

409409
/* Make sure if no one is using another buffer */
410410
for (i = 0; i < args -> nthreads; i++)
411-
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
411+
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};
412412

413413
STOP_RPCC(waiting1);
414414

@@ -441,7 +441,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
441441

442442
for (i = 0; i < args -> nthreads; i++)
443443
job[mypos].working[i][CACHE_LINE_SIZE * bufferside] = (BLASLONG)buffer[bufferside];
444-
}
444+
WMB;
445+
}
445446

446447
current = mypos;
447448

@@ -458,7 +459,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
458459
START_RPCC();
459460

460461
/* thread has to wait */
461-
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
462+
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};
462463

463464
STOP_RPCC(waiting2);
464465

@@ -477,6 +478,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
477478

478479
if (m_to - m_from == min_i) {
479480
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
481+
WMB;
480482
}
481483
}
482484
} while (current != mypos);
@@ -517,6 +519,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
517519
if (is + min_i >= m_to) {
518520
/* Thread doesn't need this buffer any more */
519521
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
522+
WMB;
520523
}
521524
}
522525

@@ -541,7 +544,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
541544

542545
/* Make sure if no one is using another buffer */
543546
for (i = 0; i < args -> nthreads; i++)
544-
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
547+
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};
545548

546549
STOP_RPCC(waiting1);
547550

@@ -595,7 +598,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
595598
START_RPCC();
596599

597600
/* thread has to wait */
598-
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
601+
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};
599602

600603
STOP_RPCC(waiting2);
601604

@@ -613,6 +616,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
613616

614617
if (m_to - m_from == min_i) {
615618
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
619+
WMB;
616620
}
617621
}
618622
} while (current != mypos);
@@ -677,7 +681,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
677681

678682
/* Make sure if no one is using another buffer */
679683
for (i = 0; i < args -> nthreads; i++)
680-
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
684+
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};
681685

682686
STOP_RPCC(waiting1);
683687

@@ -731,7 +735,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
731735
START_RPCC();
732736

733737
/* thread has to wait */
734-
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
738+
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};
735739

736740
STOP_RPCC(waiting2);
737741

@@ -748,8 +752,9 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
748752
}
749753

750754
if (m_to - m_from == min_i) {
751-
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
752-
}
755+
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] &= 0;
756+
WMB;
757+
}
753758
}
754759
} while (current != mypos);
755760

@@ -787,7 +792,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
787792
#endif
788793
if (is + min_i >= m_to) {
789794
/* Thread doesn't need this buffer any more */
790-
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
795+
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] &= 0;
796+
WMB;
791797
}
792798
}
793799

@@ -804,7 +810,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
804810

805811
for (i = 0; i < args -> nthreads; i++) {
806812
for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
807-
while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {YIELDING;};
813+
while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {YIELDING;MB;};
808814
}
809815
}
810816

@@ -840,6 +846,15 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
840846
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
841847
*range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
842848

849+
#ifndef USE_OPENMP
850+
#ifndef OS_WINDOWS
851+
static pthread_mutex_t level3_lock = PTHREAD_MUTEX_INITIALIZER;
852+
#else
853+
CRITICAL_SECTION level3_lock;
854+
InitializeCriticalSection((PCRITICAL_SECTION)&level3_lock);
855+
#endif
856+
#endif
857+
843858
blas_arg_t newarg;
844859

845860
blas_queue_t queue[MAX_CPU_NUMBER];
@@ -869,6 +884,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
869884
mode = BLAS_SINGLE | BLAS_REAL | BLAS_NODE;
870885
#endif
871886

887+
#ifndef USE_OPENMP
888+
#ifndef OS_WINDOWS
889+
pthread_mutex_lock(&level3_lock);
890+
#else
891+
EnterCriticalSection((PCRITICAL_SECTION)&level3_lock);
892+
#endif
893+
#endif
894+
872895
newarg.m = args -> m;
873896
newarg.n = args -> n;
874897
newarg.k = args -> k;
@@ -973,6 +996,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
973996
free(job);
974997
#endif
975998

999+
#ifndef USE_OPENMP
1000+
#ifndef OS_WINDOWS
1001+
pthread_mutex_unlock(&level3_lock);
1002+
#else
1003+
LeaveCriticalSection((PCRITICAL_SECTION)&level3_lock);
1004+
#endif
1005+
#endif
1006+
9761007
return 0;
9771008
}
9781009

0 commit comments

Comments
 (0)