Skip to content

Commit 9a7b1de

Browse files
[SYCLomatic][nvSHMEM][QUDA] Added migration support for 4 nvSHMEM APIs (#2841)
* nvshmem_putmem_nbi * nvshmemx_signal_op * nvshmem_signal_wait_until * nvshmem_putmem_signal_nbi
1 parent 8d96283 commit 9a7b1de

File tree

18 files changed

+402
-260
lines changed

18 files changed

+402
-260
lines changed

clang/lib/DPCT/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ set(RUNTIME_HEADERS
2626
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/lapack_utils.hpp
2727
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/group_utils.hpp
2828
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
29+
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/shmem_utils.hpp
2930
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/compat_service.hpp
3031
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/algorithm.h
3132
${CMAKE_SOURCE_DIR}/../clang/runtime/dpct-rt/include/dpct/dpl_extras/functional.h
@@ -78,6 +79,7 @@ set(PROCESS_FILES_OUTPUT
7879
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/lapack_utils.hpp.inc
7980
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/group_utils.hpp.inc
8081
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/blas_gemm_utils.hpp.inc
82+
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/shmem_utils.hpp.inc
8183
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/compat_service.hpp.inc
8284
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/detail/group_utils_detail.hpp.inc
8385
${CMAKE_BINARY_DIR}/tools/clang/include/clang/DPCT/detail/math_detail.hpp.inc

clang/lib/DPCT/FileGenerator/GenHelperFunction.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ const std::string GroupUtilsAllContentStr =
8282
const std::string BlasGemmUtilsAllContentStr =
8383
#include "clang/DPCT/blas_gemm_utils.hpp.inc"
8484
;
85+
const std::string ShmemUtilsAllContentStr =
86+
#include "clang/DPCT/shmem_utils.hpp.inc"
87+
;
8588
const std::string CompatServiceAllContentStr =
8689
#include "clang/DPCT/compat_service.hpp.inc"
8790
;
@@ -217,6 +220,7 @@ void genHelperFunction(const clang::tooling::UnifiedPath &OutRoot) {
217220
GENERATE_ALL_FILE_CONTENT(LapackUtils, ".", lapack_utils.hpp)
218221
GENERATE_ALL_FILE_CONTENT(GroupUtils, ".", group_utils.hpp)
219222
GENERATE_ALL_FILE_CONTENT(BlasGemmUtils, ".", blas_gemm_utils.hpp)
223+
GENERATE_ALL_FILE_CONTENT(ShmemUtils, ".", shmem_utils.hpp)
220224
GENERATE_ALL_FILE_CONTENT(CompatService, ".", compat_service.hpp)
221225
GENERATE_ALL_FILE_CONTENT(GroupUtilsDetail, "detail", group_utils_detail.hpp)
222226
GENERATE_ALL_FILE_CONTENT(MathDetail, "detail", math_detail.hpp)

clang/lib/DPCT/RuleInfra/MapNames.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,14 @@ void MapNames::setExplicitNamespaceMap(
922922
{"cudaExternalSemaphoreHandleType",
923923
std::make_shared<TypeNameRule>(getExpNamespace() +
924924
"external_semaphore_handle_type")},
925+
// nvSHMEM
925926
{"nvshmem_team_t", std::make_shared<TypeNameRule>("ishmem_team_t")},
926927
{"nvshmem_team_config_t",
927928
std::make_shared<TypeNameRule>("ishmem_team_config_t")},
928929
{"nvshmemx_init_attr_t",
929930
std::make_shared<TypeNameRule>("ishmemx_attr_t")},
931+
{"nvshmemi_amo_t", std::make_shared<TypeNameRule>("int")},
932+
{"nvshmemi_cmp_type", std::make_shared<TypeNameRule>("int")},
930933
// ...
931934
};
932935
// SYCLcompat unsupport types
@@ -1660,16 +1663,29 @@ void MapNames::setExplicitNamespaceMap(
16601663
? getExpNamespace() +
16611664
"external_semaphore_handle_type::timeline_win32_nt_handle"
16621665
: "cudaExternalSemaphoreHandleTypeTimelineSemaphoreWin32")},
1666+
// nvSHMEM
16631667
{"NVSHMEM_TEAM_WORLD",
16641668
std::make_shared<EnumNameRule>("ISHMEM_TEAM_WORLD")},
16651669
{"NVSHMEM_TEAM_SHARED",
16661670
std::make_shared<EnumNameRule>("ISHMEM_TEAM_SHARED")},
16671671
{"NVSHMEM_TEAM_INVALID",
16681672
std::make_shared<EnumNameRule>("ISHMEM_TEAM_INVALID")},
16691673
{"NVSHMEMX_INIT_WITH_MPI_COMM",
1670-
std::make_shared<EnumNameRule>("ISHMEMX_RUNTIME_MPI")},
1674+
std::make_shared<EnumNameRule>(MapNames::getDpctNamespace() +
1675+
"shmemx::RUNTIME_MPI")},
16711676
{"NVSHMEMX_INIT_WITH_SHMEM",
1672-
std::make_shared<EnumNameRule>("ISHMEMX_RUNTIME_OPENSHMEM")},
1677+
std::make_shared<EnumNameRule>(MapNames::getDpctNamespace() +
1678+
"shmemx::RUNTIME_OPENSHMEM")},
1679+
{"NVSHMEM_SIGNAL_SET",
1680+
std::make_shared<EnumNameRule>("ISHMEM_SIGNAL_SET")},
1681+
{"NVSHMEM_SIGNAL_ADD",
1682+
std::make_shared<EnumNameRule>("ISHMEM_SIGNAL_ADD")},
1683+
{"NVSHMEM_CMP_EQ", std::make_shared<EnumNameRule>("ISHMEM_CMP_EQ")},
1684+
{"NVSHMEM_CMP_NE", std::make_shared<EnumNameRule>("ISHMEM_CMP_NE")},
1685+
{"NVSHMEM_CMP_GT", std::make_shared<EnumNameRule>("ISHMEM_CMP_GT")},
1686+
{"NVSHMEM_CMP_GE", std::make_shared<EnumNameRule>("ISHMEM_CMP_GE")},
1687+
{"NVSHMEM_CMP_LT", std::make_shared<EnumNameRule>("ISHMEM_CMP_LT")},
1688+
{"NVSHMEM_CMP_LE", std::make_shared<EnumNameRule>("ISHMEM_CMP_LE")},
16731689
// ...
16741690
};
16751691

