@@ -1522,7 +1522,7 @@ class KernelObjVisitor {
15221522 void visitParam (ParmVarDecl *Param, QualType ParamTy,
15231523 HandlerTys &...Handlers) {
15241524 if (isSyclSpecialType (ParamTy, SemaSYCLRef))
1525- KP_FOR_EACH (handleOtherType , Param, ParamTy);
1525+ KP_FOR_EACH (handleSyclSpecialType , Param, ParamTy);
15261526 else if (ParamTy->isStructureOrClassType ()) {
15271527 if (KP_FOR_EACH (handleStructType, Param, ParamTy)) {
15281528 CXXRecordDecl *RD = ParamTy->getAsCXXRecordDecl ();
@@ -2075,8 +2075,11 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20752075 }
20762076
20772077 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2078- Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type) << ParamTy;
2079- IsInvalid = true ;
2078+ if (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
2079+ Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2080+ << ParamTy;
2081+ IsInvalid = true ;
2082+ }
20802083 return isValid ();
20812084 }
20822085
@@ -2228,8 +2231,8 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22282231 }
22292232
22302233 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2231- // TODO
2232- unsupportedFreeFunctionParamType ();
2234+ if (! SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
2235+ unsupportedFreeFunctionParamType (); // TODO
22332236 return true ;
22342237 }
22352238
@@ -3013,9 +3016,26 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30133016 return handleSpecialType (FD, FieldTy);
30143017 }
30153018
3016- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
3017- // TODO
3018- unsupportedFreeFunctionParamType ();
3019+ bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3020+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
3021+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
3022+ assert (RecordDecl && " The type must be a RecordDecl" );
3023+ CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
3024+ assert (InitMethod && " The type must have the __init method" );
3025+ // Don't do -1 here because we count on this to be the first parameter
3026+ // added (if any).
3027+ size_t ParamIndex = Params.size ();
3028+ for (const ParmVarDecl *Param : InitMethod->parameters ()) {
3029+ QualType ParamTy = Param->getType ();
3030+ addParam (Param, ParamTy.getCanonicalType ());
3031+ // Propagate add_ir_attributes_kernel_parameter attribute.
3032+ if (const auto *AddIRAttr =
3033+ Param->getAttr <SYCLAddIRAttributesKernelParameterAttr>())
3034+ Params.back ()->addAttr (AddIRAttr->clone (SemaSYCLRef.getASTContext ()));
3035+ }
3036+ LastParamIndex = ParamIndex;
3037+ } else // TODO
3038+ unsupportedFreeFunctionParamType ();
30193039 return true ;
30203040 }
30213041
@@ -3291,9 +3311,7 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
32913311 }
32923312
32933313 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3294- // TODO
3295- unsupportedFreeFunctionParamType ();
3296- return true ;
3314+ return handleSpecialType (ParamTy);
32973315 }
32983316
32993317 bool handleSyclSpecialType (const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -4442,6 +4460,45 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44424460 {});
44434461 }
44444462
4463+ MemberExpr *buildMemberExpr (Expr *Base, ValueDecl *Member) {
4464+ DeclAccessPair MemberDAP = DeclAccessPair::make (Member, AS_none);
4465+ MemberExpr *Result = SemaSYCLRef.SemaRef .BuildMemberExpr (
4466+ Base, /* IsArrow */ false , FreeFunctionSrcLoc, NestedNameSpecifierLoc (),
4467+ FreeFunctionSrcLoc, Member, MemberDAP,
4468+ /* HadMultipleCandidates*/ false ,
4469+ DeclarationNameInfo (Member->getDeclName (), FreeFunctionSrcLoc),
4470+ Member->getType (), VK_LValue, OK_Ordinary);
4471+ return Result;
4472+ }
4473+
4474+ void createSpecialMethodCall (const CXXRecordDecl *RD, StringRef MethodName,
4475+ Expr *MemberBaseExpr,
4476+ SmallVectorImpl<Stmt *> &AddTo) {
4477+ CXXMethodDecl *Method = getMethodByName (RD, MethodName);
4478+ if (!Method)
4479+ return ;
4480+ unsigned NumParams = Method->getNumParams ();
4481+ llvm::SmallVector<Expr *, 4 > ParamDREs (NumParams);
4482+ llvm::ArrayRef<ParmVarDecl *> KernelParameters =
4483+ DeclCreator.getParamVarDeclsForCurrentField ();
4484+ for (size_t I = 0 ; I < NumParams; ++I) {
4485+ QualType ParamType = KernelParameters[I]->getOriginalType ();
4486+ ParamDREs[I] = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4487+ KernelParameters[I], ParamType, VK_LValue, FreeFunctionSrcLoc);
4488+ }
4489+ MemberExpr *MethodME = buildMemberExpr (MemberBaseExpr, Method);
4490+ QualType ResultTy = Method->getReturnType ();
4491+ ExprValueKind VK = Expr::getValueKindForType (ResultTy);
4492+ ResultTy = ResultTy.getNonLValueExprType (SemaSYCLRef.getASTContext ());
4493+ llvm::SmallVector<Expr *, 4 > ParamStmts;
4494+ const auto *Proto = cast<FunctionProtoType>(Method->getType ());
4495+ SemaSYCLRef.SemaRef .GatherArgumentsForCall (FreeFunctionSrcLoc, Method,
4496+ Proto, 0 , ParamDREs, ParamStmts);
4497+ AddTo.push_back (CXXMemberCallExpr::Create (
4498+ SemaSYCLRef.getASTContext (), MethodME, ParamStmts, ResultTy, VK,
4499+ FreeFunctionSrcLoc, FPOptionsOverride ()));
4500+ }
4501+
44454502public:
44464503 static constexpr const bool VisitInsideSimpleContainers = false ;
44474504
@@ -4461,9 +4518,53 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
44614518 return true ;
44624519 }
44634520
4464- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
4465- // TODO
4466- unsupportedFreeFunctionParamType ();
4521+ // Default inits the type, then calls the init-method in the body.
4522+ // A type may not have a public default constructor as per its spec so
4523+ // typically if this is the case the default constructor will be private and
4524+ // in such cases we must manually override the access specifier from private
4525+ // to public just for the duration of this default initialization.
4526+ // TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
4527+ // is closed.
4528+ bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4529+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
4530+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
4531+ AccessSpecifier DefaultConstructorAccess;
4532+ auto DefaultConstructor =
4533+ std::find_if (RecordDecl->ctor_begin (), RecordDecl->ctor_end (),
4534+ [](auto it) { return it->isDefaultConstructor (); });
4535+ DefaultConstructorAccess = DefaultConstructor->getAccess ();
4536+ DefaultConstructor->setAccess (AS_public);
4537+
4538+ QualType Ty = PD->getOriginalType ();
4539+ ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4540+ VarDecl *WorkGroupMemoryClone = VarDecl::Create (
4541+ Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc,
4542+ FreeFunctionSrcLoc, PD->getIdentifier (), PD->getType (),
4543+ Ctx.getTrivialTypeSourceInfo (Ty), SC_None);
4544+ InitializedEntity VarEntity =
4545+ InitializedEntity::InitializeVariable (WorkGroupMemoryClone);
4546+ InitializationKind InitKind =
4547+ InitializationKind::CreateDefault (FreeFunctionSrcLoc);
4548+ InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
4549+ std::nullopt );
4550+ ExprResult Init = InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity,
4551+ InitKind, std::nullopt );
4552+ WorkGroupMemoryClone->setInit (
4553+ SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4554+ WorkGroupMemoryClone->setInitStyle (VarDecl::CallInit);
4555+ DefaultConstructor->setAccess (DefaultConstructorAccess);
4556+
4557+ Stmt *DS = new (SemaSYCLRef.getASTContext ())
4558+ DeclStmt (DeclGroupRef (WorkGroupMemoryClone), FreeFunctionSrcLoc,
4559+ FreeFunctionSrcLoc);
4560+ BodyStmts.push_back (DS);
4561+ Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4562+ WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4563+ createSpecialMethodCall (RecordDecl, InitMethodName, MemberBaseExpr,
4564+ BodyStmts);
4565+ ArgExprs.push_back (MemberBaseExpr);
4566+ } else // TODO
4567+ unsupportedFreeFunctionParamType ();
44674568 return true ;
44684569 }
44694570
@@ -4748,9 +4849,11 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
47484849 return true ;
47494850 }
47504851
4751- bool handleSyclSpecialType (ParmVarDecl *, QualType) final {
4752- // TODO
4753- unsupportedFreeFunctionParamType ();
4852+ bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4853+ if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
4854+ addParam (PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4855+ else
4856+ unsupportedFreeFunctionParamType (); // TODO
47544857 return true ;
47554858 }
47564859
@@ -6227,7 +6330,6 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
62276330 O << " #include <sycl/detail/defines_elementary.hpp>\n " ;
62286331 O << " #include <sycl/detail/kernel_desc.hpp>\n " ;
62296332 O << " #include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n " ;
6230-
62316333 O << " \n " ;
62326334
62336335 LangOptions LO;
@@ -6502,6 +6604,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65026604
65036605 O << " \n " ;
65046606 O << " // Forward declarations of kernel and its argument types:\n " ;
6607+ Policy.SuppressDefaultTemplateArgs = false ;
65056608 FwdDeclEmitter.Visit (K.SyclKernel ->getType ());
65066609 O << " \n " ;
65076610
@@ -6510,6 +6613,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65106613 std::string ParmList;
65116614 bool FirstParam = true ;
65126615 Policy.SuppressDefaultTemplateArgs = false ;
6616+ Policy.PrintCanonicalTypes = true ;
65136617 for (ParmVarDecl *Param : K.SyclKernel ->parameters ()) {
65146618 if (FirstParam)
65156619 FirstParam = false ;
@@ -6518,6 +6622,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65186622 ParmList += Param->getType ().getCanonicalType ().getAsString (Policy);
65196623 }
65206624 FunctionTemplateDecl *FTD = K.SyclKernel ->getPrimaryTemplate ();
6625+ Policy.PrintCanonicalTypes = false ;
65216626 Policy.SuppressDefinition = true ;
65226627 Policy.PolishForDeclaration = true ;
65236628 Policy.FullyQualifiedName = true ;
@@ -6577,6 +6682,8 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65776682 }
65786683 O << " ;\n " ;
65796684 O << " }\n " ;
6685+ Policy.SuppressDefaultTemplateArgs = true ;
6686+ Policy.EnforceDefaultTemplateArgs = false ;
65806687
65816688 // Generate is_kernel, is_single_task_kernel and nd_range_kernel functions.
65826689 O << " namespace sycl {\n " ;
0 commit comments