Skip to content

Commit 13a8e42

Browse files
authored
Merge pull request #6163 from artpol84/osc/mt_submission
Refactoring of osc/ucx component for MT
2 parents 170d5d1 + 91d6115 commit 13a8e42

12 files changed

+2254
-780
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,19 @@
1515
#include "ompi/group/group.h"
1616
#include "ompi/communicator/communicator.h"
1717
#include "opal/mca/common/ucx/common_ucx.h"
18+
#include "opal/mca/common/ucx/common_ucx_wpool.h"
1819

1920
#define OSC_UCX_ASSERT MCA_COMMON_UCX_ASSERT
2021
#define OSC_UCX_ERROR MCA_COMMON_UCX_ERROR
2122
#define OSC_UCX_VERBOSE MCA_COMMON_UCX_VERBOSE
2223

2324
#define OMPI_OSC_UCX_POST_PEER_MAX 32
2425
#define OMPI_OSC_UCX_ATTACH_MAX 32
25-
#define OMPI_OSC_UCX_RKEY_BUF_MAX 1024
26-
27-
typedef struct ompi_osc_ucx_win_info {
28-
ucp_rkey_h rkey;
29-
uint64_t addr;
30-
bool rkey_init;
31-
} ompi_osc_ucx_win_info_t;
26+
#define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024
3227

3328
typedef struct ompi_osc_ucx_component {
3429
ompi_osc_base_component_t super;
35-
ucp_context_h ucp_context;
36-
ucp_worker_h ucp_worker;
30+
opal_common_ucx_wpool_t *wpool;
3731
bool enable_mpi_threads;
3832
opal_free_list_t requests; /* request free list for the r* communication variants */
3933
bool env_initialized; /* UCX environment is initialized or not */
@@ -62,7 +56,6 @@ typedef struct ompi_osc_ucx_epoch_type {
6256
#define TARGET_LOCK_EXCLUSIVE ((uint64_t)(0x0000000100000000ULL))
6357

6458
#define OSC_UCX_IOVEC_MAX 128
65-
#define OSC_UCX_OPS_THRESHOLD 1000000
6659

6760
#define OSC_UCX_STATE_LOCK_OFFSET 0
6861
#define OSC_UCX_STATE_REQ_FLAG_OFFSET sizeof(uint64_t)
@@ -75,11 +68,13 @@ typedef struct ompi_osc_ucx_epoch_type {
7568
typedef struct ompi_osc_dynamic_win_info {
7669
uint64_t base;
7770
size_t size;
78-
char rkey_buffer[OMPI_OSC_UCX_RKEY_BUF_MAX];
71+
char mem_addr[OMPI_OSC_UCX_MEM_ADDR_MAX_LEN];
7972
} ompi_osc_dynamic_win_info_t;
8073

8174
typedef struct ompi_osc_local_dynamic_win_info {
82-
ucp_mem_h memh;
75+
opal_common_ucx_wpmem_t *mem;
76+
char *my_mem_addr;
77+
int my_mem_addr_size;
8378
int refcnt;
8479
} ompi_osc_local_dynamic_win_info_t;
8580

@@ -97,12 +92,10 @@ typedef struct ompi_osc_ucx_state {
9792
typedef struct ompi_osc_ucx_module {
9893
ompi_osc_base_module_t super;
9994
struct ompi_communicator_t *comm;
100-
ucp_mem_h memh; /* remote accessible memory */
10195
int flavor;
10296
size_t size;
103-
ucp_mem_h state_memh;
104-
ompi_osc_ucx_win_info_t *win_info_array;
105-
ompi_osc_ucx_win_info_t *state_info_array;
97+
uint64_t *addrs;
98+
uint64_t *state_addrs;
10699
int disp_unit; /* if disp_unit >= 0, then everyone has the same
107100
* disp unit size; if disp_unit == -1, then we
108101
* need to look at disp_units */
@@ -117,11 +110,12 @@ typedef struct ompi_osc_ucx_module {
117110
opal_list_t pending_posts;
118111
int lock_count;
119112
int post_count;
120-
int global_ops_num;
121-
int *per_target_ops_nums;
122113
uint64_t req_result;
123114
int *start_grp_ranks;
124115
bool lock_all_is_nocheck;
116+
opal_common_ucx_ctx_t *ctx;
117+
opal_common_ucx_wpmem_t *mem;
118+
opal_common_ucx_wpmem_t *state_mem;
125119
} ompi_osc_ucx_module_t;
126120

127121
typedef enum locktype {
@@ -216,7 +210,4 @@ int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_
216210
int min_index, int max_index,
217211
uint64_t base, size_t len, int *insert);
218212

219-
void req_completion(void *request, ucs_status_t status);
220-
void internal_req_init(void *request);
221-
222213
#endif /* OMPI_OSC_UCX_H */

ompi/mca/osc/ucx/osc_ucx_active_target.c

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ static inline void ompi_osc_ucx_handle_incoming_post(ompi_osc_ucx_module_t *modu
6060

6161
int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win) {
6262
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
63-
int ret;
63+
int ret = OMPI_SUCCESS;
6464

6565
if (module->epoch_type.access != NONE_EPOCH &&
6666
module->epoch_type.access != FENCE_EPOCH) {
@@ -74,16 +74,12 @@ int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win) {
7474
}
7575

7676
if (!(assert & MPI_MODE_NOPRECEDE)) {
77-
ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker);
77+
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/);
7878
if (ret != OMPI_SUCCESS) {
7979
return ret;
8080
}
8181
}
8282

83-
module->global_ops_num = 0;
84-
memset(module->per_target_ops_nums, 0,
85-
sizeof(int) * ompi_comm_size(module->comm));
86-
8783
return module->comm->c_coll->coll_barrier(module->comm,
8884
module->comm->c_coll->coll_barrier_module);
8985
}
@@ -147,7 +143,7 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t
147143

148144
ompi_osc_ucx_handle_incoming_post(module, &(module->state.post_state[i]), ranks_in_win_grp, size);
149145
}
150-
ucp_worker_progress(mca_osc_ucx_component.ucp_worker);
146+
opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool);
151147
}
152148

