Skip to content

Commit 81aeea5

Browse files
authored
Merge pull request #9636 from janjust/master-osc-dbg
master: osc/ucx fixes dynamic windows
2 parents 3684a92 + 9bd8115 commit 81aeea5

File tree

6 files changed

+116
-62
lines changed

6 files changed

+116
-62
lines changed

.mailmap

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,5 @@ Wei-keng Liao <wkliao@users.noreply.github.com>
124124

125125
Samuel K. Gutierrez <samuel@lanl.gov> <samuelkgutierrez@users.noreply.github.com>
126126
Samuel K. Gutierrez <samuel@lanl.gov> <samuel@lanl.gov>
127+
128+
Tomislav Janjusic <tomislavj@nvidia.com> Tomislavj Janjusic <tomislavj@nvidia.com>

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -325,18 +325,18 @@ static inline int end_atomicity(
325325
}
326326

327327
static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module,
328-
int target) {
328+
int target, int *win_idx) {
329329
uint64_t remote_state_addr = (module->state_addrs)[target] + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET;
330-
size_t len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX;
331-
char *temp_buf = malloc(len);
330+
size_t remote_state_len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX;
331+
char *temp_buf = calloc(remote_state_len, 1);
332332
ompi_osc_dynamic_win_info_t *temp_dynamic_wins;
333333
uint64_t win_count;
334-
int contain, insert = -1;
334+
int insert = -1;
335335
int ret;
336336

337337
ret = opal_common_ucx_wpmem_putget(module->state_mem, OPAL_COMMON_UCX_GET, target,
338338
(void *)((intptr_t)temp_buf),
339-
len, remote_state_addr);
339+
remote_state_len, remote_state_addr);
340340
if (OPAL_SUCCESS != ret) {
341341
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
342342
ret = OMPI_ERROR;
@@ -350,23 +350,27 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
350350
}
351351

352352
memcpy(&win_count, temp_buf, sizeof(uint64_t));
353-
assert(win_count > 0 && win_count <= OMPI_OSC_UCX_ATTACH_MAX);
353+
if (win_count > OMPI_OSC_UCX_ATTACH_MAX) {
354+
return MPI_ERR_RMA_RANGE;
355+
}
354356

355357
temp_dynamic_wins = (ompi_osc_dynamic_win_info_t *)(temp_buf + sizeof(uint64_t));
356-
contain = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count,
358+
*win_idx = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count,
357359
remote_addr, 1, &insert);
358-
assert(contain >= 0 && (uint64_t)contain < win_count);
360+
if (*win_idx < 0 || (uint64_t)*win_idx >= win_count) {
361+
return MPI_ERR_RMA_RANGE;
362+
}
359363

