|
| 1 | +/** |
| 2 | + * Copyright (c) 2021 Mellanox Technologies. All rights reserved. |
| 3 | + * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. |
| 4 | + * $COPYRIGHT$ |
| 5 | + * |
| 6 | + * Additional copyrights may follow |
| 7 | + * |
| 8 | + */ |
| 9 | + |
| 10 | +#include "coll_ucc_common.h" |
| 11 | + |
| 12 | +static inline |
| 13 | +ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, int scount, |
| 14 | + struct ompi_datatype_t *sdtype, |
| 15 | + void *rbuf, int rcount, |
| 16 | + struct ompi_datatype_t *rdtype, int root, |
| 17 | + mca_coll_ucc_module_t *ucc_module, |
| 18 | + ucc_coll_req_h *req, |
| 19 | + mca_coll_ucc_req_t *coll_req) |
| 20 | +{ |
| 21 | + ucc_datatype_t ucc_sdt, ucc_rdt; |
| 22 | + int comm_rank = ompi_comm_rank(ucc_module->comm); |
| 23 | + int comm_size = ompi_comm_size(ucc_module->comm); |
| 24 | + |
| 25 | + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); |
| 26 | + if (comm_rank == root) { |
| 27 | + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); |
| 28 | + if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) || |
| 29 | + (MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) { |
| 30 | + UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", |
| 31 | + (COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ? |
| 32 | + sdtype->super.name : rdtype->super.name); |
| 33 | + goto fallback; |
| 34 | + } |
| 35 | + } else { |
| 36 | + if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { |
| 37 | + UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", |
| 38 | + rdtype->super.name); |
| 39 | + goto fallback; |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + ucc_coll_args_t coll = { |
| 44 | + .mask = 0, |
| 45 | + .coll_type = UCC_COLL_TYPE_SCATTER, |
| 46 | + .root = root, |
| 47 | + .src.info = { |
| 48 | + .buffer = (void*)sbuf, |
| 49 | + .count = ((size_t)scount) * comm_size, |
| 50 | + .datatype = ucc_sdt, |
| 51 | + .mem_type = UCC_MEMORY_TYPE_UNKNOWN |
| 52 | + }, |
| 53 | + .dst.info = { |
| 54 | + .buffer = (void*)rbuf, |
| 55 | + .count = rcount, |
| 56 | + .datatype = ucc_rdt, |
| 57 | + .mem_type = UCC_MEMORY_TYPE_UNKNOWN |
| 58 | + }, |
| 59 | + }; |
| 60 | + |
| 61 | + if (MPI_IN_PLACE == rbuf) { |
| 62 | + coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; |
| 63 | + coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; |
| 64 | + } |
| 65 | + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); |
| 66 | + return UCC_OK; |
| 67 | +fallback: |
| 68 | + return UCC_ERR_NOT_SUPPORTED; |
| 69 | +} |
| 70 | + |
| 71 | +int mca_coll_ucc_scatter(const void *sbuf, int scount, |
| 72 | + struct ompi_datatype_t *sdtype, void *rbuf, int rcount, |
| 73 | + struct ompi_datatype_t *rdtype, int root, |
| 74 | + struct ompi_communicator_t *comm, |
| 75 | + mca_coll_base_module_t *module) |
| 76 | +{ |
| 77 | + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; |
| 78 | + ucc_coll_req_h req; |
| 79 | + |
| 80 | + UCC_VERBOSE(3, "running ucc scatter"); |
| 81 | + COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount, |
| 82 | + rdtype, root, ucc_module, &req, |
| 83 | + NULL)); |
| 84 | + COLL_UCC_POST_AND_CHECK(req); |
| 85 | + COLL_UCC_CHECK(coll_ucc_req_wait(req)); |
| 86 | + return OMPI_SUCCESS; |
| 87 | +fallback: |
| 88 | + UCC_VERBOSE(3, "running fallback scatter"); |
| 89 | + return ucc_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, |
| 90 | + rdtype, root, comm, |
| 91 | + ucc_module->previous_scatter_module); |
| 92 | + |
| 93 | +} |
| 94 | + |
| 95 | +int mca_coll_ucc_iscatter(const void *sbuf, int scount, |
| 96 | + struct ompi_datatype_t *sdtype, void *rbuf, int rcount, |
| 97 | + struct ompi_datatype_t *rdtype, int root, |
| 98 | + struct ompi_communicator_t *comm, |
| 99 | + ompi_request_t** request, |
| 100 | + mca_coll_base_module_t *module) |
| 101 | +{ |
| 102 | + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; |
| 103 | + ucc_coll_req_h req; |
| 104 | + mca_coll_ucc_req_t *coll_req = NULL; |
| 105 | + |
| 106 | + UCC_VERBOSE(3, "running ucc iscatter"); |
| 107 | + COLL_UCC_GET_REQ(coll_req); |
| 108 | + COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount, |
| 109 | + rdtype, root, ucc_module, &req, |
| 110 | + coll_req)); |
| 111 | + COLL_UCC_POST_AND_CHECK(req); |
| 112 | + *request = &coll_req->super; |
| 113 | + return OMPI_SUCCESS; |
| 114 | +fallback: |
| 115 | + UCC_VERBOSE(3, "running fallback iscatter"); |
| 116 | + if (coll_req) { |
| 117 | + mca_coll_ucc_req_free((ompi_request_t **)&coll_req); |
| 118 | + } |
| 119 | + return ucc_module->previous_iscatter(sbuf, scount, sdtype, rbuf, rcount, |
| 120 | + rdtype, root, comm, request, |
| 121 | + ucc_module->previous_iscatter_module); |
| 122 | +} |
0 commit comments