@@ -840,6 +840,7 @@ static void _common_ucx_tls_cleanup(_tlocal_table_t *tls)
840
840
size = tls -> ctx_tbl_size ;
841
841
for (i = 0 ; i < size ; i ++ ) {
842
842
if (NULL != tls -> ctx_tbl [i ]-> gctx ){
843
+ assert (tls -> ctx_tbl [i ]-> refcnt == 0 );
843
844
_tlocal_ctx_record_cleanup (tls -> ctx_tbl [i ]);
844
845
}
845
846
free (tls -> ctx_tbl [i ]);
@@ -909,6 +910,11 @@ _tlocal_ctx_record_cleanup(_tlocal_ctx_t *ctx_rec)
909
910
if (NULL == ctx_rec -> gctx ) {
910
911
return OPAL_SUCCESS ;
911
912
}
913
+
914
+ if (ctx_rec -> refcnt > 0 ) {
915
+ return OPAL_SUCCESS ;
916
+ }
917
+
912
918
/* Remove myself from the communication context structure
913
919
* This may result in context release as we are using
914
920
* delayed cleanup */
@@ -934,7 +940,7 @@ _tlocal_add_ctx(_tlocal_table_t *tls, opal_common_ucx_ctx_t *ctx)
934
940
/* Try to find available record in the TLS table
935
941
* In parallel perform deferred cleanups */
936
942
for (i = 0 ; i < tls -> ctx_tbl_size ; i ++ ) {
937
- if (NULL != tls -> ctx_tbl [i ]-> gctx ) {
943
+ if (NULL != tls -> ctx_tbl [i ]-> gctx && tls -> ctx_tbl [ i ] -> refcnt == 0 ) {
938
944
if (tls -> ctx_tbl [i ]-> gctx -> released ) {
939
945
/* Found dirty record, need to clean first */
940
946
_tlocal_ctx_record_cleanup (tls -> ctx_tbl [i ]);
@@ -1059,6 +1065,10 @@ _tlocal_mem_record_cleanup(_tlocal_mem_t *mem_rec)
1059
1065
free (mem_rec -> mem_tls_ptr );
1060
1066
}
1061
1067
1068
+ assert (mem_rec -> ctx_rec != NULL );
1069
+ OPAL_ATOMIC_ADD_FETCH32 (& mem_rec -> ctx_rec -> refcnt , -1 );
1070
+ assert (mem_rec -> ctx_rec -> refcnt >= 0 );
1071
+
1062
1072
free (mem_rec -> mem );
1063
1073
1064
1074
memset (mem_rec , 0 , sizeof (* mem_rec ));
@@ -1107,6 +1117,9 @@ static _tlocal_mem_t *_tlocal_add_mem(_tlocal_table_t *tls,
1107
1117
WPOOL_DBG_OUT ("tls = %p, ctx = %p\n" ,
1108
1118
(void * )tls , (void * )mem -> ctx );
1109
1119
1120
+ tls -> mem_tbl [free_idx ]-> ctx_rec = ctx_rec ;
1121
+ OPAL_ATOMIC_ADD_FETCH32 (& ctx_rec -> refcnt , 1 );
1122
+
1110
1123
tls -> mem_tbl [free_idx ]-> mem -> worker = ctx_rec -> winfo ;
1111
1124
tls -> mem_tbl [free_idx ]-> mem -> rkeys = calloc (mem -> ctx -> comm_size ,
1112
1125
sizeof (* tls -> mem_tbl [free_idx ]-> mem -> rkeys ));
0 commit comments