Skip to content

Commit 9de128a

Browse files
author
Sergey Oblomov
committed
OSHMEM: added processing of zero-length collectives
- according spec 1.4, annex C shmem collectives should process calls where number of elements is zero independently from pointer value - added zero-count processing - it just call barrier to sync ranks Signed-off-by: Sergey Oblomov <sergeyo@mellanox.com>
1 parent f05ebe8 commit 9de128a

File tree

10 files changed

+56
-35
lines changed

10 files changed

+56
-35
lines changed

oshmem/mca/scoll/basic/scoll_basic_alltoall.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,17 @@ int mca_scoll_basic_alltoall(struct oshmem_group_t *group,
6161
return OSHMEM_ERR_BAD_PARAM;
6262
}
6363

64-
if ((sst == 1) && (dst == 1)) {
65-
rc = a2a_alg_simple(group, target, source, nelems, element_size);
66-
} else {
67-
rc = a2as_alg_simple(group, target, source, dst, sst, nelems,
68-
element_size);
69-
}
64+
if (nelems) {
65+
if ((sst == 1) && (dst == 1)) {
66+
rc = a2a_alg_simple(group, target, source, nelems, element_size);
67+
} else {
68+
rc = a2as_alg_simple(group, target, source, dst, sst, nelems,
69+
element_size);
70+
}
7071

71-
if (rc != OSHMEM_SUCCESS) {
72-
return rc;
72+
if (rc != OSHMEM_SUCCESS) {
73+
return rc;
74+
}
7375
}
7476

7577
/* quiet is needed because scoll level barrier does not

oshmem/mca/scoll/basic/scoll_basic_broadcast.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ static int _algorithm_central_counter(struct oshmem_group_t *group,
131131
group->my_pe, pSync[0], PE_root);
132132

133133
/* Check if this PE is the root */
134-
if (PE_root == group->my_pe) {
134+
if ((PE_root == group->my_pe) && nlong) {
135135
int pe_cur = 0;
136136

137137
SCOLL_VERBOSE(14,
@@ -192,6 +192,16 @@ static int _algorithm_binomial_tree(struct oshmem_group_t *group,
192192
"[#%d] pSync[0] = %ld root = #%d",
193193
group->my_pe, pSync[0], PE_root);
194194

195+
if (OPAL_UNLIKELY(!nlong)) {
196+
SCOLL_VERBOSE(14, "[#%d] Wait for operation completion", group->my_pe);
197+
/* wait until root finishes sending data */
198+
rc = BARRIER_FUNC(group,
199+
(pSync + 1),
200+
SCOLL_DEFAULT_ALG);
201+
return rc;
202+
}
203+
204+
195205
vrank = (my_id + group->proc_count - root_id) % group->proc_count;
196206
hibit = opal_hibit(vrank, dim);
197207

oshmem/mca/scoll/basic/scoll_basic_collect.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ int mca_scoll_basic_collect(struct oshmem_group_t *group,
6666
if ((rc == OSHMEM_SUCCESS) && oshmem_proc_group_is_member(group)) {
6767
int i = 0;
6868

69-
if (nlong_type) {
69+
if (nlong_type && nlong) {
7070
alg = (alg == SCOLL_DEFAULT_ALG ?
7171
mca_scoll_basic_param_collect_algorithm : alg);
7272
switch (alg) {
@@ -156,7 +156,7 @@ static int _algorithm_f_central_counter(struct oshmem_group_t *group,
156156
group->my_pe);
157157
SCOLL_VERBOSE(15, "[#%d] pSync[0] = %ld", group->my_pe, pSync[0]);
158158

159-
if (PE_root == group->my_pe) {
159+
if ((PE_root == group->my_pe) && nlong) {
160160
int pe_cur = 0;
161161

162162
memcpy((void*) ((unsigned char*) target + 0 * nlong),
@@ -543,7 +543,7 @@ static int _algorithm_central_collector(struct oshmem_group_t *group,
543543
/* Set own data size */
544544
pSync[0] = (nlong ? (long)nlong : SHMEM_SYNC_READY);
545545

546-
if (PE_root == group->my_pe) {
546+
if ((PE_root == group->my_pe) && nlong) {
547547
long value = 0;
548548
int pe_cur = 0;
549549
long wait_pe_count = 0;

oshmem/mca/scoll/basic/scoll_basic_reduce.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ int mca_scoll_basic_reduce(struct oshmem_group_t *group,
7979
int i = 0;
8080

8181
if (pSync) {
82-
alg = (alg == SCOLL_DEFAULT_ALG ?
83-
mca_scoll_basic_param_reduce_algorithm : alg);
82+
alg = (nlong ? (alg == SCOLL_DEFAULT_ALG ?
83+
mca_scoll_basic_param_reduce_algorithm : alg) :
84+
SCOLL_ALG_REDUCE_CENTRAL_COUNTER );
8485
switch (alg) {
8586
case SCOLL_ALG_REDUCE_CENTRAL_COUNTER:
8687
{
@@ -185,7 +186,7 @@ static int _algorithm_central_counter(struct oshmem_group_t *group,
185186

186187
SCOLL_VERBOSE(12, "[#%d] Reduce algorithm: Central Counter", group->my_pe);
187188

188-
if (PE_root == group->my_pe) {
189+
if ((PE_root == group->my_pe) && nlong) {
189190
int pe_cur = 0;
190191
void *target_cur = NULL;
191192

oshmem/mca/scoll/mpi/scoll_mpi_ops.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ int mca_scoll_mpi_broadcast(struct oshmem_group_t *group,
6161
* and considering this contradiction, we cast size_t to int here
6262
* in case if the value is less than INT_MAX and fallback to previous module otherwise. */
6363
#ifdef INCOMPATIBLE_SHMEM_OMPI_COLL_APIS
64-
if (INT_MAX < nlong) {
64+
if ((INT_MAX < nlong) || !nlong) {
6565
MPI_COLL_VERBOSE(20,"RUNNING FALLBACK BCAST");
6666
PREVIOUS_SCOLL_FN(mpi_module, broadcast, group,
6767
PE_root,
@@ -104,7 +104,7 @@ int mca_scoll_mpi_collect(struct oshmem_group_t *group,
104104
void *sbuf, *rbuf;
105105
MPI_COLL_VERBOSE(20,"RUNNING MPI ALLGATHER");
106106
mpi_module = (mca_scoll_mpi_module_t *) group->g_scoll.scoll_collect_module;
107-
if (nlong_type == true) {
107+
if ((nlong_type == true) && nlong) {
108108
sbuf = (void *) source;
109109
rbuf = target;
110110
stype = &ompi_mpi_char.dt;
@@ -184,7 +184,7 @@ int mca_scoll_mpi_reduce(struct oshmem_group_t *group,
184184
* and considering this contradiction, we cast size_t to int here
185185
* in case if the value is less than INT_MAX and fallback to previous module otherwise. */
186186
#ifdef INCOMPATIBLE_SHMEM_OMPI_COLL_APIS
187-
if (INT_MAX < count) {
187+
if ((INT_MAX < count) || !nlong) {
188188
MPI_COLL_VERBOSE(20,"RUNNING FALLBACK REDUCE");
189189
PREVIOUS_SCOLL_FN(mpi_module, reduce, group,
190190
op,

oshmem/runtime/runtime.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ OSHMEM_DECLSPEC int oshmem_shmem_register_params(void);
200200
RUNTIME_CHECK_ERROR("Required address %p is not in symmetric space\n", ((void*)x)); \
201201
oshmem_shmem_abort(-1); \
202202
}
203+
/* Check if address is in symmetric space or size is zero */
204+
#define RUNTIME_CHECK_ADDR_SIZE(x,s) \
205+
if (OPAL_UNLIKELY((s) && !MCA_MEMHEAP_CALL(is_symmetric_addr((x))))) \
206+
{ \
207+
RUNTIME_CHECK_ERROR("Required address %p is not in symmetric space\n", ((void*)x)); \
208+
oshmem_shmem_abort(-1); \
209+
}
203210
#define RUNTIME_CHECK_WITH_MEMHEAP_SIZE(x) \
204211
if (OPAL_UNLIKELY((long)(x) > MCA_MEMHEAP_CALL(size))) \
205212
{ \
@@ -212,6 +219,7 @@ OSHMEM_DECLSPEC int oshmem_shmem_register_params(void);
212219
#define RUNTIME_CHECK_INIT()
213220
#define RUNTIME_CHECK_PE(x)
214221
#define RUNTIME_CHECK_ADDR(x)
222+
#define RUNTIME_CHECK_ADDR_SIZE(x,s)
215223
#define RUNTIME_CHECK_WITH_MEMHEAP_SIZE(x)
216224

217225
#endif /* OSHMEM_PARAM_CHECK */

oshmem/shmem/c/shmem_alltoall.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ static void _shmem_alltoall(void *target,
3030
int PE_size,
3131
long *pSync);
3232

33-
#define SHMEM_TYPE_ALLTOALL(name, element_size) \
33+
#define SHMEM_TYPE_ALLTOALL(name, element_size) \
3434
void shmem##name(void *target, \
3535
const void *source, \
3636
size_t nelems, \
@@ -40,15 +40,15 @@ static void _shmem_alltoall(void *target,
4040
long *pSync) \
4141
{ \
4242
RUNTIME_CHECK_INIT(); \
43-
RUNTIME_CHECK_ADDR(target); \
44-
RUNTIME_CHECK_ADDR(source); \
43+
RUNTIME_CHECK_ADDR_SIZE(target, nelems); \
44+
RUNTIME_CHECK_ADDR_SIZE(source, nelems); \
4545
\
4646
_shmem_alltoall(target, source, 1, 1, nelems, element_size, \
4747
PE_start, logPE_stride, PE_size, \
4848
pSync); \
4949
}
5050

51-
#define SHMEM_TYPE_ALLTOALLS(name, element_size) \
51+
#define SHMEM_TYPE_ALLTOALLS(name, element_size) \
5252
void shmem##name(void *target, \
5353
const void *source, \
5454
ptrdiff_t dst, ptrdiff_t sst, \
@@ -59,8 +59,8 @@ static void _shmem_alltoall(void *target,
5959
long *pSync) \
6060
{ \
6161
RUNTIME_CHECK_INIT(); \
62-
RUNTIME_CHECK_ADDR(target); \
63-
RUNTIME_CHECK_ADDR(source); \
62+
RUNTIME_CHECK_ADDR_SIZE(target, nelems); \
63+
RUNTIME_CHECK_ADDR_SIZE(source, nelems); \
6464
\
6565
_shmem_alltoall(target, source, dst, sst, nelems, element_size, \
6666
PE_start, logPE_stride, PE_size, \

oshmem/shmem/c/shmem_broadcast.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ static void _shmem_broadcast(void *target,
2929
int PE_size,
3030
long *pSync);
3131

32-
#define SHMEM_TYPE_BROADCAST(name, element_size) \
32+
#define SHMEM_TYPE_BROADCAST(name, element_size) \
3333
void shmem##name( void *target, \
3434
const void *source, \
3535
size_t nelems, \
@@ -40,10 +40,10 @@ static void _shmem_broadcast(void *target,
4040
long *pSync) \
4141
{ \
4242
RUNTIME_CHECK_INIT(); \
43-
RUNTIME_CHECK_ADDR(target); \
44-
RUNTIME_CHECK_ADDR(source); \
43+
RUNTIME_CHECK_ADDR_SIZE(target, nelems); \
44+
RUNTIME_CHECK_ADDR_SIZE(source, nelems); \
4545
\
46-
_shmem_broadcast( target, source, nelems * element_size, \
46+
_shmem_broadcast( target, source, nelems * element_size, \
4747
PE_root, PE_start, logPE_stride, PE_size, \
4848
pSync); \
4949
}

oshmem/shmem/c/shmem_collect.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ static void _shmem_collect(void *target,
3939
long *pSync) \
4040
{ \
4141
RUNTIME_CHECK_INIT(); \
42-
RUNTIME_CHECK_ADDR(target); \
43-
RUNTIME_CHECK_ADDR(source); \
42+
RUNTIME_CHECK_ADDR_SIZE(target, nelems); \
43+
RUNTIME_CHECK_ADDR_SIZE(source, nelems); \
4444
\
45-
_shmem_collect( target, source, nelems * element_size, \
45+
_shmem_collect( target, source, nelems * element_size, \
4646
PE_start, logPE_stride, PE_size, \
4747
pSync, \
4848
nelems_type); \

oshmem/shmem/c/shmem_reduce.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
* object of every PE in the active set. The active set of PEs is defined by the triple PE_start,
2727
* logPE_stride and PE_size.
2828
*/
29-
#define SHMEM_TYPE_REDUCE_OP(name, type_name, type, prefix) \
30-
void prefix##type_name##_##name##_to_all( type *target, \
29+
#define SHMEM_TYPE_REDUCE_OP(name, type_name, type, prefix) \
30+
void prefix##type_name##_##name##_to_all( type *target, \
3131
const type *source, \
3232
int nreduce, \
3333
int PE_start, \
@@ -40,8 +40,8 @@
4040
oshmem_group_t* group = NULL; \
4141
\
4242
RUNTIME_CHECK_INIT(); \
43-
RUNTIME_CHECK_ADDR(target); \
44-
RUNTIME_CHECK_ADDR(source); \
43+
RUNTIME_CHECK_ADDR_SIZE(target, nreduce); \
44+
RUNTIME_CHECK_ADDR_SIZE(source, nreduce); \
4545
\
4646
{ \
4747
group = oshmem_proc_group_create_nofail(PE_start, 1<<logPE_stride, PE_size); \

0 commit comments

Comments
 (0)