Skip to content

Commit 8807950

Browse files
authored
[SYCLomatic] Split BLAS rewriters into 6 cpp files (#2872)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 3bc4252 commit 8807950

10 files changed

+3052
-2257
lines changed

clang/lib/DPCT/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,12 @@ add_clang_library(DPCT
271271
RulesMathLib/RandomAPIMigration.cpp
272272
RulesMathLib/SolverAPIMigration.cpp
273273
CodePin/GenCodePinHeader.cpp
274+
RulesMathLib/CallExprRewriterCUBLASExt.cpp
275+
RulesMathLib/CallExprRewriterCUBLASHelper.cpp
276+
RulesMathLib/CallExprRewriterCUBLASLevel1.cpp
277+
RulesMathLib/CallExprRewriterCUBLASLevel2.cpp
278+
RulesMathLib/CallExprRewriterCUBLASLevel3.cpp
279+
RulesMathLib/CallExprRewriterCUBLASLt.cpp
274280

275281
DEPENDS
276282
ClangDriverOptions

clang/lib/DPCT/RulesMathLib/APINamesCUBLAS.inc

Lines changed: 0 additions & 2156 deletions
This file was deleted.

clang/lib/DPCT/RulesMathLib/CallExprRewriterCUBLAS.cpp

Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -6,112 +6,18 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "RuleInfra/CallExprRewriter.h"
10-
#include "RuleInfra/CallExprRewriterCommon.h"
11-
#include <string>
9+
#include "CallExprRewriterCUBLAS.h"
1210

1311
namespace clang {
1412
namespace dpct {
1513

16-
template <class ArgT> class BufferOrUSMPtrCallArgPrinter {
17-
ArgT Arg;
18-
std::string DataType;
19-
20-
public:
21-
BufferOrUSMPtrCallArgPrinter(ArgT &&Arg, std::string DataType)
22-
: Arg(std::forward<ArgT>(Arg)), DataType(DataType) {}
23-
template <class StreamT> void print(StreamT &Stream) const {
24-
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
25-
Stream << MapNames::getLibraryHelperNamespace() << "rvalue_ref_to_lvalue_ref("
26-
<< MapNames::getDpctNamespace() << "get_buffer<" << DataType
27-
<< ">(";
28-
clang::dpct::print(Stream, Arg);
29-
Stream << "))";
30-
} else {
31-
if (DataType == "std::complex<float>" ||
32-
DataType == "std::complex<double>")
33-
Stream << "(" << DataType << "*)";
34-
if constexpr (std::is_same_v<ArgT, const Expr *>)
35-
clang::dpct::print(Stream, Arg->IgnoreCasts());
36-
else
37-
clang::dpct::print(Stream, Arg);
38-
}
39-
}
40-
};
41-
42-
template <class ArgT>
43-
std::function<BufferOrUSMPtrCallArgPrinter<ArgT>(const CallExpr *)>
44-
makeBufferOrUSMPtrCallArgCreator(std::function<ArgT(const CallExpr *)> Arg,
45-
std::string DataType) {
46-
return PrinterCreator<BufferOrUSMPtrCallArgPrinter<ArgT>,
47-
std::function<ArgT(const CallExpr *)>,
48-
std::function<std::string(const CallExpr *)>>(
49-
Arg, [=](const CallExpr *) { return DataType; });
50-
}
51-
52-
class ScalarInputValuePrinter {
53-
const Expr *Arg;
54-
const Expr *Handle;
55-
std::string DataType;
56-
57-
public:
58-
ScalarInputValuePrinter(const Expr *&&Arg, const Expr *&&Handle,
59-
std::string DataType)
60-
: Arg(std::forward<const Expr *>(Arg)),
61-
Handle(std::forward<const Expr *>(Handle)), DataType(DataType) {}
62-
template <class StreamT> void print(StreamT &Stream) const {
63-
const auto *UO = dyn_cast_or_null<UnaryOperator>(Arg->IgnoreImpCasts());
64-
const auto *COCE = dyn_cast<CXXOperatorCallExpr>(Arg->IgnoreImpCasts());
65-
if ((UO && UO->getOpcode() == UO_AddrOf && UO->getSubExpr()) ||
66-
(COCE && COCE->getOperator() == OO_Amp && COCE->getArg(0))) {
67-
const Expr *Sub = UO ? UO->getSubExpr() : COCE->getArg(0);
68-
if (DataType == "std::complex<float>" ||
69-
DataType == "std::complex<double>") {
70-
Stream << DataType << "(";
71-
clang::dpct::print(Stream, Sub);
72-
Stream << ".x(), ";
73-
clang::dpct::print(Stream, Sub);
74-
Stream << ".y())";
75-
} else {
76-
clang::dpct::print(Stream, Sub);
77-
}
78-
} else {
79-
Stream << MapNames::getLibraryHelperNamespace() << "get_value(";
80-
clang::dpct::print(Stream, Arg);
81-
Stream << ", ";
82-
if (needExtraParensInMemberExpr(Handle)) {
83-
Stream << "(";
84-
clang::dpct::print(Stream, Handle);
85-
Stream << ")->get_queue())";
86-
} else {
87-
clang::dpct::print(Stream, Handle);
88-
Stream << "->get_queue())";
89-
}
90-
}
91-
}
92-
};
93-
94-
std::function<ScalarInputValuePrinter(const CallExpr *)>
95-
makeScalarInputValueCreator(
96-
std::function<const Expr *(const CallExpr *)> Arg,
97-
std::function<const Expr *(const CallExpr *)> Handle,
98-
std::string DataType) {
99-
return PrinterCreator<ScalarInputValuePrinter,
100-
std::function<const Expr *(const CallExpr *)>,
101-
std::function<const Expr *(const CallExpr *)>,
102-
std::function<std::string(const CallExpr *)>>(
103-
Arg, Handle, [=](const CallExpr *) { return DataType; });
104-
}
105-
106-
#define BUFFER_OR_USM_PTR(Arg, T) makeBufferOrUSMPtrCallArgCreator(Arg, T)
107-
#define SCALAR_INPUT(Arg, T) makeScalarInputValueCreator(Arg, ARG(0), T)
108-
10914
void CallExprRewriterFactoryBase::initRewriterMapCUBLAS() {
110-
RewriterMap->merge(
111-
std::unordered_map<std::string,
112-
std::shared_ptr<CallExprRewriterFactoryBase>>({
113-
#include "RulesMathLib/APINamesCUBLAS.inc"
114-
}));
15+
RewriterMap->merge(createCUBLASLevel1RewriterMap());
16+
RewriterMap->merge(createCUBLASLevel2RewriterMap());
17+
RewriterMap->merge(createCUBLASLevel3RewriterMap());
18+
RewriterMap->merge(createCUBLASHelperRewriterMap());
19+
RewriterMap->merge(createCUBLASExtRewriterMap());
20+
RewriterMap->merge(createCUBLASLtRewriterMap());
11521
}
11622

11723
} // namespace dpct
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
//===--------------- CallExprRewriterCUBLAS.h -----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef DPCT_REWRITERS_CALL_EXPR_REWRITER_CUBLAS_H
10+
#define DPCT_REWRITERS_CALL_EXPR_REWRITER_CUBLAS_H
11+
12+
#include "RuleInfra/CallExprRewriter.h"
13+
#include "RuleInfra/CallExprRewriterCommon.h"
14+
#include <string>
15+
16+
namespace clang {
17+
namespace dpct {
18+
19+
template <class ArgT> class BufferOrUSMPtrCallArgPrinter {
20+
ArgT Arg;
21+
std::string DataType;
22+
23+
public:
24+
BufferOrUSMPtrCallArgPrinter(ArgT &&Arg, std::string DataType)
25+
: Arg(std::forward<ArgT>(Arg)), DataType(DataType) {}
26+
template <class StreamT> void print(StreamT &Stream) const {
27+
if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_None) {
28+
Stream << MapNames::getLibraryHelperNamespace()
29+
<< "rvalue_ref_to_lvalue_ref(" << MapNames::getDpctNamespace()
30+
<< "get_buffer<" << DataType << ">(";
31+
clang::dpct::print(Stream, Arg);
32+
Stream << "))";
33+
} else {
34+
if (DataType == "std::complex<float>" ||
35+
DataType == "std::complex<double>")
36+
Stream << "(" << DataType << "*)";
37+
if constexpr (std::is_same_v<ArgT, const Expr *>)
38+
clang::dpct::print(Stream, Arg->IgnoreCasts());
39+
else
40+
clang::dpct::print(Stream, Arg);
41+
}
42+
}
43+
};
44+
45+
template <class ArgT>
46+
std::function<BufferOrUSMPtrCallArgPrinter<ArgT>(const CallExpr *)>
47+
makeBufferOrUSMPtrCallArgCreator(std::function<ArgT(const CallExpr *)> Arg,
48+
std::string DataType) {
49+
return PrinterCreator<BufferOrUSMPtrCallArgPrinter<ArgT>,
50+
std::function<ArgT(const CallExpr *)>,
51+
std::function<std::string(const CallExpr *)>>(
52+
Arg, [=](const CallExpr *) { return DataType; });
53+
}
54+
55+
class ScalarInputValuePrinter {
56+
const Expr *Arg;
57+
const Expr *Handle;
58+
std::string DataType;
59+
60+
public:
61+
ScalarInputValuePrinter(const Expr *&&Arg, const Expr *&&Handle,
62+
std::string DataType)
63+
: Arg(std::forward<const Expr *>(Arg)),
64+
Handle(std::forward<const Expr *>(Handle)), DataType(DataType) {}
65+
template <class StreamT> void print(StreamT &Stream) const {
66+
const auto *UO = dyn_cast_or_null<UnaryOperator>(Arg->IgnoreImpCasts());
67+
const auto *COCE = dyn_cast<CXXOperatorCallExpr>(Arg->IgnoreImpCasts());
68+
if ((UO && UO->getOpcode() == UO_AddrOf && UO->getSubExpr()) ||
69+
(COCE && COCE->getOperator() == OO_Amp && COCE->getArg(0))) {
70+
const Expr *Sub = UO ? UO->getSubExpr() : COCE->getArg(0);
71+
if (DataType == "std::complex<float>" ||
72+
DataType == "std::complex<double>") {
73+
Stream << DataType << "(";
74+
clang::dpct::print(Stream, Sub);
75+
Stream << ".x(), ";
76+
clang::dpct::print(Stream, Sub);
77+
Stream << ".y())";
78+
} else {
79+
clang::dpct::print(Stream, Sub);
80+
}
81+
} else {
82+
Stream << MapNames::getLibraryHelperNamespace() << "get_value(";
83+
clang::dpct::print(Stream, Arg);
84+
Stream << ", ";
85+
if (needExtraParensInMemberExpr(Handle)) {
86+
Stream << "(";
87+
clang::dpct::print(Stream, Handle);
88+
Stream << ")->get_queue())";
89+
} else {
90+
clang::dpct::print(Stream, Handle);
91+
Stream << "->get_queue())";
92+
}
93+
}
94+
}
95+
};
96+
97+
std::function<ScalarInputValuePrinter(
98+
const CallExpr
99+
*)> inline makeScalarInputValueCreator(std::
100+
function<const Expr *(
101+
const CallExpr *)>
102+
Arg,
103+
std::function<const Expr *(
104+
const CallExpr *)>
105+
Handle,
106+
std::string DataType) {
107+
return PrinterCreator<ScalarInputValuePrinter,
108+
std::function<const Expr *(const CallExpr *)>,
109+
std::function<const Expr *(const CallExpr *)>,
110+
std::function<std::string(const CallExpr *)>>(
111+
Arg, Handle, [=](const CallExpr *) { return DataType; });
112+
}
113+
114+
#define BUFFER_OR_USM_PTR(Arg, T) makeBufferOrUSMPtrCallArgCreator(Arg, T)
115+
#define SCALAR_INPUT(Arg, T) makeScalarInputValueCreator(Arg, ARG(0), T)
116+
117+
typedef std::unordered_map<std::string,
118+
std::shared_ptr<CallExprRewriterFactoryBase>>
119+
RewriterMap;
120+
121+
RewriterMap createCUBLASLevel1RewriterMap();
122+
RewriterMap createCUBLASLevel2RewriterMap();
123+
RewriterMap createCUBLASLevel3RewriterMap();
124+
RewriterMap createCUBLASHelperRewriterMap();
125+
RewriterMap createCUBLASExtRewriterMap();
126+
RewriterMap createCUBLASLtRewriterMap();
127+
128+
} // namespace dpct
129+
} // namespace clang
130+
131+
#endif // DPCT_REWRITERS_CALL_EXPR_REWRITER_CUBLAS_H

0 commit comments

Comments
 (0)