Skip to content

Commit a1b02c6

Browse files
authored
Merge pull request #11893 from tvegas1/smkey_store
oshmem: Add symmetric remote key handling code
2 parents 3de90b1 + 776e8ba commit a1b02c6

File tree

5 files changed

+241
-31
lines changed

5 files changed

+241
-31
lines changed

config/ompi_check_ucx.m4

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
108108
UCP_PARAM_FIELD_ESTIMATED_NUM_PPN,
109109
UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK,
110110
UCP_OP_ATTR_FLAG_MULTI_SEND,
111-
UCS_MEMORY_TYPE_RDMA],
111+
UCS_MEMORY_TYPE_RDMA,
112+
UCP_MEM_MAP_SYMMETRIC_RKEY],
112113
[], [],
113114
[#include <ucp/api/ucp.h>])
114115
AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS],
@@ -124,7 +125,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
124125
[#include <ucp/api/ucp.h>])
125126
AC_CHECK_DECLS([ucp_tag_send_nbx,
126127
ucp_tag_send_sync_nbx,
127-
ucp_tag_recv_nbx],
128+
ucp_tag_recv_nbx,
129+
ucp_rkey_compare],
128130
[], [],
129131
[#include <ucp/api/ucp.h>])
130132
AC_CHECK_TYPES([ucp_request_param_t],

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "opal/datatype/opal_convertor.h"
2323
#include "opal/mca/common/ucx/common_ucx.h"
2424
#include "opal/util/opal_environ.h"
25+
#include "opal/util/minmax.h"
2526
#include "ompi/datatype/ompi_datatype.h"
2627
#include "ompi/mca/pml/pml.h"
2728

@@ -126,6 +127,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
126127
};
127128
#endif
128129

130+
unsigned
131+
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx)
132+
{
133+
#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
134+
if (spml_ucx->symmetric_rkey_max_count > 0) {
135+
return UCP_MEM_MAP_SYMMETRIC_RKEY;
136+
}
137+
#endif
138+
139+
return 0;
140+
}
141+
142+
void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store)
143+
{
144+
store->array = NULL;
145+
store->count = 0;
146+
store->size = 0;
147+
}
148+
149+
void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store)
150+
{
151+
int i;
152+
153+
for (i = 0; i < store->count; i++) {
154+
if (store->array[i].refcnt != 0) {
155+
SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0",
156+
i, store->count, store->array[i].refcnt);
157+
}
158+
159+
ucp_rkey_destroy(store->array[i].rkey);
160+
}
161+
162+
free(store->array);
163+
}
164+
165+
/**
166+
* Find position in sorted array for existing or future entry
167+
*
168+
* @param[in] store Store of the entries
169+
* @param[in] worker Common worker for rkeys used
170+
* @param[in] rkey Remote key to search for
171+
* @param[out] index Index of entry
172+
*
173+
* @return
174+
* OSHMEM_ERR_NOT_FOUND: index contains the position where future element
175+
* should be inserted to keep array sorted
176+
* OSHMEM_SUCCESS : index contains the position of the element
177+
* Other error : index is not valid
178+
*/
179+
static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store,
180+
const ucp_worker_h worker,
181+
const ucp_rkey_h rkey,
182+
int *index)
183+
{
184+
#if HAVE_DECL_UCP_RKEY_COMPARE
185+
ucp_rkey_compare_params_t params;
186+
int i, result, m, end;
187+
ucs_status_t status;
188+
189+
for (i = 0, end = store->count; i < end;) {
190+
m = (i + end) / 2;
191+
192+
params.field_mask = 0;
193+
status = ucp_rkey_compare(worker, store->array[m].rkey,
194+
rkey, &params, &result);
195+
if (status != UCS_OK) {
196+
return OSHMEM_ERROR;
197+
} else if (result == 0) {
198+
*index = m;
199+
return OSHMEM_SUCCESS;
200+
} else if (result > 0) {
201+
end = m;
202+
} else {
203+
i = m + 1;
204+
}
205+
}
206+
207+
*index = i;
208+
return OSHMEM_ERR_NOT_FOUND;
209+
#else
210+
return OSHMEM_ERROR;
211+
#endif
212+
}
213+
214+
static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store,
215+
int i, ucp_rkey_h rkey)
216+
{
217+
int size;
218+
mca_spml_ucx_rkey_t *tmp;
219+
220+
if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) {
221+
return;
222+
}
223+
224+
if (store->count >= store->size) {
225+
size = opal_min(opal_max(store->size, 8) * 2,
226+
mca_spml_ucx.symmetric_rkey_max_count);
227+
tmp = realloc(store->array, size * sizeof(*store->array));
228+
if (tmp == NULL) {
229+
return;
230+
}
231+
232+
store->array = tmp;
233+
store->size = size;
234+
}
235+
236+
memmove(&store->array[i + 1], &store->array[i],
237+
(store->count - i) * sizeof(*store->array));
238+
store->array[i].rkey = rkey;
239+
store->array[i].refcnt = 1;
240+
store->count++;
241+
return;
242+
}
243+
244+
/* Takes ownership of input ucp remote key */
245+
static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store,
246+
ucp_worker_h worker,
247+
ucp_rkey_h rkey)
248+
{
249+
int ret, i;
250+
251+
if (mca_spml_ucx.symmetric_rkey_max_count == 0) {
252+
return rkey;
253+
}
254+
255+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
256+
if (ret == OSHMEM_SUCCESS) {
257+
ucp_rkey_destroy(rkey);
258+
store->array[i].refcnt++;
259+
return store->array[i].rkey;
260+
}
261+
262+
if (ret == OSHMEM_ERR_NOT_FOUND) {
263+
mca_spml_ucx_rkey_store_insert(store, i, rkey);
264+
}
265+
266+
return rkey;
267+
}
268+
269+
static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store,
270+
ucp_worker_h worker,
271+
ucp_rkey_h rkey)
272+
{
273+
mca_spml_ucx_rkey_t *entry;
274+
int ret, i;
275+
276+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
277+
if (ret != OSHMEM_SUCCESS) {
278+
goto out;
279+
}
280+
281+
entry = &store->array[i];
282+
assert(entry->rkey == rkey);
283+
if (--entry->refcnt > 0) {
284+
return;
285+
}
286+
287+
memmove(&store->array[i], &store->array[i + 1],
288+
(store->count - (i + 1)) * sizeof(*store->array));
289+
store->count--;
290+
291+
out:
292+
ucp_rkey_destroy(rkey);
293+
}
294+
129295
int mca_spml_ucx_enable(bool enable)
130296
{
131297
SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
@@ -240,6 +406,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
240406
{
241407
int rc;
242408
ucs_status_t err;
409+
ucp_rkey_h rkey;
243410

244411
rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey);
245412
if (OSHMEM_SUCCESS != rc) {
@@ -248,11 +415,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
248415
}
249416

250417
if (mkey->u.data) {
251-
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey));
418+
err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey);
252419
if (UCS_OK != err) {
253420
SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err));
254421
return OSHMEM_ERROR;
255422
}
423+
424+
if (!oshmem_proc_on_local_node(pe)) {
425+
rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey);
426+
}
427+
428+
(*ucx_mkey)->rkey = rkey;
429+
256430
rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe);
257431
if (OSHMEM_SUCCESS != rc) {
258432
SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed");
@@ -267,7 +441,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
267441
ucp_peer_t *ucp_peer;
268442
int rc;
269443
ucp_peer = &(ucx_ctx->ucp_peers[pe]);
270-
ucp_rkey_destroy(ucx_mkey->rkey);
444+
mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey);
271445
ucx_mkey->rkey = NULL;
272446
rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno);
273447
if(OSHMEM_SUCCESS != rc){
@@ -725,7 +899,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
725899
UCP_MEM_MAP_PARAM_FIELD_FLAGS;
726900
mem_map_params.address = addr;
727901
mem_map_params.length = size;
728-
mem_map_params.flags = flags;
902+
mem_map_params.flags = flags |
903+
mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx);
729904

