Skip to content

Commit cc3930b

Browse files
authored
[ESIMD] Fix wrong genx attributes generation in sycl-post-link (#6029)
* [ESIMD] Fix wrong genx attributes generation in sycl-post-link The attributes NBarrierCount and SLMSize were set incorrectly in those cases where there was a function shared for 2 different kernels. The attribute generated in the shared function should be set for all callers. Previously, only attributes of the 1st caller were updated. Signed-off-by: Vyacheslav N Klochkov <vyacheslav.n.klochkov@intel.com>
1 parent 969fb7f commit cc3930b

File tree

2 files changed

+117
-67
lines changed

2 files changed

+117
-67
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 77 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -886,55 +886,93 @@ template <typename Ty = llvm::Value> Ty *getVal(llvm::Metadata *M) {
886886
return nullptr;
887887
}
888888

889-
/// Return the MDNode that has the SLM size attribute.
890-
static llvm::MDNode *getSLMSizeMDNode(llvm::Function *F) {
891-
llvm::NamedMDNode *Nodes =
892-
F->getParent()->getNamedMetadata(GENX_KERNEL_METADATA);
893-
assert(Nodes && "invalid genx.kernels metadata");
894-
for (auto Node : Nodes->operands()) {
895-
if (Node->getNumOperands() >= 4 && getVal(Node->getOperand(0)) == F)
896-
return Node;
897-
}
898-
// if F is not a kernel, keep looking into its callers
899-
while (!F->use_empty()) {
900-
auto CI = cast<CallInst>(F->use_begin()->getUser());
901-
auto UF = CI->getParent()->getParent();
902-
if (auto Node = getSLMSizeMDNode(UF))
903-
return Node;
904-
}
905-
return nullptr;
906-
}
907-
908889
static inline llvm::Metadata *getMD(llvm::Value *V) {
909890
return llvm::ValueAsMetadata::get(V);
910891
}
911892

912-
static void translateSLMInit(CallInst &CI) {
913-
auto F = CI.getParent()->getParent();
893+
/// Updates genx.kernels metadata attribute \p MD for the given function \p F.
894+
/// The value of the attribute is updated only if the new value \p NewVal is
895+
/// bigger than what is already stored in the attribute.
896+
// TODO: 1) In general this function is supposed to handle intrinsics
897+
// translated into kernel's metadata. So, the primary/intended usage model is
898+
// when such intrinsics are called from kernels.
899+
// 2) For now such intrinsics are also handled in functions directly called
900+
// from kernels and being translate into those caller-kernel meeven though such
901+
// behaviour is not fully specified/documented.
902+
// 3) This code (or the code in FE) must verify that slm_init or other such
903+
// intrinsic is not called from another module because kernels in that other
904+
// module would not get updated meta data attributes.
905+
static void updateGenXMDNodes(llvm::Function *F, genx::KernelMDOp MD,
906+
uint64_t NewVal) {
907+
llvm::NamedMDNode *GenXKernelMD =
908+
F->getParent()->getNamedMetadata(GENX_KERNEL_METADATA);
909+
assert(GenXKernelMD && "invalid genx.kernels metadata");
910+
911+
SmallPtrSet<Function *, 32> FunctionsVisited;
912+
SmallVector<Function *, 32> Worklist{F};
913+
while (!Worklist.empty()) {
914+
Function *CurF = Worklist.pop_back_val();
915+
FunctionsVisited.insert(CurF);
916+
917+
// Update the meta data attribute for the current function.
918+
for (auto Node : GenXKernelMD->operands()) {
919+
if (Node->getNumOperands() <= MD ||
920+
getVal(Node->getOperand(genx::KernelMDOp::FunctionRef)) != CurF)
921+
continue;
914922

915-
auto *ArgV = CI.getArgOperand(0);
916-
if (!isa<ConstantInt>(ArgV)) {
917-
assert(false && "integral constant expected for slm size");
918-
return;
919-
}
920-
auto NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue();
921-
assert(NewVal != 0 && "zero slm bytes being requested");
923+
llvm::Value *Old = getVal(Node->getOperand(MD));
924+
uint64_t OldVal = cast<llvm::ConstantInt>(Old)->getZExtValue();
925+
if (OldVal < NewVal) {
926+
llvm::Value *New = llvm::ConstantInt::get(Old->getType(), NewVal);
927+
Node->replaceOperandWith(MD, getMD(New));
928+
}
929+
}
922930

923-
// find the corresponding kernel metadata and set the SLM size.
924-
if (llvm::MDNode *Node = getSLMSizeMDNode(F)) {
925-
if (llvm::Value *OldSz = getVal(Node->getOperand(4))) {
926-
assert(isa<llvm::ConstantInt>(OldSz) && "integer constant expected");
927-
llvm::Value *NewSz = llvm::ConstantInt::get(OldSz->getType(), NewVal);
928-
uint64_t OldVal = cast<llvm::ConstantInt>(OldSz)->getZExtValue();
929-
if (OldVal < NewVal)
930-
Node->replaceOperandWith(3, getMD(NewSz));
931+
// Update all callers as well.
932+
for (auto It = CurF->use_begin(); It != CurF->use_end(); It++) {
933+
auto FCall = It->getUser();
934+
if (!isa<CallInst>(FCall))
935+
llvm::report_fatal_error(
936+
llvm::Twine(__FILE__ " ") +
937+
"Found an intrinsic violating assumption on usage from a kernel or "
938+
"a func directly called from a kernel");
939+
940+
auto FCaller = cast<CallInst>(FCall)->getFunction();
941+
if (!FunctionsVisited.count(FCaller))
942+
Worklist.push_back(FCaller);
931943
}
932-
} else {
933-
// We check whether this call is inside a kernel function.
934-
assert(false && "slm_init shall be called by a kernel");
935944
}
936945
}
937946

947+
// This function sets/updates VCSLMSize attribute to the kernels
948+
// calling this intrinsic initializing SLM memory.
949+
static void translateSLMInit(CallInst &CI) {
950+
auto F = CI.getFunction();
951+
auto *ArgV = CI.getArgOperand(0);
952+
if (!isa<ConstantInt>(ArgV))
953+
llvm::report_fatal_error(llvm::Twine(__FILE__ " ") +
954+
"integral constant is expected for slm size");
955+
956+
uint64_t NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue();
957+
assert(NewVal != 0 && "zero slm bytes being requested");
958+
updateGenXMDNodes(F, genx::KernelMDOp::SLMSize, NewVal);
959+
}
960+
961+
// This function sets/updates VCNamedBarrierCount attribute to the kernels
962+
// calling this intrinsic initializing the number of named barriers.
963+
static void translateNbarrierInit(CallInst &CI) {
964+
auto F = CI.getFunction();
965+
auto *ArgV = CI.getArgOperand(0);
966+
if (!isa<ConstantInt>(ArgV))
967+
llvm::report_fatal_error(
968+
llvm::Twine(__FILE__ " ") +
969+
"integral constant is expected for named barrier count");
970+
971+
auto NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue();
972+
assert(NewVal != 0 && "zero named barrier count being requested");
973+
updateGenXMDNodes(F, genx::KernelMDOp::NBarrierCnt, NewVal);
974+
}
975+
938976
static void translatePackMask(CallInst &CI) {
939977
using Demangler = id::ManglingParser<SimpleAllocator>;
940978
Function *F = CI.getCalledFunction();
@@ -1022,34 +1060,6 @@ static void translateUnPackMask(CallInst &CI) {
10221060
CI.replaceAllUsesWith(TransCI);
10231061
}
10241062

1025-
// This function sets VCNamedBarrierCount attribute to set
1026-
// the number of named barriers required by a kernel
1027-
static void translateNbarrierInit(CallInst &CI) {
1028-
auto *F = CI.getFunction();
1029-
1030-
auto *ArgV = CI.getArgOperand(0);
1031-
assert(isa<ConstantInt>(ArgV) &&
1032-
"integral constant expected for nbarrier count");
1033-
1034-
auto NewVal = cast<llvm::ConstantInt>(ArgV)->getZExtValue();
1035-
assert(NewVal != 0 && "zero nbarrier count being requested");
1036-
1037-
if (llvm::MDNode *Node = getSLMSizeMDNode(F)) {
1038-
if (llvm::Value *OldCount =
1039-
getVal(Node->getOperand(genx::KernelMDOp::NBarrierCnt))) {
1040-
assert(isa<llvm::ConstantInt>(OldCount) && "integer constant expected");
1041-
llvm::Value *NewCount =
1042-
llvm::ConstantInt::get(OldCount->getType(), NewVal);
1043-
uint64_t OldVal = cast<llvm::ConstantInt>(OldCount)->getZExtValue();
1044-
if (OldVal < NewVal)
1045-
Node->replaceOperandWith(genx::KernelMDOp::NBarrierCnt,
1046-
getMD(NewCount));
1047-
}
1048-
} else {
1049-
llvm_unreachable("esimd_nbarrier_init can only be called by a kernel");
1050-
}
1051-
}
1052-
10531063
static bool translateVLoad(CallInst &CI, SmallPtrSet<Type *, 4> &GVTS) {
10541064
if (GVTS.find(CI.getType()) != GVTS.end())
10551065
return false;

sycl/test/esimd/genx_func_attr.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %clangxx -O2 -fsycl -fsycl-device-only -Xclang -emit-llvm %s -o %t
2+
// RUN: sycl-post-link -split-esimd -lower-esimd -O2 -S %t -o %t.table
3+
// RUN: FileCheck %s -input-file=%t_esimd_0.ll
4+
5+
// Checks ESIMD intrinsic translation.
6+
// NOTE: must be run in -O0, as optimizer optimizes away some of the code
7+
8+
#include <CL/sycl.hpp>
9+
#include <CL/sycl/detail/image_ocl_types.hpp>
10+
#include <sycl/ext/intel/esimd.hpp>
11+
12+
using namespace sycl::ext::intel::esimd;
13+
14+
template <typename name, typename Func>
15+
__attribute__((sycl_kernel)) void kernel(Func kernelFunc) {
16+
kernelFunc();
17+
}
18+
19+
SYCL_ESIMD_FUNCTION SYCL_EXTERNAL ESIMD_NOINLINE void callee(int x) {
20+
slm_init(1234);
21+
sycl::ext::intel::experimental::esimd::named_barrier_init<13>();
22+
}
23+
24+
// inherits SLMSize and NBarrierCount from callee
25+
void caller_abc(int x) {
26+
kernel<class kernel_abc>([=]() SYCL_ESIMD_KERNEL { callee(x); });
27+
// CHECK: define dso_local spir_kernel void @_ZTSZ10caller_abciE10kernel_abc(i32 noundef "VCArgumentIOKind"="0" %_arg_x) local_unnamed_addr #2
28+
}
29+
30+
// inherits only NBarrierCount from callee
31+
void caller_xyz(int x) {
32+
kernel<class kernel_xyz>([=]() SYCL_ESIMD_KERNEL {
33+
slm_init(1235);
34+
callee(x);
35+
});
36+
// CHECK: define dso_local spir_kernel void @_ZTSZ10caller_xyziE10kernel_xyz(i32 noundef "VCArgumentIOKind"="0" %_arg_x) local_unnamed_addr #3
37+
}
38+
39+
// CHECK: attributes #2 = { {{.*}} "VCNamedBarrierCount"="13" "VCSLMSize"="1234"
40+
// CHECK: attributes #3 = { {{.*}} "VCNamedBarrierCount"="13" "VCSLMSize"="1235"

0 commit comments

Comments
 (0)