clang/lib/DPCT/RulesInclude/HeaderTypes.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ DPCT_HEADER(Atomic, "<dpct/atomic.hpp>")
7474
DPCT_HEADER(SPBLAS_Utils, "<dpct/sparse_utils.hpp>")
7575
DPCT_HEADER(Math, "<dpct/math.hpp>")
7676
DPCT_HEADER(BLAS_GEMM_Utils, "<dpct/blas_gemm_utils.hpp>")
77+
DPCT_HEADER(SHMEM_Utils, "<dpct/shmem_utils.hpp>")
7778
DPCT_HEADER(CodePin_SYCL, "<dpct/codepin/codepin.hpp>")
7879
DPCT_HEADER(CodePin_CUDA, "<dpct/codepin/codepin.hpp>")
7980
DPCT_HEADER(Graph, "<dpct/graph.hpp>")

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
360360
"cudaGraphicsRegisterFlags", "cudaExternalMemoryHandleType",
361361
"cudaExternalSemaphoreHandleType", "CUstreamCallback",
362362
"cudaHostFn_t", "cudaGraphNodeType", "CUsurfref",
363-
"CUdevice_P2PAttribute", "cudaIpcMemHandle_t"))))))
363+
"CUdevice_P2PAttribute", "cudaIpcMemHandle_t", "nvshmemi_amo_t",
364+
"nvshmemi_cmp_type"))))))
364365
.bind("cudaTypeDef"),
365366
this);
366367

clang/lib/DPCT/RulesSHMEM/APINamesNvshmem.inc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
1111
CALL_FACTORY_ENTRY("nvshmem_init", CALL("ishmem_init")))
1212