153149
module->post_count = 0;
@@ -163,7 +159,6 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t
163159

164160
int ompi_osc_ucx_complete(struct ompi_win_t *win) {
165161
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
166-
ucs_status_t status;
167162
int i, size;
168163
int ret = OMPI_SUCCESS;
169164

@@ -173,29 +168,26 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {
173168

174169
module->epoch_type.access = NONE_EPOCH;
175170

176-
ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker);
171+
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/);
177172
if (ret != OMPI_SUCCESS) {
178173
return ret;
179174
}
180-
module->global_ops_num = 0;
181-
memset(module->per_target_ops_nums, 0,
182-
sizeof(int) * ompi_comm_size(module->comm));
183175

184176
size = ompi_group_size(module->start_group);
185177
for (i = 0; i < size; i++) {
186-
uint64_t remote_addr = (module->state_info_array)[module->start_grp_ranks[i]].addr + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; /* write to state.complete_count on remote side */
187-
ucp_rkey_h rkey = (module->state_info_array)[module->start_grp_ranks[i]].rkey;
188-
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, module->start_grp_ranks[i]);
189-
190-
status = ucp_atomic_post(ep, UCP_ATOMIC_POST_OP_ADD, 1,
191-
sizeof(uint64_t), remote_addr, rkey);
192-
if (status != UCS_OK) {
193-
OSC_UCX_VERBOSE(1, "ucp_atomic_post failed: %d", status);
178+
uint64_t remote_addr = module->state_addrs[module->start_grp_ranks[i]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; // write to state.complete_count on remote side
179+
180+
ret = opal_common_ucx_wpmem_post(module->mem, UCP_ATOMIC_POST_OP_ADD,
181+
1, module->start_grp_ranks[i], sizeof(uint64_t),
182+
remote_addr);
183+
if (ret != OMPI_SUCCESS) {
184+
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_post failed: %d", ret);
194185
}
195186

196-
ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker);
197-
if (OMPI_SUCCESS != ret) {
198-
OSC_UCX_VERBOSE(1, "opal_common_ucx_ep_flush failed: %d", ret);
187+
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP,
188+
module->start_grp_ranks[i]);
189+
if (ret != OMPI_SUCCESS) {
190+
return ret;
199191
}
200192
}
201193

@@ -243,25 +235,29 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t
243235
}
244236

245237
for (i = 0; i < size; i++) {
246-
uint64_t remote_addr = (module->state_info_array)[ranks_in_win_grp[i]].addr + OSC_UCX_STATE_POST_INDEX_OFFSET; /* write to state.post_index on remote side */
247-
ucp_rkey_h rkey = (module->state_info_array)[ranks_in_win_grp[i]].rkey;
248-
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, ranks_in_win_grp[i]);
238+
uint64_t remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_INDEX_OFFSET; // write to state.post_index on remote side
249239
uint64_t curr_idx = 0, result = 0;
250240

251241
/* do fop first to get an post index */
252-
opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_FADD, 1,
253-
&result, sizeof(result),
254-
remote_addr, rkey, mca_osc_ucx_component.ucp_worker);
242+
ret = opal_common_ucx_wpmem_fetch(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
243+
1, ranks_in_win_grp[i], &result,
244+
sizeof(result), remote_addr);
245+
if (ret != OMPI_SUCCESS) {
246+
return OMPI_ERROR;
247+
}
255248

256249
curr_idx = result & (OMPI_OSC_UCX_POST_PEER_MAX - 1);
257250

258-
remote_addr = (module->state_info_array)[ranks_in_win_grp[i]].addr + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof(uint64_t) * curr_idx;
251+
remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof(uint64_t) * curr_idx;
259252

260253
/* do cas to send post message */
261254
do {
262-
opal_common_ucx_atomic_cswap(ep, 0, (uint64_t)myrank + 1, &result,
263-
sizeof(result), remote_addr, rkey,
264-
mca_osc_ucx_component.ucp_worker);
255+
ret = opal_common_ucx_wpmem_cmpswp(module->mem, 0, result,
256+
myrank + 1, &result, sizeof(result),
257+
remote_addr);
258+
if (ret != OMPI_SUCCESS) {
259+
return OMPI_ERROR;
260+
}
265261

266262
if (result == 0)
267263
break;
@@ -302,7 +298,7 @@ int ompi_osc_ucx_wait(struct ompi_win_t *win) {
302298

303299
while (module->state.complete_count != (uint64_t)size) {
304300
/* not sure if this is required */
305-
ucp_worker_progress(mca_osc_ucx_component.ucp_worker);
301+
opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool);
306302
}
307303

308304
module->state.complete_count = 0;

0 commit comments

Comments
 (0)