@@ -886,55 +886,93 @@ template <typename Ty = llvm::Value> Ty *getVal(llvm::Metadata *M) {
886
886
return nullptr ;
887
887
}
888
888
889
- // / Return the MDNode that has the SLM size attribute.
890
- static llvm::MDNode *getSLMSizeMDNode (llvm::Function *F) {
891
- llvm::NamedMDNode *Nodes =
892
- F->getParent ()->getNamedMetadata (GENX_KERNEL_METADATA);
893
- assert (Nodes && " invalid genx.kernels metadata" );
894
- for (auto Node : Nodes->operands ()) {
895
- if (Node->getNumOperands () >= 4 && getVal (Node->getOperand (0 )) == F)
896
- return Node;
897
- }
898
- // if F is not a kernel, keep looking into its callers
899
- while (!F->use_empty ()) {
900
- auto CI = cast<CallInst>(F->use_begin ()->getUser ());
901
- auto UF = CI->getParent ()->getParent ();
902
- if (auto Node = getSLMSizeMDNode (UF))
903
- return Node;
904
- }
905
- return nullptr ;
906
- }
907
-
908
889
static inline llvm::Metadata *getMD (llvm::Value *V) {
909
890
return llvm::ValueAsMetadata::get (V);
910
891
}
911
892
912
- static void translateSLMInit (CallInst &CI) {
913
- auto F = CI.getParent ()->getParent ();
893
+ // / Updates genx.kernels metadata attribute \p MD for the given function \p F.
894
+ // / The value of the attribute is updated only if the new value \p NewVal is
895
+ // / bigger than what is already stored in the attribute.
896
+ // TODO: 1) In general this function is supposed to handle intrinsics
897
+ // translated into kernel's metadata. So, the primary/intended usage model is
898
+ // when such intrinsics are called from kernels.
899
+ // 2) For now such intrinsics are also handled in functions directly called
900
+ // from kernels and being translate into those caller-kernel meeven though such
901
+ // behaviour is not fully specified/documented.
902
+ // 3) This code (or the code in FE) must verify that slm_init or other such
903
+ // intrinsic is not called from another module because kernels in that other
904
+ // module would not get updated meta data attributes.
905
+ static void updateGenXMDNodes (llvm::Function *F, genx::KernelMDOp MD,
906
+ uint64_t NewVal) {
907
+ llvm::NamedMDNode *GenXKernelMD =
908
+ F->getParent ()->getNamedMetadata (GENX_KERNEL_METADATA);
909
+ assert (GenXKernelMD && " invalid genx.kernels metadata" );
910
+
911
+ SmallPtrSet<Function *, 32 > FunctionsVisited;
912
+ SmallVector<Function *, 32 > Worklist{F};
913
+ while (!Worklist.empty ()) {
914
+ Function *CurF = Worklist.pop_back_val ();
915
+ FunctionsVisited.insert (CurF);
916
+
917
+ // Update the meta data attribute for the current function.
918
+ for (auto Node : GenXKernelMD->operands ()) {
919
+ if (Node->getNumOperands () <= MD ||
920
+ getVal (Node->getOperand (genx::KernelMDOp::FunctionRef)) != CurF)
921
+ continue ;
914
922
915
- auto *ArgV = CI. getArgOperand ( 0 );
916
- if (!isa< ConstantInt>(ArgV)) {
917
- assert ( false && " integral constant expected for slm size " );
918
- return ;
919
- }
920
- auto NewVal = cast<llvm::ConstantInt>(ArgV)-> getZExtValue ();
921
- assert (NewVal != 0 && " zero slm bytes being requested " );
923
+ llvm::Value *Old = getVal (Node-> getOperand (MD) );
924
+ uint64_t OldVal = cast<llvm:: ConstantInt>(Old)-> getZExtValue ();
925
+ if (OldVal < NewVal) {
926
+ llvm::Value *New = llvm::ConstantInt::get (Old-> getType (), NewVal) ;
927
+ Node-> replaceOperandWith (MD, getMD (New));
928
+ }
929
+ }
922
930
923
- // find the corresponding kernel metadata and set the SLM size.
924
- if (llvm::MDNode *Node = getSLMSizeMDNode (F)) {
925
- if (llvm::Value *OldSz = getVal (Node->getOperand (4 ))) {
926
- assert (isa<llvm::ConstantInt>(OldSz) && " integer constant expected" );
927
- llvm::Value *NewSz = llvm::ConstantInt::get (OldSz->getType (), NewVal);
928
- uint64_t OldVal = cast<llvm::ConstantInt>(OldSz)->getZExtValue ();
929
- if (OldVal < NewVal)
930
- Node->replaceOperandWith (3 , getMD (NewSz));
931
+ // Update all callers as well.
932
+ for (auto It = CurF->use_begin (); It != CurF->use_end (); It++) {
933
+ auto FCall = It->getUser ();
934
+ if (!isa<CallInst>(FCall))
935
+ llvm::report_fatal_error (
936
+ llvm::Twine (__FILE__ " " ) +
937
+ " Found an intrinsic violating assumption on usage from a kernel or "
938
+ " a func directly called from a kernel" );
939
+
940
+ auto FCaller = cast<CallInst>(FCall)->getFunction ();
941
+ if (!FunctionsVisited.count (FCaller))
942
+ Worklist.push_back (FCaller);
931
943
}
932
- } else {
933
- // We check whether this call is inside a kernel function.
934
- assert (false && " slm_init shall be called by a kernel" );
935
944
}
936
945
}
937
946
947
+ // This function sets/updates VCSLMSize attribute to the kernels
948
+ // calling this intrinsic initializing SLM memory.
949
+ static void translateSLMInit (CallInst &CI) {
950
+ auto F = CI.getFunction ();
951
+ auto *ArgV = CI.getArgOperand (0 );
952
+ if (!isa<ConstantInt>(ArgV))
953
+ llvm::report_fatal_error (llvm::Twine (__FILE__ " " ) +
954
+ " integral constant is expected for slm size" );
955
+
956
+ uint64_t NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue ();
957
+ assert (NewVal != 0 && " zero slm bytes being requested" );
958
+ updateGenXMDNodes (F, genx::KernelMDOp::SLMSize, NewVal);
959
+ }
960
+
961
+ // This function sets/updates VCNamedBarrierCount attribute to the kernels
962
+ // calling this intrinsic initializing the number of named barriers.
963
+ static void translateNbarrierInit (CallInst &CI) {
964
+ auto F = CI.getFunction ();
965
+ auto *ArgV = CI.getArgOperand (0 );
966
+ if (!isa<ConstantInt>(ArgV))
967
+ llvm::report_fatal_error (
968
+ llvm::Twine (__FILE__ " " ) +
969
+ " integral constant is expected for named barrier count" );
970
+
971
+ auto NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue ();
972
+ assert (NewVal != 0 && " zero named barrier count being requested" );
973
+ updateGenXMDNodes (F, genx::KernelMDOp::NBarrierCnt, NewVal);
974
+ }
975
+
938
976
static void translatePackMask (CallInst &CI) {
939
977
using Demangler = id::ManglingParser<SimpleAllocator>;
940
978
Function *F = CI.getCalledFunction ();
@@ -1022,34 +1060,6 @@ static void translateUnPackMask(CallInst &CI) {
1022
1060
CI.replaceAllUsesWith (TransCI);
1023
1061
}
1024
1062
1025
- // This function sets VCNamedBarrierCount attribute to set
1026
- // the number of named barriers required by a kernel
1027
- static void translateNbarrierInit (CallInst &CI) {
1028
- auto *F = CI.getFunction ();
1029
-
1030
- auto *ArgV = CI.getArgOperand (0 );
1031
- assert (isa<ConstantInt>(ArgV) &&
1032
- " integral constant expected for nbarrier count" );
1033
-
1034
- auto NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue ();
1035
- assert (NewVal != 0 && " zero nbarrier count being requested" );
1036
-
1037
- if (llvm::MDNode *Node = getSLMSizeMDNode (F)) {
1038
- if (llvm::Value *OldCount =
1039
- getVal (Node->getOperand (genx::KernelMDOp::NBarrierCnt))) {
1040
- assert (isa<llvm::ConstantInt>(OldCount) && " integer constant expected" );
1041
- llvm::Value *NewCount =
1042
- llvm::ConstantInt::get (OldCount->getType (), NewVal);
1043
- uint64_t OldVal = cast<llvm::ConstantInt>(OldCount)->getZExtValue ();
1044
- if (OldVal < NewVal)
1045
- Node->replaceOperandWith (genx::KernelMDOp::NBarrierCnt,
1046
- getMD (NewCount));
1047
- }
1048
- } else {
1049
- llvm_unreachable (" esimd_nbarrier_init can only be called by a kernel" );
1050
- }
1051
- }
1052
-
1053
1063
static bool translateVLoad (CallInst &CI, SmallPtrSet<Type *, 4 > &GVTS) {
1054
1064
if (GVTS.find (CI.getType ()) != GVTS.end ())
1055
1065
return false ;
0 commit comments