360-
if (module->local_dynamic_win_info[contain].mem->mem_addrs == NULL) {
361-
module->local_dynamic_win_info[contain].mem->mem_addrs = calloc(ompi_comm_size(module->comm),
364+
if (module->local_dynamic_win_info[*win_idx].mem->mem_addrs == NULL) {
365+
module->local_dynamic_win_info[*win_idx].mem->mem_addrs = calloc(ompi_comm_size(module->comm),
362366
OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
363-
module->local_dynamic_win_info[contain].mem->mem_displs =calloc(ompi_comm_size(module->comm),
367+
module->local_dynamic_win_info[*win_idx].mem->mem_displs = calloc(ompi_comm_size(module->comm),
364368
sizeof(int));
365369
}
366370

367-
memcpy(module->local_dynamic_win_info[contain].mem->mem_addrs + target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN,
368-
temp_dynamic_wins[contain].mem_addr, OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
369-
module->local_dynamic_win_info[contain].mem->mem_displs[target] = target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN;
371+
memcpy(module->local_dynamic_win_info[*win_idx].mem->mem_addrs + target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN,
372+
temp_dynamic_wins[*win_idx].mem_addr, OMPI_OSC_UCX_MEM_ADDR_MAX_LEN);
373+
module->local_dynamic_win_info[*win_idx].mem->mem_displs[target] = target * OMPI_OSC_UCX_MEM_ADDR_MAX_LEN;
370374

371375
cleanup:
372376
free(temp_buf);
@@ -416,17 +420,20 @@ static int do_atomic_op_intrinsic(
416420
void *result_addr,
417421
ompi_osc_ucx_request_t *ucx_req)
418422
{
419-
int ret = OMPI_SUCCESS;
423+
int ret = OMPI_SUCCESS,
424+
win_idx = -1;
420425
size_t origin_dt_bytes;
426+
opal_common_ucx_wpmem_t *mem = module->mem;
421427
ompi_datatype_type_size(dt, &origin_dt_bytes);
422428

423429
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
424430

425431
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
426-
ret = get_dynamic_win_info(remote_addr, module, target);
432+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
427433
if (ret != OMPI_SUCCESS) {
428434
return ret;
429435
}
436+
mem = module->local_dynamic_win_info[win_idx].mem;
430437
}
431438

432439
ucp_atomic_fetch_op_t opcode;
@@ -454,7 +461,7 @@ static int do_atomic_op_intrinsic(
454461
user_req_ptr = ucx_req;
455462
// issue a fence if this is the last but not the only element
456463
if (0 < i) {
457-
ret = opal_common_ucx_wpmem_fence(module->mem);
464+
ret = opal_common_ucx_wpmem_fence(mem);
458465
if (ret != OMPI_SUCCESS) {
459466
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
460467
return OMPI_ERROR;
@@ -466,7 +473,7 @@ static int do_atomic_op_intrinsic(
466473
} else {
467474
value = opal_common_ucx_load_uint64(origin_addr, origin_dt_bytes);
468475
}
469-
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, opcode, value, target,
476+
ret = opal_common_ucx_wpmem_fetch_nb(mem, opcode, value, target,
470477
output_addr, origin_dt_bytes, remote_addr,
471478
user_req_cb, user_req_ptr);
472479

@@ -485,21 +492,23 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
485492
int target, ptrdiff_t target_disp, int target_count,
486493
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
487494
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
495+
opal_common_ucx_wpmem_t *mem = module->mem;
488496
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
489497
bool is_origin_contig = false, is_target_contig = false;
490498
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
491-
int ret = OMPI_SUCCESS;
499+
int ret = OMPI_SUCCESS, win_idx = -1;
492500

493501
ret = check_sync_state(module, target, false);
494502
if (ret != OMPI_SUCCESS) {
495503
return ret;
496504
}
497505

498506
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
499-
ret = get_dynamic_win_info(remote_addr, module, target);
507+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
500508
if (ret != OMPI_SUCCESS) {
501509
return ret;
502510
}
511+
mem = module->local_dynamic_win_info[win_idx].mem;
503512
}
504513

505514
if (!target_count) {
@@ -519,7 +528,7 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
519528
ompi_datatype_type_size(origin_dt, &origin_len);
520529
origin_len *= origin_count;
521530

522-
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
531+
ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_PUT, target,
523532
(void *)((intptr_t)origin_addr + origin_lb),
524533
origin_len, remote_addr + target_lb);
525534
if (OPAL_SUCCESS != ret) {
@@ -539,21 +548,23 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
539548
int target, ptrdiff_t target_disp, int target_count,
540549
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
541550
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
551+
opal_common_ucx_wpmem_t *mem = module->mem;
542552
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
543553
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
544554
bool is_origin_contig = false, is_target_contig = false;
545-
int ret = OMPI_SUCCESS;
555+
int ret = OMPI_SUCCESS, win_idx = -1;
546556

547557
ret = check_sync_state(module, target, false);
548558
if (ret != OMPI_SUCCESS) {
549559
return ret;
550560
}
551561

552562
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
553-
ret = get_dynamic_win_info(remote_addr, module, target);
563+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
554564
if (ret != OMPI_SUCCESS) {
555565
return ret;
556566
}
567+
mem = module->local_dynamic_win_info[win_idx].mem;
557568
}
558569

559570
if (!target_count) {
@@ -574,7 +585,7 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
574585
ompi_datatype_type_size(origin_dt, &origin_len);
575586
origin_len *= origin_count;
576587

577-
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
588+
ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_GET, target,
578589
(void *)((intptr_t)origin_addr + origin_lb),
579590
origin_len, remote_addr + target_lb);
580591
if (OPAL_SUCCESS != ret) {
@@ -771,9 +782,10 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
771782
int target, ptrdiff_t target_disp,
772783
struct ompi_win_t *win) {
773784
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module;
785+
opal_common_ucx_wpmem_t *mem = module->mem;
774786
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
775787
size_t dt_bytes;
776-
int ret = OMPI_SUCCESS;
788+
int ret = OMPI_SUCCESS, win_idx = -1;
777789
bool lock_acquired = false;
778790

779791
ret = check_sync_state(module, target, false);
@@ -782,10 +794,11 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
782794
}
783795

784796
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
785-
ret = get_dynamic_win_info(remote_addr, module, target);
797+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
786798
if (ret != OMPI_SUCCESS) {
787799
return ret;
788800
}
801+
mem = module->local_dynamic_win_info[win_idx].mem;
789802
}
790803

791804
ompi_datatype_type_size(dt, &dt_bytes);
@@ -803,21 +816,21 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
803816
return ret;
804817
}
805818

806-
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_GET, target,
819+
ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_GET, target,
807820
&result_addr, dt_bytes, remote_addr);
808821
if (OPAL_SUCCESS != ret) {
809822
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
810823
return OMPI_ERROR;
811824
}
812825

813-
ret = opal_common_ucx_wpmem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target);
826+
ret = opal_common_ucx_wpmem_flush(mem, OPAL_COMMON_UCX_SCOPE_EP, target);
814827
if (ret != OPAL_SUCCESS) {
815828
return ret;
816829
}
817830

818831
if (0 == memcmp(result_addr, compare_addr, dt_bytes)) {
819832
// write the new value
820-
ret = opal_common_ucx_wpmem_putget(module->mem, OPAL_COMMON_UCX_PUT, target,
833+
ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_PUT, target,
821834
(void*)origin_addr, dt_bytes, remote_addr);
822835
if (OPAL_SUCCESS != ret) {
823836
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret);
@@ -834,7 +847,8 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
834847
struct ompi_win_t *win) {
835848
size_t dt_bytes;
836849
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
837-
int ret = OMPI_SUCCESS;
850+
opal_common_ucx_wpmem_t *mem = module->mem;
851+
int ret = OMPI_SUCCESS, win_idx = -1;
838852

839853
ret = check_sync_state(module, target, false);
840854
if (ret != OMPI_SUCCESS) {
@@ -860,10 +874,11 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
860874
}
861875

862876
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
863-
ret = get_dynamic_win_info(remote_addr, module, target);
877+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
864878
if (ret != OMPI_SUCCESS) {
865879
return ret;
866880
}
881+
mem = module->local_dynamic_win_info[win_idx].mem;
867882
}
868883

869884
value = origin_addr ? opal_common_ucx_load_uint64(origin_addr, dt_bytes) : 0;
@@ -877,7 +892,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
877892
}
878893
}
879894

880-
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, opcode, value, target,
895+
ret = opal_common_ucx_wpmem_fetch_nb(mem, opcode, value, target,
881896
(void *)result_addr, dt_bytes,
882897
remote_addr, NULL, NULL);
883898

@@ -1049,20 +1064,22 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
10491064
struct ompi_datatype_t *target_dt,
10501065
struct ompi_win_t *win, struct ompi_request_t **request) {
10511066
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
1067+
opal_common_ucx_wpmem_t *mem = module->mem;
10521068
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
10531069
ompi_osc_ucx_request_t *ucx_req = NULL;
1054-
int ret = OMPI_SUCCESS;
1070+
int ret = OMPI_SUCCESS, win_idx = -1;
10551071

10561072
ret = check_sync_state(module, target, true);
10571073
if (ret != OMPI_SUCCESS) {
10581074
return ret;
10591075
}
10601076

10611077
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
1062-
ret = get_dynamic_win_info(remote_addr, module, target);
1078+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
10631079
if (ret != OMPI_SUCCESS) {
10641080
return ret;
10651081
}
1082+
mem = module->local_dynamic_win_info[win_idx].mem;
10661083
}
10671084

