Skip to content

Commit 4763822

Browse files
committed
pml_ucx: add ompi datatype attribute to release ucp_datatype
Signed-off-by: Yossi Itigin <yosefe@mellanox.com>
1 parent b0e6d1f commit 4763822

File tree

4 files changed

+96
-15
lines changed

4 files changed

+96
-15
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "opal/runtime/opal.h"
1818
#include "opal/mca/pmix/pmix.h"
19+
#include "ompi/attribute/attribute.h"
1920
#include "ompi/message/message.h"
2021
#include "ompi/mca/pml/base/pml_base_bsend.h"
2122
#include "opal/mca/common/ucx/common_ucx.h"
@@ -190,9 +191,9 @@ int mca_pml_ucx_close(void)
190191
int mca_pml_ucx_init(void)
191192
{
192193
ucp_worker_params_t params;
193-
ucs_status_t status;
194194
ucp_worker_attr_t attr;
195-
int rc;
195+
ucs_status_t status;
196+
int i, rc;
196197

197198
PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
198199

@@ -209,30 +210,34 @@ int mca_pml_ucx_init(void)
209210
&ompi_pml_ucx.ucp_worker);
210211
if (UCS_OK != status) {
211212
PML_UCX_ERROR("Failed to create UCP worker");
212-
return OMPI_ERROR;
213+
rc = OMPI_ERROR;
214+
goto err;
213215
}
214216

215217
attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
216218
status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr);
217219
if (UCS_OK != status) {
218-
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
219-
ompi_pml_ucx.ucp_worker = NULL;
220220
PML_UCX_ERROR("Failed to query UCP worker thread level");
221-
return OMPI_ERROR;
221+
rc = OMPI_ERROR;
222+
goto err_destroy_worker;
222223
}
223224

224-
if (ompi_mpi_thread_multiple && attr.thread_mode != UCS_THREAD_MODE_MULTI) {
225+
if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) {
225226
/* UCX does not support multithreading, disqualify current PML for now */
226227
/* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
227-
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
228-
ompi_pml_ucx.ucp_worker = NULL;
229228
PML_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE");
230-
return OMPI_ERROR;
229+
rc = OMPI_ERR_NOT_SUPPORTED;
230+
goto err_destroy_worker;
231231
}
232232

233233
rc = mca_pml_ucx_send_worker_address();
234234
if (rc < 0) {
235-
return rc;
235+
goto err_destroy_worker;
236+
}
237+
238+
ompi_pml_ucx.datatype_attr_keyval = MPI_KEYVAL_INVALID;
239+
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
240+
ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
236241
}
237242

238243
/* Initialize the free lists */
@@ -249,14 +254,33 @@ int mca_pml_ucx_init(void)
249254
(void *)ompi_pml_ucx.ucp_context,
250255
(void *)ompi_pml_ucx.ucp_worker);
251256
return OMPI_SUCCESS;
257+
258+
err_destroy_worker:
259+
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
260+
ompi_pml_ucx.ucp_worker = NULL;
261+
err:
262+
return OMPI_ERROR;
252263
}
253264

254265
int mca_pml_ucx_cleanup(void)
255266
{
267+
int i;
268+
256269
PML_UCX_VERBOSE(1, "mca_pml_ucx_cleanup");
257270

258271
opal_progress_unregister(mca_pml_ucx_progress);
259272

273+
if (ompi_pml_ucx.datatype_attr_keyval != MPI_KEYVAL_INVALID) {
274+
ompi_attr_free_keyval(TYPE_ATTR, &ompi_pml_ucx.datatype_attr_keyval, false);
275+
}
276+
277+
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
278+
if (ompi_pml_ucx.predefined_types[i] != PML_UCX_DATATYPE_INVALID) {
279+
ucp_dt_destroy(ompi_pml_ucx.predefined_types[i]);
280+
ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
281+
}
282+
}
283+
260284
ompi_pml_ucx.completed_send_req.req_state = OMPI_REQUEST_INVALID;
261285
OMPI_REQUEST_FINI(&ompi_pml_ucx.completed_send_req);
262286
OBJ_DESTRUCT(&ompi_pml_ucx.completed_send_req);
@@ -398,6 +422,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
398422

