Skip to content

Commit 4ff42c3

Browse files
[SYCL] Relax work-group metadata generation and processing in sycl-post-link (#7471)
When generating program metadata for `reqd_work_group_size` information sycl-post-link currently expects the existence of all three dimensions in the metadata. As a result the compile-time kernel property `work_group_size` needs to pad itself to adhere to this requirement. This commit relaxes the requirement by instead allowing fewer than three metadata operands in both `reqd_work_group_size` and `work_group_size_hint` when generating the program metadata. Additionally, padding will no longer be added when converting "sycl-work-group-size" and "sycl-work-group-size-hint" into the corresponding metadata. Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com>
1 parent 216137b commit 4ff42c3

File tree

4 files changed

+41
-21
lines changed

4 files changed

+41
-21
lines changed

llvm/test/tools/sycl-post-link/emit_program_metadata.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
target triple = "spir64-unknown-unknown"
88

99
attributes #0 = { "sycl-work-group-size"="4,2,1" }
10+
attributes #1 = { "sycl-work-group-size"="4,2" }
11+
attributes #2 = { "sycl-work-group-size"="4" }
1012

1113
!0 = !{i32 1, i32 2, i32 4}
14+
!1 = !{i32 2, i32 4}
15+
!2 = !{i32 4}
1216

1317
define weak_odr spir_kernel void @SpirKernel1(float %arg1) !reqd_work_group_size !0 {
1418
call void @foo(float %arg1)
@@ -20,12 +24,36 @@ define weak_odr spir_kernel void @SpirKernel2(float %arg1) #0 {
2024
ret void
2125
}
2226

27+
define weak_odr spir_kernel void @SpirKernel3(float %arg1) !reqd_work_group_size !1 {
28+
call void @foo(float %arg1)
29+
ret void
30+
}
31+
32+
define weak_odr spir_kernel void @SpirKernel4(float %arg1) #1 {
33+
call void @foo(float %arg1)
34+
ret void
35+
}
36+
37+
define weak_odr spir_kernel void @SpirKernel5(float %arg1) !reqd_work_group_size !2 {
38+
call void @foo(float %arg1)
39+
ret void
40+
}
41+
42+
define weak_odr spir_kernel void @SpirKernel6(float %arg1) #2 {
43+
call void @foo(float %arg1)
44+
ret void
45+
}
46+
2347
declare void @foo(float)
2448

2549
; CHECK-PROP: [SYCL/program metadata]
2650
; // Base64 encoding in the prop file (including 8 bytes length):
2751
; CHECK-PROP-NEXT: SpirKernel1@reqd_work_group_size=2|gBAAAAAAAAQAAAAACAAAAQAAAAA
2852
; CHECK-PROP-NEXT: SpirKernel2@reqd_work_group_size=2|gBAAAAAAAAQAAAAACAAAAQAAAAA
53+
; CHECK-PROP-NEXT: SpirKernel3@reqd_work_group_size=2|ABAAAAAAAAgAAAAAEAAAAA
54+
; CHECK-PROP-NEXT: SpirKernel4@reqd_work_group_size=2|ABAAAAAAAAgAAAAAEAAAAA
55+
; CHECK-PROP-NEXT: SpirKernel5@reqd_work_group_size=2|gAAAAAAAAAABAAAA
56+
; CHECK-PROP-NEXT: SpirKernel6@reqd_work_group_size=2|gAAAAAAAAAABAAAA
2957

3058
; CHECK-TABLE: [Code|Properties]
3159
; CHECK-TABLE-NEXT: {{.*}}files_0.prop

llvm/test/tools/sycl-post-link/kernel-properties.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ attributes #2 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-w
3636
!3 = !{i32 1, !"wchar_size", i32 4}
3737
!4 = !{i32 7, !"frame-pointer", i32 2}
3838