13-
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
14-
CALL_FACTORY_ENTRY("nvshmemx_init_attr",
15-
CALL("ishmemx_init_attr", ARG(1))))
13+
HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_SHMEM_Utils,
14+
FEATURE_REQUEST_FACTORY(
15+
HelperFeatureEnum::device_ext,
16+
CALL_FACTORY_ENTRY("nvshmemx_init_attr",
17+
CALL(MapNames::getDpctNamespace() +
18+
"shmemx::init_attr",
19+
ARG(0), ARG(1)))))
1620

1721
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
1822
CALL_FACTORY_ENTRY("nvshmem_my_pe",
@@ -91,3 +95,30 @@ FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
9195
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
9296
CALL_FACTORY_ENTRY("nvshmem_team_destroy",
9397
CALL("ishmem_team_destroy", ARG(0))))
98+
99+
// Nonblocking RMA
100+
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
101+
CALL_FACTORY_ENTRY("nvshmem_putmem_nbi",
102+
CALL("ishmem_putmem_nbi", ARG(0),
103+
ARG(1), ARG(2), ARG(3))))
104+
105+
// Signalling Operations
106+
HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_SHMEM_Utils,
107+
FEATURE_REQUEST_FACTORY(
108+
HelperFeatureEnum::device_ext,
109+
CALL_FACTORY_ENTRY("nvshmemx_signal_op",
110+
CALL(MapNames::getDpctNamespace() +
111+
"shmemx::signal_op",
112+
ARG(0), ARG(1), ARG(2),
113+
ARG(3)))))
114+
115+
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
116+
CALL_FACTORY_ENTRY("nvshmem_signal_wait_until",
117+
CALL("ishmem_signal_wait_until",
118+
ARG(0), ARG(1), ARG(2))))
119+
120+
FEATURE_REQUEST_FACTORY(HelperFeatureEnum::device_ext,
121+
CALL_FACTORY_ENTRY("nvshmem_putmem_signal_nbi",
122+
CALL("ishmem_putmem_signal_nbi",
123+
ARG(0), ARG(1), ARG(2), ARG(3),
124+
ARG(4), ARG(5), ARG(6))))

clang/lib/DPCT/RulesSHMEM/NVSHMEMAPIMigration.cpp

Lines changed: 14 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using namespace clang::dpct;
1313
using namespace clang::ast_matchers;
1414

