@@ -1522,7 +1522,7 @@ class KernelObjVisitor {
1522
1522
void visitParam (ParmVarDecl *Param, QualType ParamTy,
1523
1523
HandlerTys &...Handlers) {
1524
1524
if (isSyclSpecialType (ParamTy, SemaSYCLRef))
1525
- KP_FOR_EACH (handleOtherType , Param, ParamTy);
1525
+ KP_FOR_EACH (handleSyclSpecialType , Param, ParamTy);
1526
1526
else if (ParamTy->isStructureOrClassType ()) {
1527
1527
if (KP_FOR_EACH (handleStructType, Param, ParamTy)) {
1528
1528
CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl ();
@@ -2075,8 +2075,11 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
2075
2075
}
2076
2076
2077
2077
bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2078
- Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type) << ParamTy;
2079
- IsInvalid = true ;
2078
+ if (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
2079
+ Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2080
+ << ParamTy;
2081
+ IsInvalid = true ;
2082
+ }
2080
2083
return isValid ();
2081
2084
}
2082
2085
@@ -2228,8 +2231,8 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
2228
2231
}
2229
2232
2230
2233
bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2231
- // TODO
2232
- unsupportedFreeFunctionParamType ();
2234
+ if (! SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
2235
+ unsupportedFreeFunctionParamType (); // TODO
2233
2236
return true ;
2234
2237
}
2235
2238
@@ -3013,9 +3016,26 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
3013
3016
return handleSpecialType (FD, FieldTy);
3014
3017
}
3015
3018
3016
- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
3017
- // TODO
3018
- unsupportedFreeFunctionParamType ();
3019
+ bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3020
+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
3021
+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
3022
+ assert (RecordDecl && " The type must be a RecordDecl" );
3023
+ CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
3024
+ assert (InitMethod && " The type must have the __init method" );
3025
+ // Don't do -1 here because we count on this to be the first parameter
3026
+ // added (if any).
3027
+ size_t ParamIndex = Params.size ();
3028
+ for (const ParmVarDecl *Param : InitMethod->parameters ()) {
3029
+ QualType ParamTy = Param->getType ();
3030
+ addParam (Param, ParamTy.getCanonicalType ());
3031
+ // Propagate add_ir_attributes_kernel_parameter attribute.
3032
+ if (const auto *AddIRAttr =
3033
+ Param->getAttr <SYCLAddIRAttributesKernelParameterAttr>())
3034
+ Params.back ()->addAttr (AddIRAttr->clone (SemaSYCLRef.getASTContext ()));
3035
+ }
3036
+ LastParamIndex = ParamIndex;
3037
+ } else // TODO
3038
+ unsupportedFreeFunctionParamType ();
3019
3039
return true ;
3020
3040
}
3021
3041
@@ -3291,9 +3311,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
3291
3311
}
3292
3312
3293
3313
bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3294
- // TODO
3295
- unsupportedFreeFunctionParamType ();
3296
- return true ;
3314
+ return handleSpecialType (ParamTy);
3297
3315
}
3298
3316
3299
3317
bool handleSyclSpecialType (const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -4442,6 +4460,45 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
4442
4460
{});
4443
4461
}
4444
4462
4463
+ MemberExpr *buildMemberExpr (Expr *Base, ValueDecl *Member) {
4464
+ DeclAccessPair MemberDAP = DeclAccessPair::make (Member, AS_none);
4465
+ MemberExpr *Result = SemaSYCLRef.SemaRef .BuildMemberExpr (
4466
+ Base, /* IsArrow */ false , FreeFunctionSrcLoc, NestedNameSpecifierLoc (),
4467
+ FreeFunctionSrcLoc, Member, MemberDAP,
4468
+ /* HadMultipleCandidates*/ false ,
4469
+ DeclarationNameInfo (Member->getDeclName (), FreeFunctionSrcLoc),
4470
+ Member->getType (), VK_LValue, OK_Ordinary);
4471
+ return Result;
4472
+ }
4473
+
4474
+ void createSpecialMethodCall (const CXXRecordDecl *RD, StringRef MethodName,
4475
+ Expr *MemberBaseExpr,
4476
+ SmallVectorImpl<Stmt *> &AddTo) {
4477
+ CXXMethodDecl *Method = getMethodByName (RD, MethodName);
4478
+ if (!Method)
4479
+ return ;
4480
+ unsigned NumParams = Method->getNumParams ();
4481
+ llvm::SmallVector<Expr *, 4 > ParamDREs (NumParams);
4482
+ llvm::ArrayRef<ParmVarDecl *> KernelParameters =
4483
+ DeclCreator.getParamVarDeclsForCurrentField ();
4484
+ for (size_t I = 0 ; I < NumParams; ++I) {
4485
+ QualType ParamType = KernelParameters[I]->getOriginalType ();
4486
+ ParamDREs[I] = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4487
+ KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc);
4488
+ }
4489
+ MemberExpr *MethodME = buildMemberExpr (MemberBaseExpr, Method);
4490
+ QualType ResultTy = Method->getReturnType ();
4491
+ ExprValueKind VK = Expr::getValueKindForType (ResultTy);
4492
+ ResultTy = ResultTy.getNonLValueExprType (SemaSYCLRef.getASTContext ());
4493
+ llvm::SmallVector<Expr *, 4 > ParamStmts;
4494
+ const auto *Proto = cast<FunctionProtoType>(Method->getType ());
4495
+ SemaSYCLRef.SemaRef .GatherArgumentsForCall (FreeFunctionSrcLoc, Method,
4496
+ Proto, 0 , ParamDREs, ParamStmts);
4497
+ AddTo.push_back (CXXMemberCallExpr::Create (
4498
+ SemaSYCLRef.getASTContext (), MethodME, ParamStmts, ResultTy, VK,
4499
+ FreeFunctionSrcLoc, FPOptionsOverride ()));
4500
+ }
4501
+
4445
4502
public:
4446
4503
static constexpr const bool VisitInsideSimpleContainers = false ;
4447
4504
@@ -4461,9 +4518,53 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
4461
4518
return true ;
4462
4519
}
4463
4520
4464
- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
4465
- // TODO
4466
- unsupportedFreeFunctionParamType ();
4521
+ // Default inits the type, then calls the init-method in the body.
4522
+ // A type may not have a public default constructor as per its spec so
4523
+ // typically if this is the case the default constructor will be private and
4524
+ // in such cases we must manually override the access specifier from private
4525
+ // to public just for the duration of this default initialization.
4526
+ // TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
4527
+ // is closed.
4528
+ bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4529
+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
4530
+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
4531
+ AccessSpecifier DefaultConstructorAccess;
4532
+ auto DefaultConstructor =
4533
+ std::find_if (RecordDecl->ctor_begin (), RecordDecl->ctor_end (),
4534
+ [](auto it) { return it->isDefaultConstructor (); });
4535
+ DefaultConstructorAccess = DefaultConstructor->getAccess ();
4536
+ DefaultConstructor->setAccess (AS_public);
4537
+
4538
+ QualType Ty = PD->getOriginalType ();
4539
+ ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4540
+ VarDecl *WorkGroupMemoryClone = VarDecl::Create (
4541
+ Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc,
4542
+ FreeFunctionSrcLoc, PD->getIdentifier (), PD->getType (),
4543
+ Ctx.getTrivialTypeSourceInfo (Ty), SC_None);
4544
+ InitializedEntity VarEntity =
4545
+ InitializedEntity::InitializeVariable (WorkGroupMemoryClone);
4546
+ InitializationKind InitKind =
4547
+ InitializationKind::CreateDefault (FreeFunctionSrcLoc);
4548
+ InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
4549
+ std::nullopt);
4550
+ ExprResult Init = InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity,
4551
+ InitKind, std::nullopt);
4552
+ WorkGroupMemoryClone->setInit (
4553
+ SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4554
+ WorkGroupMemoryClone->setInitStyle (VarDecl::CallInit);
4555
+ DefaultConstructor->setAccess (DefaultConstructorAccess);
4556
+
4557
+ Stmt *DS = new (SemaSYCLRef.getASTContext ())
4558
+ DeclStmt (DeclGroupRef (WorkGroupMemoryClone), FreeFunctionSrcLoc,
4559
+ FreeFunctionSrcLoc);
4560
+ BodyStmts.push_back (DS);
4561
+ Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4562
+ WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4563
+ createSpecialMethodCall (RecordDecl, InitMethodName, MemberBaseExpr,
4564
+ BodyStmts);
4565
+ ArgExprs.push_back (MemberBaseExpr);
4566
+ } else // TODO
4567
+ unsupportedFreeFunctionParamType ();
4467
4568
return true ;
4468
4569
}
4469
4570
@@ -4748,9 +4849,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
4748
4849
return true ;
4749
4850
}
4750
4851
4751
- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
4752
- // TODO
4753
- unsupportedFreeFunctionParamType ();
4852
+ bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4853
+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
4854
+ addParam (PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4855
+ else
4856
+ unsupportedFreeFunctionParamType (); // TODO
4754
4857
return true ;
4755
4858
}
4756
4859
@@ -6227,7 +6330,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6227
6330
O << " #include <sycl/detail/defines_elementary.hpp>\n " ;
6228
6331
O << " #include <sycl/detail/kernel_desc.hpp>\n " ;
6229
6332
O << " #include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n " ;
6230
-
6231
6333
O << " \n " ;
6232
6334
6233
6335
LangOptions LO;
@@ -6502,6 +6604,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6502
6604
6503
6605
O << " \n " ;
6504
6606
O << " // Forward declarations of kernel and its argument types:\n " ;
6607
+ Policy.SuppressDefaultTemplateArgs = false ;
6505
6608
FwdDeclEmitter.Visit (K.SyclKernel ->getType ());
6506
6609
O << " \n " ;
6507
6610
@@ -6510,6 +6613,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6510
6613
std::string ParmList;
6511
6614
bool FirstParam = true ;
6512
6615
Policy.SuppressDefaultTemplateArgs = false ;
6616
+ Policy.PrintCanonicalTypes = true ;
6513
6617
for (ParmVarDecl *Param : K.SyclKernel ->parameters ()) {
6514
6618
if (FirstParam)
6515
6619
FirstParam = false ;
@@ -6518,6 +6622,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6518
6622
ParmList += Param->getType ().getCanonicalType ().getAsString (Policy);
6519
6623
}
6520
6624
FunctionTemplateDecl *FTD = K.SyclKernel ->getPrimaryTemplate ();
6625
+ Policy.PrintCanonicalTypes = false ;
6521
6626
Policy.SuppressDefinition = true ;
6522
6627
Policy.PolishForDeclaration = true ;
6523
6628
Policy.FullyQualifiedName = true ;
@@ -6577,6 +6682,8 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
6577
6682
}
6578
6683
O << " ;\n " ;
6579
6684
O << " }\n " ;
6685
+ Policy.SuppressDefaultTemplateArgs = true ;
6686
+ Policy.EnforceDefaultTemplateArgs = false ;
6580
6687
6581
6688
// Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
6582
6689
O << " namespace sycl {\n " ;
0 commit comments