10681085
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
@@ -1074,15 +1091,15 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
10741091
return ret;
10751092
}
10761093

1077-
ret = opal_common_ucx_wpmem_fence(module->mem);
1094+
ret = opal_common_ucx_wpmem_fence(mem);
10781095
if (ret != OMPI_SUCCESS) {
10791096
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
10801097
return OMPI_ERROR;
10811098
}
10821099

10831100
mca_osc_ucx_component.num_incomplete_req_ops++;
10841101
/* TODO: investigate whether ucp_worker_flush_nb is a better choice here */
1085-
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
1102+
ret = opal_common_ucx_wpmem_fetch_nb(mem, UCP_ATOMIC_FETCH_OP_FADD,
10861103
0, target, &(module->req_result),
10871104
sizeof(uint64_t), remote_addr & (~0x7),
10881105
req_completion, ucx_req);
@@ -1102,20 +1119,22 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
11021119
struct ompi_datatype_t *target_dt, struct ompi_win_t *win,
11031120
struct ompi_request_t **request) {
11041121
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
1122+
opal_common_ucx_wpmem_t *mem = module->mem;
11051123
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
11061124
ompi_osc_ucx_request_t *ucx_req = NULL;
1107-
int ret = OMPI_SUCCESS;
1125+
int ret = OMPI_SUCCESS, win_idx = -1;
11081126

11091127
ret = check_sync_state(module, target, true);
11101128
if (ret != OMPI_SUCCESS) {
11111129
return ret;
11121130
}
11131131

11141132
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
1115-
ret = get_dynamic_win_info(remote_addr, module, target);
1133+
ret = get_dynamic_win_info(remote_addr, module, target, &win_idx);
11161134
if (ret != OMPI_SUCCESS) {
11171135
return ret;
11181136
}
1137+
mem = module->local_dynamic_win_info[win_idx].mem;
11191138
}
11201139

11211140
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
@@ -1127,15 +1146,15 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
11271146
return ret;
11281147
}
11291148

1130-
ret = opal_common_ucx_wpmem_fence(module->mem);
1149+
ret = opal_common_ucx_wpmem_fence(mem);
11311150
if (ret != OMPI_SUCCESS) {
11321151
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret);
11331152
return OMPI_ERROR;
11341153
}
11351154

11361155
mca_osc_ucx_component.num_incomplete_req_ops++;
11371156
/* TODO: investigate whether ucp_worker_flush_nb is a better choice here */
1138-
ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD,
1157+
ret = opal_common_ucx_wpmem_fetch_nb(mem, UCP_ATOMIC_FETCH_OP_FADD,
11391158
0, target, &(module->req_result),
11401159
sizeof(uint64_t), remote_addr & (~0x7),
11411160
req_completion, ucx_req);

0 commit comments

Comments
 (0)