730905
status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h);
731906
if (UCS_OK != status) {
@@ -917,6 +1092,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
9171092
}
9181093
}
9191094

1095+
mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store);
1096+
9201097
*ucx_ctx_p = ucx_ctx;
9211098

9221099
return OSHMEM_SUCCESS;

oshmem/mca/spml/ucx/spml_ucx.h

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,31 @@ struct ucp_peer {
7676
size_t mkeys_cnt;
7777
};
7878
typedef struct ucp_peer ucp_peer_t;
79-
79+
80+
/* An rkey_store entry */
81+
typedef struct mca_spml_ucx_rkey {
82+
ucp_rkey_h rkey;
83+
int refcnt;
84+
} mca_spml_ucx_rkey_t;
85+
86+
typedef struct mca_spml_ucx_rkey_store {
87+
mca_spml_ucx_rkey_t *array;
88+
int size;
89+
int count;
90+
} mca_spml_ucx_rkey_store_t;
91+
8092
struct mca_spml_ucx_ctx {
81-
ucp_worker_h *ucp_worker;
82-
ucp_peer_t *ucp_peers;
83-
long options;
84-
opal_bitmap_t put_op_bitmap;
85-
unsigned long nb_progress_cnt;
86-
unsigned int ucp_workers;
87-
int *put_proc_indexes;
88-
unsigned put_proc_count;
89-
bool synchronized_quiet;
90-
int strong_sync;
93+
ucp_worker_h *ucp_worker;
94+
ucp_peer_t *ucp_peers;
95+
long options;
96+
opal_bitmap_t put_op_bitmap;
97+
unsigned long nb_progress_cnt;
98+
unsigned int ucp_workers;
99+
int *put_proc_indexes;
100+
unsigned put_proc_count;
101+
bool synchronized_quiet;
102+
int strong_sync;
103+
mca_spml_ucx_rkey_store_t rkey_store;
91104
};
92105
typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t;
93106

@@ -128,6 +141,7 @@ struct mca_spml_ucx {
128141
unsigned long nb_ucp_worker_progress;
129142
unsigned int ucp_workers;
130143
unsigned int ucp_worker_cnt;
144+
int symmetric_rkey_max_count;
131145
};
132146
typedef struct mca_spml_ucx mca_spml_ucx_t;
133147

@@ -280,6 +294,11 @@ extern int mca_spml_ucx_team_fcollect(shmem_team_t team, void
280294
extern int mca_spml_ucx_team_reduce(shmem_team_t team, void
281295
*dest, const void *source, size_t nreduce, int operation, int datatype);
282296

297+
extern unsigned
298+
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx);
299+
300+
extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store);
301+
extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store);
283302

284303
static inline int
285304
mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey)

0 commit comments

Comments
 (0)