1515
void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF) {
16-
auto nvSHMEM_API = [&]() {
16+
auto NvshmemAPI = [&]() {
1717
return hasAnyName(
1818
// Library Setup, Exit & Query
1919
"nvshmem_init", "nvshmem_my_pe", "nvshmem_n_pes", "nvshmem_finalize",
@@ -24,7 +24,12 @@ void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF) {
2424
// Team Management
2525
"nvshmem_team_my_pe", "nvshmem_team_n_pes", "nvshmem_team_get_config",
2626
"nvshmem_team_translate_pe", "nvshmem_team_split_strided",
27-
"nvshmem_team_split_2d", "nvshmem_team_destroy");
27+
"nvshmem_team_split_2d", "nvshmem_team_destroy",
28+
// Nonblocking RMA
29+
"nvshmem_putmem_nbi",
30+
// Signalling Operations
31+
"nvshmemx_signal_op", "nvshmem_signal_wait_until",
32+
"nvshmem_putmem_signal_nbi");
2833
};
2934

3035
MF.addMatcher(typeLoc(loc(qualType(hasDeclaration(namedDecl(hasAnyName(
@@ -41,15 +46,18 @@ void clang::dpct::NVSHMEMRule::registerMatcher(ast_matchers::MatchFinder &MF) {
4146
.bind("memberAccess"),
4247
this);
4348

44-
MF.addMatcher(callExpr(callee(functionDecl(nvSHMEM_API()))).bind("call"),
49+
MF.addMatcher(callExpr(callee(functionDecl(NvshmemAPI()))).bind("call"),
4550
this);
4651

4752
MF.addMatcher(
4853
declRefExpr(to(enumConstantDecl(hasAnyName(
4954
"NVSHMEM_TEAM_WORLD", "NVSHMEM_TEAM_INVALID",
5055
"NVSHMEM_TEAM_SHARED", "NVSHMEMX_INIT_WITH_MPI_COMM",
51-
"NVSHMEMX_INIT_WITH_SHMEM"))))
52-
.bind("enum"),
56+
"NVSHMEMX_INIT_WITH_SHMEM", "NVSHMEM_SIGNAL_SET",
57+
"NVSHMEM_SIGNAL_ADD", "NVSHMEM_CMP_EQ", "NVSHMEM_CMP_NE",
58+
"NVSHMEM_CMP_GT", "NVSHMEM_CMP_GE", "NVSHMEM_CMP_LT",
59+
"NVSHMEM_CMP_LE"))))
60+
.bind("enumConstant"),
5361
this);
5462
}
5563

@@ -93,94 +101,9 @@ void clang::dpct::NVSHMEMRule::runRule(
93101

94102
EA.analyze(*TL);
95103
} else if (const CallExpr *CE = getNodeAsType<CallExpr>(Result, "call")) {
96-
std::string FuncName = "";
97-
const FunctionDecl *FD = CE->getDirectCallee();
98-
if (FD) {
99-
FuncName = FD->getNameInfo().getName().getAsString();
100-
}
101-
102-
if (!FuncName.empty()) {
103-
if (FuncName == "nvshmemx_init_attr") {
104-
// Get function arguments
105-
std::string nvshmem_rt = "";
106-
std::string nvshmem_init_rt = "";
107-
std::string attr_arg = "";
108-
109-
// Get the first argument's data
110-
const Expr *Arg0 = CE->getArg(0);
111-
112-
// Binary op on first argument is not supported
113-
if (dyn_cast<BinaryOperator>(Arg0->IgnoreImpCasts())) {
114-
report(CE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false,
115-
FuncName);
116-
return;
117-
}
118-
119-
// Get first argument's init value
120-
if (auto DRE = dyn_cast<DeclRefExpr>(Arg0->IgnoreImpCasts())) {
121-
if (const VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
122-
if (VD->hasInit()) {
123-
// get the init value from definition
124-
if (auto Init =
125-
dyn_cast<DeclRefExpr>(VD->getInit()->IgnoreImplicit())) {
126-
nvshmem_init_rt = Init->getNameInfo().getName().getAsString();
127-
}
128-
}
129-
}
130-
}
131-
132-
// Get the first argument as string
133-
nvshmem_rt = Lexer::getSourceText(
134-
CharSourceRange::getTokenRange(Arg0->getSourceRange()),
135-
DpctGlobalInfo::getSourceManager(), LangOptions());
136-
137-
if (nvshmem_rt == "0" || nvshmem_init_rt == "0") {
138-
emplaceTransformation(new ReplaceStmt(CE, "ishmem_init()"));
139-
return;
140-
}
141-
142-
std::string ishmem_rt = "";
143-
if (nvshmem_rt == "NVSHMEMX_INIT_WITH_MPI_COMM") {
144-
ishmem_rt = "ISHMEMX_RUNTIME_MPI";
145-
} else if (nvshmem_rt == "NVSHMEMX_INIT_WITH_SHMEM") {
146-
ishmem_rt = "ISHMEMX_RUNTIME_OPENSHMEM";
147-
}
148-
149-
if (nvshmem_init_rt == "NVSHMEMX_INIT_WITH_MPI_COMM" ||
150-
nvshmem_init_rt == "NVSHMEMX_INIT_WITH_SHMEM") {
151-
ishmem_rt = "static_cast<ishmemx_runtime_type_t>(" + nvshmem_rt + ")";
152-
}
153-
154-
// Get the second argument as string
155-
attr_arg = Lexer::getSourceText(
156-
CharSourceRange::getTokenRange(CE->getArg(1)->getSourceRange()),
157-
DpctGlobalInfo::getSourceManager(), LangOptions());
158-
159-
if (ishmem_rt.empty()) {
160-
report(CE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false,
161-
FuncName);
162-
return;
163-
} else {
164-
auto &SM = DpctGlobalInfo::getSourceManager();
165-
166-
auto IndentLoc = CE->getBeginLoc();
167-
if (IndentLoc.isMacroID())
168-
IndentLoc = SM.getExpansionLoc(IndentLoc);
169-
170-
std::string set_ishmem_runtime =
171-
"(" + attr_arg + ")->runtime = " + ishmem_rt + ";";
172-
set_ishmem_runtime += getNL();
173-
set_ishmem_runtime += getIndent(IndentLoc, SM).str();
174-
175-
emplaceTransformation(
176-
new InsertBeforeStmt(CE, std::move(set_ishmem_runtime)));
177-
}
178-
}
179-
}
180-
181104
EA.analyze(CE);
182105
} else if (const DeclRefExpr *DRE =
183-
getNodeAsType<DeclRefExpr>(Result, "enum")) {
106+
getNodeAsType<DeclRefExpr>(Result, "enumConstant")) {
184107
EA.analyze(DRE);
185108
} else {
186109
return;

clang/lib/DPCT/SrcAPI/APINames_nvSHMEM.inc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ ENTRY(nvshmemx_put32_nbi_warp, nvshmemx_put32_nbi_warp, false, NO_FLAG, P4, "com
781781
ENTRY(nvshmemx_put64_nbi_warp, nvshmemx_put64_nbi_warp, false, NO_FLAG, P4, "comment")
782782
ENTRY(nvshmemx_put128_nbi_warp, nvshmemx_put128_nbi_warp, false, NO_FLAG, P4, "comment")
783783

784-
ENTRY(nvshmem_putmem_nbi, nvshmem_putmem_nbi, false, NO_FLAG, P4, "comment")
784+
ENTRY(nvshmem_putmem_nbi, nvshmem_putmem_nbi, true, NO_FLAG, P4, "Succeessful")
785785
ENTRY(nvshmemx_putmem_nbi_on_stream, nvshmemx_putmem_nbi_on_stream, false, NO_FLAG, P4, "comment")
786786
ENTRY(nvshmemx_putmem_nbi_block, nvshmemx_putmem_nbi_block, false, NO_FLAG, P4, "comment")
787787
ENTRY(nvshmemx_putmem_nbi_warp, nvshmemx_putmem_nbi_warp, false, NO_FLAG, P4, "comment")
@@ -1300,7 +1300,7 @@ ENTRY(nvshmemx_put32_signal_nbi_warp, nvshmemx_put32_signal_nbi_warp, false, NO_
13001300
ENTRY(nvshmemx_put64_signal_nbi_warp, nvshmemx_put64_signal_nbi_warp, false, NO_FLAG, P4, "comment")
13011301
ENTRY(nvshmemx_put128_signal_nbi_warp, nvshmemx_put128_signal_nbi_warp, false, NO_FLAG, P4, "comment")
13021302

1303-
ENTRY(nvshmem_putmem_signal_nbi, nvshmem_putmem_signal_nbi, false, NO_FLAG, P4, "comment")
1303+
ENTRY(nvshmem_putmem_signal_nbi, nvshmem_putmem_signal_nbi, true, NO_FLAG, P4, "Succeessful")
13041304
ENTRY(nvshmemx_putmem_signal_nbi_on_stream, nvshmemx_putmem_signal_nbi_on_stream, false, NO_FLAG, P4, "comment")
13051305
ENTRY(nvshmemx_putmem_signal_nbi_block, nvshmemx_putmem_signal_nbi_block, false, NO_FLAG, P4, "comment")
13061306
ENTRY(nvshmemx_putmem_signal_nbi_warp, nvshmemx_putmem_signal_nbi_warp, false, NO_FLAG, P4, "comment")
@@ -1323,7 +1323,7 @@ ENTRY(nvshmem_signal_fetch, nvshmem_signal_fetch, false, NO_FLAG, P4, "comment")
13231323
// ENTRY(nvshmemx_size_signal, nvshmemx_size_signal, false, NO_FLAG, P4, "comment")
13241324
// ENTRY(nvshmemx_ptrdiff_signal, nvshmemx_ptrdiff_signal, false, NO_FLAG, P4, "comment")
13251325

1326-
ENTRY(nvshmemx_signal_op, nvshmemx_signal_op, false, NO_FLAG, P4, "comment")
1326+
ENTRY(nvshmemx_signal_op, nvshmemx_signal_op, true, NO_FLAG, P4, "Succeessful")
13271327

13281328

13291329
// Collective Operations
@@ -2534,7 +2534,7 @@ ENTRY(nvshmem_size_test_some_vector, nvshmem_size_test_some_vector, false, NO_FL
25342534
ENTRY(nvshmem_ptrdiff_test_some_vector, nvshmem_ptrdiff_test_some_vector, false, NO_FLAG, P4, "comment")
25352535

25362536
ENTRY(nvshmemx_signal_wait_until_on_stream, nvshmemx_signal_wait_until_on_stream, false, NO_FLAG, P4, "comment")
2537-
ENTRY(nvshmem_signal_wait_until, nvshmem_signal_wait_until, false, NO_FLAG, P4, "comment")
2537+
ENTRY(nvshmem_signal_wait_until, nvshmem_signal_wait_until, true, NO_FLAG, P4, "Succeessful")
25382538

25392539

25402540
// Memory Ordering

clang/runtime/dpct-rt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ set(dpct_rt_files
2020
include/dpct/lapack_utils.hpp
2121
include/dpct/group_utils.hpp
2222
include/dpct/blas_gemm_utils.hpp
23+
include/dpct/shmem_utils.hpp
2324
include/dpct/graph.hpp
2425
include/dpct/compat_service.hpp
2526
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//==--------------- shmem_utils.hpp ---------------*- C++ -*----------------==//
2+
//
3+
// Copyright (C) Intel Corporation
4+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef __DPCT_SHMEM_UTILS_HPP__
10+
#define __DPCT_SHMEM_UTILS_HPP__
11+
12+
namespace dpct::shmemx {
13+
14+
enum flags { RUNTIME_MPI = 1 << 1, RUNTIME_OPENSHMEM = 1 << 2 };
15+
16+
/// Initialize Intel SHMEM library based on exisiting communicator
17+
/// \param [in] runtime_flags specify the underlying communicator like MPI/
18+
/// OpenSHMEM to initialize iSHMEM for launcher agnostic bootstrapping
19+
/// \param [in] attr Additional attributes for initializing the iSHMEM library.
20+
void init_attr(unsigned runtime_flags, ishmemx_attr_t *attr) {
21+
if (runtime_flags == 0) {
22+
// if no runtime flags are present, initialize iSHMEM normally
23+
ishmem_init();
24+
} else {
25+
unsigned ishmem_runtime_flags = 0;
26+
27+
if (runtime_flags & RUNTIME_MPI)
28+
ishmem_runtime_flags |= ISHMEMX_RUNTIME_MPI;
29+
if (runtime_flags & RUNTIME_OPENSHMEM)
30+
ishmem_runtime_flags |= ISHMEMX_RUNTIME_OPENSHMEM;
31+
32+
attr->runtime = static_cast<ishmemx_runtime_type_t>(ishmem_runtime_flags);
33+
ishmemx_init_attr(attr);
34+
}
35+
}
36+
37+
/// Update signal address with signal using signal operation
38+
/// \param [out] sig_addr Symmetric address of the signal data object to be
39+
/// updated on the remote PE.
40+
/// \param [in] signal Unsigned 64-bit value that is used for updating the
41+
/// remote sig_addr signal data object.
42+
/// \param [in] sig_op Operator used to update signal data object
43+
/// \param [in] pe Processing element
44+
void signal_op(uint64_t *sig_addr, uint64_t signal, int sig_op, int pe) {
45+
if (sig_op == ISHMEM_SIGNAL_SET) {
46+
ishmemx_signal_set(sig_addr, signal, pe);
47+
} else if (sig_op == ISHMEM_SIGNAL_ADD) {
48+
ishmemx_signal_add(sig_addr, signal, pe);
49+
} else {
50+
throw std::runtime_error("Unsupported signal operator!");
51+
}
52+
}
53+
54+
} // namespace dpct::shmemx
55+
56+
#endif // __DPCT_SHMEM_UTILS_HPP__

0 commit comments

Comments
 (0)