@@ -846,11 +846,17 @@ std::string CubRule::getOpRepl(const Expr *Operator) {
846
846
auto processOperatorExpr = [&](const Expr *Obj) {
847
847
std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName (
848
848
Obj->getType ().getCanonicalType ());
849
- if (OpType == " cub::Sum" || OpType == " cuda::std::plus<void>" ) {
849
+ if (OpType.find (" cub::Sum" ) != std::string::npos ||
850
+ OpType.find (" cuda::std::plus" ) != std::string::npos ||
851
+ OpType.find (" thrust::plus" ) != std::string::npos) {
850
852
OpRepl = MapNames::getClNamespace () + " plus<>()" ;
851
- } else if (OpType == " cub::Max" || OpType == " cuda::maximum<void>" ) {
853
+ } else if (OpType.find (" cub::Max" ) != std::string::npos ||
854
+ OpType.find (" cuda::maximum" ) != std::string::npos ||
855
+ OpType.find (" thrust::maximum" ) != std::string::npos) {
852
856
OpRepl = MapNames::getClNamespace () + " maximum<>()" ;
853
- } else if (OpType == " cub::Min" || OpType == " cuda::minimum<void>" ) {
857
+ } else if (OpType.find (" cub::Min" ) != std::string::npos ||
858
+ OpType.find (" cuda::minimum" ) != std::string::npos ||
859
+ OpType.find (" thrust::minimum" ) != std::string::npos) {
854
860
OpRepl = MapNames::getClNamespace () + " minimum<>()" ;
855
861
}
856
862
};
@@ -861,17 +867,21 @@ std::string CubRule::getOpRepl(const Expr *Operator) {
861
867
} else {
862
868
auto CtorArg = Op->getArg (0 )->IgnoreImplicitAsWritten ();
863
869
if (auto DRE = dyn_cast<DeclRefExpr>(CtorArg)) {
864
- auto D = DRE->getDecl ();
865
- if (!D)
866
- return OpRepl;
867
- std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName (
868
- D->getType ().getCanonicalType ());
869
- if (OpType == " cub::Sum" || OpType == " cub::Max" ||
870
- OpType == " cub::Min" || OpType == " cuda::std::plus<void>" ||
871
- OpType == " cuda::maximum<void>" ||
872
- OpType == " cuda::minimum<void>" ) {
873
- ExprAnalysis EA (Operator);
874
- OpRepl = EA.getReplacedString ();
870
+ if (auto D = DRE->getDecl ()) {
871
+ std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName (
872
+ D->getType ().getCanonicalType ());
873
+ if (OpType.find (" cub::Sum" ) != std::string::npos ||
874
+ OpType.find (" cub::Max" ) != std::string::npos ||
875
+ OpType.find (" cub::Min" ) != std::string::npos ||
876
+ OpType.find (" cuda::std::plus" ) != std::string::npos ||
877
+ OpType.find (" cuda::maximum" ) != std::string::npos ||
878
+ OpType.find (" cuda::minimum" ) != std::string::npos ||
879
+ OpType.find (" thrust::plus" ) != std::string::npos ||
880
+ OpType.find (" thrust::maximum" ) != std::string::npos ||
881
+ OpType.find (" thrust::minimum" ) != std::string::npos) {
882
+ ExprAnalysis EA (Operator);
883
+ OpRepl = EA.getReplacedString ();
884
+ }
875
885
}
876
886
} else if (auto CXXTempObj = dyn_cast<CXXTemporaryObjectExpr>(CtorArg)) {
877
887
processOperatorExpr (CXXTempObj);
0 commit comments