@@ -4512,10 +4512,14 @@ void StreamAPICallRule::runRule(const MatchFinder::MatchResult &Result) {
4512
4512
}
4513
4513
4514
4514
void KernelCallRefRule::registerMatcher (ast_matchers::MatchFinder &MF) {
4515
- MF.addMatcher (declRefExpr (allOf (to (functionDecl (hasAttr (attr::CUDAGlobal))),
4516
- unless (hasAncestor (cudaKernelCallExpr ()))))
4517
- .bind (" kernelRef" ),
4518
- this );
4515
+ MF.addMatcher (
4516
+ functionDecl (
4517
+ forEachDescendant (
4518
+ declRefExpr (allOf (to (functionDecl (hasAttr (attr::CUDAGlobal))),
4519
+ unless (hasAncestor (cudaKernelCallExpr ()))))
4520
+ .bind (" kernelRef" )))
4521
+ .bind (" outerFunc" ),
4522
+ this );
4519
4523
MF.addMatcher (unresolvedLookupExpr (unless (hasAncestor (cudaKernelCallExpr ())))
4520
4524
.bind (" unresolvedRef" ),
4521
4525
this );
@@ -4591,6 +4595,11 @@ void KernelCallRefRule::insertWrapperPostfix(const T *Node,
4591
4595
void KernelCallRefRule::runRule (
4592
4596
const ast_matchers::MatchFinder::MatchResult &Result) {
4593
4597
if (auto DRE = getAssistNodeAsType<DeclRefExpr>(Result, " kernelRef" )) {
4598
+ const FunctionDecl *OuterFD =
4599
+ getAssistNodeAsType<FunctionDecl>(Result, " outerFunc" );
4600
+ if (!OuterFD) {
4601
+ return ;
4602
+ }
4594
4603
if (auto ParentCE = DpctGlobalInfo::findAncestor<CallExpr>(DRE)) {
4595
4604
if (auto Callee = ParentCE->getDirectCallee ()) {
4596
4605
if (dpct::DpctGlobalInfo::isInCudaPath (Callee->getBeginLoc ())) {
@@ -4614,31 +4623,25 @@ void KernelCallRefRule::runRule(
4614
4623
DFI->collectInfoForWrapper (FD);
4615
4624
}
4616
4625
}
4617
- if (auto *OuterFD = DpctGlobalInfo::findAncestor<FunctionDecl>(DRE)) {
4618
- if ((OuterFD->getTemplatedKind () ==
4619
- FunctionDecl::TemplatedKind::TK_NonTemplate) ||
4620
- (OuterFD->getTemplatedKind () ==
4621
- FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
4622
- std::string TypeRepl;
4623
- if (DpctGlobalInfo::isCVersionCUDALaunchUsed ()) {
4624
- if ((IsTemplateRelated &&
4625
- (!DRE->hasExplicitTemplateArgs () ||
4626
- (DRE->getNumTemplateArgs () <= TemplateParamNum))) ||
4627
- DRE->hadMultipleCandidates ()) {
4628
- TypeRepl = getTypeRepl (DRE);
4629
- }
4626
+ if ((OuterFD->getTemplatedKind () ==
4627
+ FunctionDecl::TemplatedKind::TK_NonTemplate) ||
4628
+ (OuterFD->getTemplatedKind () ==
4629
+ FunctionDecl::TemplatedKind::TK_FunctionTemplate)) {
4630
+ std::string TypeRepl;
4631
+ if (DpctGlobalInfo::isCVersionCUDALaunchUsed ()) {
4632
+ if ((IsTemplateRelated &&
4633
+ (!DRE->hasExplicitTemplateArgs () ||
4634
+ (DRE->getNumTemplateArgs () <= TemplateParamNum))) ||
4635
+ DRE->hadMultipleCandidates ()) {
4636
+ TypeRepl = getTypeRepl (DRE);
4630
4637
}
4631
- insertWrapperPostfix<DeclRefExpr>(
4632
- DRE, std::move (TypeRepl),
4633
- DpctGlobalInfo::isCVersionCUDALaunchUsed ());
4634
4638
}
4639
+ insertWrapperPostfix<DeclRefExpr>(
4640
+ DRE, std::move (TypeRepl), DpctGlobalInfo::isCVersionCUDALaunchUsed ());
4635
4641
}
4636
4642
}
4637
4643
if (auto ULE =
4638
4644
getAssistNodeAsType<UnresolvedLookupExpr>(Result, " unresolvedRef" )) {
4639
- if (!DpctGlobalInfo::isCVersionCUDALaunchUsed ()) {
4640
- return ;
4641
- }
4642
4645
bool KernelRefFound = false ;
4643
4646
for (auto *D : ULE->decls ()) {
4644
4647
const FunctionDecl *FD = dyn_cast<FunctionDecl>(D);
@@ -4670,7 +4673,8 @@ void KernelCallRefRule::runRule(
4670
4673
}
4671
4674
}
4672
4675
}
4673
- insertWrapperPostfix<UnresolvedLookupExpr>(ULE, getTypeRepl (ULE), true );
4676
+ insertWrapperPostfix<UnresolvedLookupExpr>(
4677
+ ULE, getTypeRepl (ULE), DpctGlobalInfo::isCVersionCUDALaunchUsed ());
4674
4678
}
4675
4679
}
4676
4680
0 commit comments