399423
int mca_pml_ucx_enable(bool enable)
400424
{
425+
ompi_attribute_fn_ptr_union_t copy_fn;
426+
ompi_attribute_fn_ptr_union_t del_fn;
427+
int ret;
428+
429+
/* Create a key for adding custom attributes to datatypes */
430+
copy_fn.attr_datatype_copy_fn =
431+
(MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN;
432+
del_fn.attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn;
433+
ret = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn,
434+
&ompi_pml_ucx.datatype_attr_keyval, NULL, 0,
435+
NULL);
436+
if (ret != OMPI_SUCCESS) {
437+
PML_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", ret);
438+
return ret;
439+
}
440+
401441
PML_UCX_FREELIST_INIT(&ompi_pml_ucx.persistent_reqs,
402442
mca_pml_ucx_persistent_request_t,
403443
128, -1, 128);

ompi/mca/pml/ucx/pml_ucx.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ompi/mca/pml/pml.h"
1616
#include "ompi/mca/pml/base/base.h"
1717
#include "ompi/datatype/ompi_datatype.h"
18+
#include "ompi/datatype/ompi_datatype_internal.h"
1819
#include "ompi/communicator/communicator.h"
1920
#include "ompi/request/request.h"
2021
#include "opal/mca/common/ucx/common_ucx.h"
@@ -42,6 +43,10 @@ struct mca_pml_ucx_module {
4243
ucp_context_h ucp_context;
4344
ucp_worker_h ucp_worker;
4445

46+
/* Datatypes */
47+
int datatype_attr_keyval;
48+
ucp_datatype_t predefined_types[OMPI_DATATYPE_MPI_MAX_PREDEFINED];
49+
4550
/* Requests */
4651
mca_pml_ucx_freelist_t persistent_reqs;
4752
ompi_request_t completed_send_req;

ompi/mca/pml/ucx/pml_ucx_datatype.c

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "pml_ucx_datatype.h"
1111

1212
#include "ompi/runtime/mpiruntime.h"
13+
#include "ompi/attribute/attribute.h"
1314

1415
#include <inttypes.h>
1516

@@ -127,12 +128,25 @@ static ucp_generic_dt_ops_t pml_ucx_generic_datatype_ops = {
127128
.finish = pml_ucx_generic_datatype_finish
128129
};
129130

131+
int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
132+
void *attr_val, void *extra)
133+
{
134+
ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val;
135+
136+
PML_UCX_ASSERT((void*)ucp_datatype == datatype->pml_data);
137+
138+
ucp_dt_destroy(ucp_datatype);
139+
datatype->pml_data = PML_UCX_DATATYPE_INVALID;
140+
return OMPI_SUCCESS;
141+
}
142+
130143
ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
131144
{
132145
ucp_datatype_t ucp_datatype;
133146
ucs_status_t status;
134147
ptrdiff_t lb;
135148
size_t size;
149+
int ret;
136150

137151
ompi_datatype_type_lb(datatype, &lb);
138152

@@ -147,16 +161,33 @@ ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
147161
}
148162

149163
status = ucp_dt_create_generic(&pml_ucx_generic_datatype_ops,
150-
datatype, &ucp_datatype);
164+
datatype, &ucp_datatype);
151165
if (status != UCS_OK) {
152166
PML_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name);
153167
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
154168
}
155169

170+
datatype->pml_data = ucp_datatype;
171+
172+
/* Add custom attribute, to clean up UCX resources when OMPI datatype is
173+
* released.
174+
*/
175+
if (ompi_datatype_is_predefined(datatype)) {
176+
PML_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED);
177+
ompi_pml_ucx.predefined_types[datatype->id] = ucp_datatype;
178+
} else {
179+
ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash,
180+
ompi_pml_ucx.datatype_attr_keyval,
181+
(void*)ucp_datatype, false);
182+
if (ret != OMPI_SUCCESS) {
183+
PML_UCX_ERROR("Failed to add UCX datatype attribute for %s: %d",
184+
datatype->name, ret);
185+
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
186+
}
187+
}
188+
156189
PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype)
157-
// TODO put this on a list to be destroyed later
158190

159-
datatype->pml_data = ucp_datatype;
160191
return ucp_datatype;
161192
}
162193

ompi/mca/pml/ucx/pml_ucx_datatype.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "pml_ucx.h"
1414

1515

16+
#define PML_UCX_DATATYPE_INVALID 0
17+
1618
struct pml_ucx_convertor {
1719
opal_free_list_item_t super;
1820
ompi_datatype_t *datatype;
@@ -23,14 +25,17 @@ struct pml_ucx_convertor {
2325

2426
ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype);
2527

28+
int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
29+
void *attr_val, void *extra);
30+
2631
OBJ_CLASS_DECLARATION(mca_pml_ucx_convertor_t);
2732

2833

2934
static inline ucp_datatype_t mca_pml_ucx_get_datatype(ompi_datatype_t *datatype)
3035
{
3136
ucp_datatype_t ucp_type = datatype->pml_data;
3237

33-
if (OPAL_LIKELY(ucp_type != 0)) {
38+
if (OPAL_LIKELY(ucp_type != PML_UCX_DATATYPE_INVALID)) {
3439
return ucp_type;
3540
}
3641

0 commit comments

Comments
 (0)