Skip to content

Commit e257292

Browse files
authored
[SYCL][Matrix] Propagate constexpr matrix layout even with O0 (#16628)
Per SPIR-V specification Layout of a matrix must be a constant instruction aka a constexpr or specialization constant. Meanwhile in SYCL headers layout is passed as a parameter to joint_matrix_load function, so even if that layout is a constant expression in the user's code - it's not possible to prove that to the compiler, so constant propagation will happen only after inlining, not in AST. That means, that with O0 layout would remain to be a runtime variable in LLVM IR. SYCL matrix layout is being mapped on SPIR-V matrix layout by joint_matrix_layout_to_spv function. This patch adds routine that finds calls to this function and replaces them with the found constant. To help this routine always_inline attribute was removed from joint_matrix_layout_to_spv function. --------- Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
1 parent 004d6d9 commit e257292

File tree

3 files changed

+178
-3
lines changed

3 files changed

+178
-3
lines changed

llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace {
2121

2222
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
2323
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
24+
static constexpr char MATRIX_LAYOUT[] = "joint_matrix_layout_to_spv";
2425

2526
Type *getInnermostType(Type *Ty) {
2627
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
@@ -184,17 +185,99 @@ bool transformAccessChain(Function *F) {
184185
}
185186
return ModuleChanged;
186187
}
188+
189+
StoreInst *findLastStoreBeforeLoad(Value *Ptr, Instruction *Load) {
190+
BasicBlock::iterator It(Load);
191+
while (It != Load->getParent()->begin()) {
192+
--It;
193+
if (auto *Store = dyn_cast<StoreInst>(&*It))
194+
if (Store->getPointerOperand() == Ptr)
195+
return Store;
196+
}
197+
return nullptr;
198+
}
199+
200+
// Per SPIR-V specification Layout of a matrix must be a constant instruction
201+
// aka a constexpr or specialization constant. Meanwhile in SYCL headers
202+
// layout is passed as a parameter to joint_matrix_load function, so even if
203+
// that layout is a constant expression in the user's code - it's not possible
204+
// to prove that to the compiler, so constant propagation will happen only
205+
// after inlining, not in AST. That means, that with O0 layout would remain
206+
// to be a runtime variable in LLVM IR.
207+
// SYCL matrix layout is being mapped on SPIR-V matrix layout by
208+
// joint_matrix_layout_to_spv function. The following routine finds calls to
209+
// this function and replaces them with the found constant.
210+
// This function also cleans up code, that becomes dead. Pattern of the dead
211+
// code is stable, as user's code doesn't affect it.
212+
bool propagateConstexprLayout(Function *F) {
213+
llvm::SmallVector<Instruction *, 8> ToErase;
214+
for (auto I = F->user_begin(), E = F->user_end(); I != E;) {
215+
User *U = *I++;
216+
auto *CI = dyn_cast<CallInst>(U);
217+
if (!CI)
218+
continue;
219+
auto *Op = dyn_cast<Instruction>(CI->getArgOperand(0));
220+
if (!Op || !isa<LoadInst>(Op))
221+
continue;
222+
auto *Ptr = dyn_cast<Instruction>(cast<LoadInst>(Op)->getPointerOperand());
223+
if (!Ptr)
224+
continue;
225+
226+
ConstantInt *ConstLayout = nullptr;
227+
StoreInst *SI = findLastStoreBeforeLoad(Ptr, Op);
228+
if (!SI)
229+
continue;
230+
ConstLayout = dyn_cast<ConstantInt>(SI->getValueOperand());
231+
if (ConstLayout) {
232+
CI->replaceAllUsesWith(ConstLayout);
233+
ToErase.push_back(CI);
234+
ToErase.push_back(SI);
235+
ToErase.push_back(Op);
236+
ToErase.push_back(Ptr);
237+
if (auto *Cast = dyn_cast<AddrSpaceCastInst>(Ptr)) {
238+
auto *OrigPtr = Cast->getPointerOperand();
239+
if (auto *AI = dyn_cast<AllocaInst>(OrigPtr))
240+
ToErase.push_back(AI);
241+
}
242+
}
243+
}
244+
245+
// There are possible cases, when a single instruction result is used multiple
246+
// times. For this case we have to use a vector to store such instructions
247+
// and keep track if we have removed them before to avoid double free().
248+
SmallPtrSet<Instruction *, 8> Erased;
249+
for (Instruction *II : ToErase) {
250+
if (!II->use_empty())
251+
continue;
252+
if (Erased.contains(II))
253+
continue;
254+
II->dropAllReferences();
255+
II->eraseFromParent();
256+
Erased.insert(II);
257+
}
258+
return !ToErase.empty();
259+
}
187260
} // namespace
188261

189262
PreservedAnalyses
190263
SYCLJointMatrixTransformPass::run(Module &M, ModuleAnalysisManager &MAM) {
191264
bool ModuleChanged = false;
265+
llvm::SmallVector<Function *, 1> ToErase;
192266
for (Function &F : M) {
193-
if (!F.isDeclaration())
194-
continue;
267+
if (!F.isDeclaration()) {
268+
if (F.getName() == MATRIX_LAYOUT) {
269+
ModuleChanged |= propagateConstexprLayout(&F);
270+
ToErase.push_back(&F);
271+
} else
272+
continue;
273+
}
195274
if (F.getName().starts_with(ACCESS_CHAIN))
196275
ModuleChanged |= transformAccessChain(&F);
197276
}
198277

278+
for (auto *F : ToErase)
279+
if (F->users().empty())
280+
F->eraseFromParent();
281+
199282
return ModuleChanged ? PreservedAnalyses::none() : PreservedAnalyses::all();
200283
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
; The test checks, that users of the call to joint_matrix_layout_to_spv matrix
2+
; are replaced with the layout constant.
3+
4+
; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
5+
6+
; ModuleID = 'test.bc'
7+
source_filename = "test.cpp"
8+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
9+
target triple = "spir64-unknown-unknown"
10+
11+
$joint_matrix_layout_to_spv = comdat any
12+
13+
; CHECK: define weak_odr dso_local spir_kernel void @test
14+
; CHECK-NEXT: entry:
15+
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 0, i64 noundef{{.*}}
16+
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 1, i64 noundef{{.*}}
17+
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 1, i64 noundef{{.*}}
18+
; CHECK-NEXT: %{{.*}} = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHR{{.*}}(ptr addrspace(1){{.*}}, i32 noundef 2, i64 noundef{{.*}}
19+
; CHECK-NEXT: ret void
20+
21+
; CHECK-NOT: joint_matrix_layout_to_spv
22+
23+
define weak_odr dso_local spir_kernel void @test(ptr addrspace(1) %matrix.1, ptr addrspace(1) %matrix.2, i64 noundef %stride) {
24+
entry:
25+
%layout.1 = alloca i32, align 4
26+
%layout.2 = alloca i32, align 4
27+
%layout.ascast.1 = addrspacecast ptr %layout.1 to ptr addrspace(4)
28+
%layout.ascast.2 = addrspacecast ptr %layout.2 to ptr addrspace(4)
29+
store i32 0, ptr addrspace(4) %layout.ascast.1, align 4
30+
store i32 1, ptr addrspace(4) %layout.ascast.2, align 4
31+
32+
%layout.val.1 = load i32, ptr addrspace(4) %layout.ascast.1, align 4
33+
%layout.spv.1 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.1)
34+
%mload.1 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.1, i32 noundef %layout.spv.1, i64 noundef %stride, i32 noundef 0)
35+
36+
%layout.val.2 = load i32, ptr addrspace(4) %layout.ascast.2, align 4
37+
%layout.spv.2 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2)
38+
%mload.2 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.2, i64 noundef %stride, i32 noundef 0)
39+
40+
%layout.spv.3 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.2)
41+
%mload.3 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.3, i64 noundef %stride, i32 noundef 0)
42+
43+
store i32 2, ptr addrspace(4) %layout.ascast.2, align 4
44+
%layout.val.4 = load i32, ptr addrspace(4) %layout.ascast.2, align 4
45+
%layout.spv.4 = call spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %layout.val.4)
46+
%mload.4 = call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef %matrix.2, i32 noundef %layout.spv.4, i64 noundef %stride, i32 noundef 0)
47+
ret void
48+
}
49+
50+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 16, 16, 2) @_Z32__spirv_CooperativeMatrixLoadKHRIU3AS1ffLm16ELm16ELN5__spv9MatrixUseE2ELNS1_12MatrixLayoutE3ELNS1_5Scope4FlagE3EEPNS1_28__spirv_CooperativeMatrixKHRIT0_XT5_EXT1_EXT2_EXT3_EEEPT_S3_mi(ptr addrspace(1) noundef, i32 noundef, i64 noundef, i32 noundef)
51+
52+
define linkonce_odr dso_local spir_func noundef i32 @joint_matrix_layout_to_spv(i32 noundef %Layout) comdat {
53+
entry:
54+
%retval = alloca i32, align 4
55+
%Layout.addr = alloca i32, align 4
56+
%retval.ascast = addrspacecast ptr %retval to ptr addrspace(4)
57+
%Layout.addr.ascast = addrspacecast ptr %Layout.addr to ptr addrspace(4)
58+
store i32 %Layout, ptr addrspace(4) %Layout.addr.ascast, align 4
59+
%0 = load i32, ptr addrspace(4) %Layout.addr.ascast, align 4
60+
switch i32 %0, label %sw.epilog [
61+
i32 0, label %sw.bb
62+
i32 1, label %sw.bb1
63+
i32 2, label %sw.bb2
64+
i32 3, label %sw.bb3
65+
]
66+
67+
sw.bb: ; preds = %entry
68+
store i32 0, ptr addrspace(4) %retval.ascast, align 4
69+
br label %return
70+
71+
sw.bb1: ; preds = %entry
72+
store i32 1, ptr addrspace(4) %retval.ascast, align 4
73+
br label %return
74+
75+
sw.bb2: ; preds = %entry
76+
store i32 2, ptr addrspace(4) %retval.ascast, align 4
77+
br label %return
78+
79+
sw.bb3: ; preds = %entry
80+
store i32 3, ptr addrspace(4) %retval.ascast, align 4
81+
br label %return
82+
83+
sw.epilog: ; preds = %entry
84+
call void @llvm.trap()
85+
unreachable
86+
87+
return: ; preds = %sw.bb3, %sw.bb2, %sw.bb1, %sw.bb
88+
%1 = load i32, ptr addrspace(4) %retval.ascast, align 4
89+
ret i32 %1
90+
}

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ convertMatrixUseStringToEnum(const char *UseString) {
6969
return std::nullopt;
7070
}
7171

72-
inline __SYCL_ALWAYS_INLINE __spv::MatrixLayout joint_matrix_layout_to_spv(
72+
// propagateConstexprLayout uses the exact name of the function, so we use
73+
// extern "C" here.
74+
extern "C" constexpr __spv::MatrixLayout joint_matrix_layout_to_spv(
7375
sycl::ext::oneapi::experimental::matrix::layout Layout) {
7476
switch (Layout) {
7577
case sycl::ext::oneapi::experimental::matrix::layout::row_major:

0 commit comments

Comments
 (0)