39-
; Note that work-group sizes are padded with 1's after being reversed.
4039
; CHECK-IR-DAG: ![[SGSizeMD0]] = !{i32 3}
41-
; CHECK-IR-DAG: ![[WGSizeMD0]] = !{i{{[0-9]+}} 1, i{{[0-9]+}} 1, i{{[0-9]+}} 1}
42-
; CHECK-IR-DAG: ![[WGSizeHintMD0]] = !{i{{[0-9]+}} 2, i{{[0-9]+}} 1, i{{[0-9]+}} 1}
43-
; CHECK-IR-DAG: ![[WGSizeMD1]] = !{i{{[0-9]+}} 5, i{{[0-9]+}} 4, i{{[0-9]+}} 1}
44-
; CHECK-IR-DAG: ![[WGSizeHintMD1]] = !{i{{[0-9]+}} 7, i{{[0-9]+}} 6, i{{[0-9]+}} 1}
40+
; CHECK-IR-DAG: ![[WGSizeMD0]] = !{i{{[0-9]+}} 1}
41+
; CHECK-IR-DAG: ![[WGSizeHintMD0]] = !{i{{[0-9]+}} 2}
42+
; CHECK-IR-DAG: ![[WGSizeMD1]] = !{i{{[0-9]+}} 5, i{{[0-9]+}} 4}
43+
; CHECK-IR-DAG: ![[WGSizeHintMD1]] = !{i{{[0-9]+}} 7, i{{[0-9]+}} 6}
4544
; CHECK-IR-DAG: ![[WGSizeMD2]] = !{i{{[0-9]+}} 10, i{{[0-9]+}} 9, i{{[0-9]+}} 8}
4645
; CHECK-IR-DAG: ![[WGSizeHintMD2]] = !{i{{[0-9]+}} 13, i{{[0-9]+}} 12, i{{[0-9]+}} 11}

llvm/tools/sycl-post-link/CompileTimePropertiesPass.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,6 @@ attributeToExecModeMetadata(Module &M, const Attribute &Attr) {
186186
MDVals.push_back(ConstantAsMetadata::get(Constant::getIntegerValue(
187187
SizeTTy, APInt(SizeTBitSize, ValStr, 10))));
188188

189-
// The SPIR-V translator expects 3 values, so we pad the remaining
190-
// dimensions with 1.
191-
for (size_t I = MDVals.size(); I < 3; ++I)
192-
MDVals.push_back(ConstantAsMetadata::get(
193-
Constant::getIntegerValue(SizeTTy, APInt(SizeTBitSize, 1))));
194-
195189
const char *MDName = (AttrKindStr == "sycl-work-group-size")
196190
? "reqd_work_group_size"
197191
: "work_group_size_hint";

llvm/tools/sycl-post-link/sycl-post-link.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,18 +302,17 @@ std::vector<StringRef> getKernelNamesUsingAssert(const Module &M) {
302302

303303
// Gets reqd_work_group_size information for function Func.
304304
std::vector<uint32_t> getKernelReqdWorkGroupSizeMetadata(const Function &Func) {
305-
auto *ReqdWorkGroupSizeMD = Func.getMetadata("reqd_work_group_size");
305+
MDNode *ReqdWorkGroupSizeMD = Func.getMetadata("reqd_work_group_size");
306306
if (!ReqdWorkGroupSizeMD)
307307
return {};
308-
// TODO: Remove 3-operand assumption when it is relaxed.
309-
assert(ReqdWorkGroupSizeMD->getNumOperands() == 3);
310-
uint32_t X = mdconst::extract<ConstantInt>(ReqdWorkGroupSizeMD->getOperand(0))
311-
->getZExtValue();
312-
uint32_t Y = mdconst::extract<ConstantInt>(ReqdWorkGroupSizeMD->getOperand(1))
313-
->getZExtValue();
314-
uint32_t Z = mdconst::extract<ConstantInt>(ReqdWorkGroupSizeMD->getOperand(2))
315-
->getZExtValue();
316-
return {X, Y, Z};
308+
size_t NumOperands = ReqdWorkGroupSizeMD->getNumOperands();
309+
assert(NumOperands >= 1 && NumOperands <= 3 &&
310+
"reqd_work_group_size does not have between 1 and 3 operands.");
311+
std::vector<uint32_t> OutVals;
312+
OutVals.reserve(NumOperands);
313+
for (const MDOperand &MDOp : ReqdWorkGroupSizeMD->operands())
314+
OutVals.push_back(mdconst::extract<ConstantInt>(MDOp)->getZExtValue());
315+
return OutVals;
317316
}
318317

319318
// Creates a filename based on current output filename, given extension,

0 commit comments

Comments
 (0)