@@ -60,7 +60,7 @@ static inline void ompi_osc_ucx_handle_incoming_post(ompi_osc_ucx_module_t *modu
60
60
61
61
int ompi_osc_ucx_fence (int assert , struct ompi_win_t * win ) {
62
62
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
63
- int ret ;
63
+ int ret = OMPI_SUCCESS ;
64
64
65
65
if (module -> epoch_type .access != NONE_EPOCH &&
66
66
module -> epoch_type .access != FENCE_EPOCH ) {
@@ -74,16 +74,12 @@ int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win) {
74
74
}
75
75
76
76
if (!(assert & MPI_MODE_NOPRECEDE )) {
77
- ret = opal_common_ucx_worker_flush ( mca_osc_ucx_component . ucp_worker );
77
+ ret = opal_common_ucx_wpmem_flush ( module -> mem , OPAL_COMMON_UCX_SCOPE_WORKER , 0 /*ignore*/ );
78
78
if (ret != OMPI_SUCCESS ) {
79
79
return ret ;
80
80
}
81
81
}
82
82
83
- module -> global_ops_num = 0 ;
84
- memset (module -> per_target_ops_nums , 0 ,
85
- sizeof (int ) * ompi_comm_size (module -> comm ));
86
-
87
83
return module -> comm -> c_coll -> coll_barrier (module -> comm ,
88
84
module -> comm -> c_coll -> coll_barrier_module );
89
85
}
@@ -147,7 +143,7 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t
147
143
148
144
ompi_osc_ucx_handle_incoming_post (module , & (module -> state .post_state [i ]), ranks_in_win_grp , size );
149
145
}
150
- ucp_worker_progress (mca_osc_ucx_component .ucp_worker );
146
+ opal_common_ucx_wpool_progress (mca_osc_ucx_component .wpool );
151
147
}
152
148
153
149
module -> post_count = 0 ;
@@ -163,7 +159,6 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t
163
159
164
160
int ompi_osc_ucx_complete (struct ompi_win_t * win ) {
165
161
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
166
- ucs_status_t status ;
167
162
int i , size ;
168
163
int ret = OMPI_SUCCESS ;
169
164
@@ -173,29 +168,26 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {
173
168
174
169
module -> epoch_type .access = NONE_EPOCH ;
175
170
176
- ret = opal_common_ucx_worker_flush ( mca_osc_ucx_component . ucp_worker );
171
+ ret = opal_common_ucx_wpmem_flush ( module -> mem , OPAL_COMMON_UCX_SCOPE_WORKER , 0 /*ignore*/ );
177
172
if (ret != OMPI_SUCCESS ) {
178
173
return ret ;
179
174
}
180
- module -> global_ops_num = 0 ;
181
- memset (module -> per_target_ops_nums , 0 ,
182
- sizeof (int ) * ompi_comm_size (module -> comm ));
183
175
184
176
size = ompi_group_size (module -> start_group );
185
177
for (i = 0 ; i < size ; i ++ ) {
186
- uint64_t remote_addr = (module -> state_info_array )[module -> start_grp_ranks [i ]].addr + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET ; /* write to state.complete_count on remote side */
187
- ucp_rkey_h rkey = (module -> state_info_array )[module -> start_grp_ranks [i ]].rkey ;
188
- ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , module -> start_grp_ranks [i ]);
189
-
190
- status = ucp_atomic_post (ep , UCP_ATOMIC_POST_OP_ADD , 1 ,
191
- sizeof (uint64_t ), remote_addr , rkey );
192
- if (status != UCS_OK ) {
193
- OSC_UCX_VERBOSE (1 , "ucp_atomic_post failed: %d" , status );
178
+ uint64_t remote_addr = module -> state_addrs [module -> start_grp_ranks [i ]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET ; // write to state.complete_count on remote side
179
+
180
+ ret = opal_common_ucx_wpmem_post (module -> mem , UCP_ATOMIC_POST_OP_ADD ,
181
+ 1 , module -> start_grp_ranks [i ], sizeof (uint64_t ),
182
+ remote_addr );
183
+ if (ret != OMPI_SUCCESS ) {
184
+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_post failed: %d" , ret );
194
185
}
195
186
196
- ret = opal_common_ucx_ep_flush (ep , mca_osc_ucx_component .ucp_worker );
197
- if (OMPI_SUCCESS != ret ) {
198
- OSC_UCX_VERBOSE (1 , "opal_common_ucx_ep_flush failed: %d" , ret );
187
+ ret = opal_common_ucx_wpmem_flush (module -> mem , OPAL_COMMON_UCX_SCOPE_EP ,
188
+ module -> start_grp_ranks [i ]);
189
+ if (ret != OMPI_SUCCESS ) {
190
+ return ret ;
199
191
}
200
192
}
201
193
@@ -243,25 +235,29 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t
243
235
}
244
236
245
237
for (i = 0 ; i < size ; i ++ ) {
246
- uint64_t remote_addr = (module -> state_info_array )[ranks_in_win_grp [i ]].addr + OSC_UCX_STATE_POST_INDEX_OFFSET ; /* write to state.post_index on remote side */
247
- ucp_rkey_h rkey = (module -> state_info_array )[ranks_in_win_grp [i ]].rkey ;
248
- ucp_ep_h ep = OSC_UCX_GET_EP (module -> comm , ranks_in_win_grp [i ]);
238
+ uint64_t remote_addr = module -> state_addrs [ranks_in_win_grp [i ]] + OSC_UCX_STATE_POST_INDEX_OFFSET ; // write to state.post_index on remote side
249
239
uint64_t curr_idx = 0 , result = 0 ;
250
240
251
241
/* do fop first to get an post index */
252
- opal_common_ucx_atomic_fetch (ep , UCP_ATOMIC_FETCH_OP_FADD , 1 ,
253
- & result , sizeof (result ),
254
- remote_addr , rkey , mca_osc_ucx_component .ucp_worker );
242
+ ret = opal_common_ucx_wpmem_fetch (module -> mem , UCP_ATOMIC_FETCH_OP_FADD ,
243
+ 1 , ranks_in_win_grp [i ], & result ,
244
+ sizeof (result ), remote_addr );
245
+ if (ret != OMPI_SUCCESS ) {
246
+ return OMPI_ERROR ;
247
+ }
255
248
256
249
curr_idx = result & (OMPI_OSC_UCX_POST_PEER_MAX - 1 );
257
250
258
- remote_addr = ( module -> state_info_array ) [ranks_in_win_grp [i ]]. addr + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof (uint64_t ) * curr_idx ;
251
+ remote_addr = module -> state_addrs [ranks_in_win_grp [i ]] + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof (uint64_t ) * curr_idx ;
259
252
260
253
/* do cas to send post message */
261
254
do {
262
- opal_common_ucx_atomic_cswap (ep , 0 , (uint64_t )myrank + 1 , & result ,
263
- sizeof (result ), remote_addr , rkey ,
264
- mca_osc_ucx_component .ucp_worker );
255
+ ret = opal_common_ucx_wpmem_cmpswp (module -> mem , 0 , result ,
256
+ myrank + 1 , & result , sizeof (result ),
257
+ remote_addr );
258
+ if (ret != OMPI_SUCCESS ) {
259
+ return OMPI_ERROR ;
260
+ }
265
261
266
262
if (result == 0 )
267
263
break ;
@@ -302,7 +298,7 @@ int ompi_osc_ucx_wait(struct ompi_win_t *win) {
302
298
303
299
while (module -> state .complete_count != (uint64_t )size ) {
304
300
/* not sure if this is required */
305
- ucp_worker_progress (mca_osc_ucx_component .ucp_worker );
301
+ opal_common_ucx_wpool_progress (mca_osc_ucx_component .wpool );
306
302
}
307
303
308
304
module -> state .complete_count = 0 ;
0 commit comments