@@ -31,6 +31,8 @@ using WarpFunctionRewriterFactory =
31
31
CallExprRewriterFactory<WarpFunctionRewriter, std::string>;
32
32
using NoRewriteFuncNameRewriterFactory =
33
33
CallExprRewriterFactory<NoRewriteFuncNameRewriter, std::string>;
34
+ using EmptyRewriterFactory =
35
+ CallExprRewriterFactory<EmptyRewriter, std::string>;
34
36
35
37
// / Base class for rewriting math function calls
36
38
class MathCallExprRewriter : public FuncCallExprRewriter {
@@ -349,10 +351,9 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
349
351
: Name(Name), MathAPIRewriters(MathAPIRewritersInput) {
350
352
NoRewriteRewriter = std::make_pair (
351
353
TrueFunctor,
352
- std::make_pair (Name,
353
- std::dynamic_pointer_cast<CallExprRewriterFactoryBase>(
354
- std::make_shared<NoRewriteFuncNameRewriterFactory>(
355
- Name, Name))));
354
+ std::make_pair (
355
+ Name, std::dynamic_pointer_cast<CallExprRewriterFactoryBase>(
356
+ std::make_shared<EmptyRewriterFactory>(Name, Name))));
356
357
}
357
358
// a. Host API priority:
358
359
// 1. host_perf
@@ -561,22 +562,19 @@ createMathAPIRewriterDevice(
561
562
math::IsDefinedInCUDA (),
562
563
std::move (createMathAPIRewriterDeviceImpl (Name, PerfPred, DevicePerf,
563
564
DeviceNodes)),
564
- {Name,
565
- std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
565
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
566
566
createConditionalFactory (
567
567
math::IsUnresolvedLookupExpr,
568
568
createConditionalFactory (
569
569
math::IsDirectCallerPureDevice,
570
570
std::move (createMathAPIRewriterDeviceImpl (
571
571
Name, PerfPred, DevicePerf, DeviceNodes)),
572
- {Name,
573
- std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
572
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
574
573
createConditionalFactory (
575
574
math::IsDefinedInCUDA (),
576
575
std::move (createMathAPIRewriterDeviceImpl (
577
576
Name, PerfPred, DevicePerf, DeviceNodes)),
578
- {Name, std::make_shared<NoRewriteFuncNameRewriterFactory>(
579
- Name, Name)})));
577
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)})));
580
578
}
581
579
582
580
inline std::pair<std::string, std::shared_ptr<CallExprRewriterFactoryBase>>
@@ -590,21 +588,17 @@ createMathAPIRewriterDevice(
590
588
createConditionalFactory (
591
589
math::IsDefinedInCUDA (),
592
590
std::move (createMathAPIRewriterDeviceImpl (Name, DeviceNodes)),
593
- {Name,
594
- std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
591
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
595
592
createConditionalFactory (
596
593
math::IsUnresolvedLookupExpr,
597
594
createConditionalFactory (
598
595
math::IsDirectCallerPureDevice,
599
596
std::move (createMathAPIRewriterDeviceImpl (Name, DeviceNodes)),
600
- {Name,
601
- std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
597
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
602
598
createConditionalFactory (
603
599
math::IsDefinedInCUDA (),
604
- std::move (
605
- createMathAPIRewriterDeviceImpl (Name, DeviceNodes)),
606
- {Name, std::make_shared<NoRewriteFuncNameRewriterFactory>(
607
- Name, Name)})));
600
+ std::move (createMathAPIRewriterDeviceImpl (Name, DeviceNodes)),
601
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)})));
608
602
}
609
603
610
604
template <class T >
@@ -620,13 +614,11 @@ createMathAPIRewriterExperimentalBfloat16(
620
614
if (math::useExtBFloat16Math () && Rewriter1.second )
621
615
return createConditionalFactory (
622
616
math::IsDefinedInCUDA (), std::move (Rewriter1),
623
- {Name,
624
- std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)});
617
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)});
625
618
if (Rewriter2.second )
626
619
return createConditionalFactory (
627
620
math::IsDefinedInCUDA (), std::move (Rewriter2),
628
- {Name,
629
- std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)});
621
+ {Name, std::make_shared<EmptyRewriterFactory>(Name, Name)});
630
622
}
631
623
// report unsupport
632
624
return std::pair<std::string, std::shared_ptr<CallExprRewriterFactoryBase>>(
@@ -643,7 +635,7 @@ createMathAPIRewriterHost(
643
635
T) {
644
636
return createConditionalFactory (
645
637
math::IsDefinedInCUDA (), std::move (HostNormal),
646
- {Name, std::make_shared<NoRewriteFuncNameRewriterFactory >(Name, Name)});
638
+ {Name, std::make_shared<EmptyRewriterFactory >(Name, Name)});
647
639
}
648
640
649
641
template <class T >
@@ -660,7 +652,7 @@ createMathAPIRewriterHost(
660
652
math::IsDefinedInCUDA (),
661
653
createConditionalFactory (makeCheckAnd (math::IsPerf, PerfPred),
662
654
std::move (HostPerf), std::move (HostNormal)),
663
- {Name, std::make_shared<NoRewriteFuncNameRewriterFactory >(Name, Name)});
655
+ {Name, std::make_shared<EmptyRewriterFactory >(Name, Name)});
664
656
}
665
657
666
658
template <bool IsDouble> std::string getPiString () {
0 commit comments