Skip to content

Commit 8471ff3

Browse files
authored
[SYCL] Add LowerInvokeSimd LLVM pass to lower invoke_simd for BEs' consumption. (#5864)
* [SYCL] Add LowerInvokeSimd LLVM pass to lower invoke_simd for BEs' consumption. Signed-off-by: Konstantin S Bobrovsky <konstantin.s.bobrovsky@intel.com>
1 parent b19e2e4 commit 8471ff3

File tree

8 files changed

+362
-0
lines changed

8 files changed

+362
-0
lines changed

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ void initializeStripSymbolsPass(PassRegistry&);
443443
void initializeStructurizeCFGLegacyPassPass(PassRegistry &);
444444
void initializeSYCLLowerWGScopeLegacyPassPass(PassRegistry &);
445445
void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
446+
void initializeSYCLLowerInvokeSimdLegacyPassPass(PassRegistry &);
446447
void initializeSYCLMutatePrintfAddrspaceLegacyPassPass(PassRegistry &);
447448
void initializeSPIRITTAnnotationsLegacyPassPass(PassRegistry &);
448449
void initializeESIMDLowerLoadStorePass(PassRegistry &);

llvm/include/llvm/LinkAllPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/IR/Function.h"
3939
#include "llvm/IR/IRPrintingPasses.h"
4040
#include "llvm/SYCLLowerIR/ESIMD/ESIMDVerifier.h"
41+
#include "llvm/SYCLLowerIR/LowerInvokeSimd.h"
4142
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
4243
#include "llvm/Support/Valgrind.h"
4344
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
@@ -210,6 +211,7 @@ namespace {
210211
(void) llvm::createExpandMemCmpPass();
211212
(void) llvm::createExpandVectorPredicationPass();
212213
(void)llvm::createESIMDVerifierPass();
214+
(void)llvm::createSYCLLowerInvokeSimdPass();
213215
std::string buf;
214216
llvm::raw_string_ostream os(buf);
215217
(void) llvm::createPrintModulePass(os);
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===---- LowerInvokeSimd.h - lower invoke_simd calls ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// This pass brings __builtin_invoke_simd intrinsic call to the form consumable
9+
// by the back ends:
10+
// - determines the "invokee" (call target) - actual function address link-time
11+
// constant (it can be represented as an SSA value in the input IR)
12+
// See more comments in the implementation.
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "llvm/IR/Module.h"
16+
#include "llvm/IR/PassManager.h"
17+
#include "llvm/Pass.h"
18+
19+
namespace llvm {
20+
class SYCLLowerInvokeSimdPass : public PassInfoMixin<SYCLLowerInvokeSimdPass> {
21+
public:
22+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
23+
};
24+
25+
ModulePass *createSYCLLowerInvokeSimdPass();
26+
void initializeSYCLLowerInvokeSimdLegacyPassPass(PassRegistry &);
27+
} // namespace llvm

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
#include "llvm/IR/Verifier.h"
8080
#include "llvm/SYCLLowerIR/ESIMD/ESIMDVerifier.h"
8181
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
82+
#include "llvm/SYCLLowerIR/LowerInvokeSimd.h"
8283
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
8384
#include "llvm/SYCLLowerIR/LowerWGScope.h"
8485
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"

llvm/lib/SYCLLowerIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
5353
ESIMD/LowerESIMDVLoadVStore.cpp
5454
ESIMD/LowerESIMDVecArg.cpp
5555
ESIMD/ESIMDVerifier.cpp
56+
LowerInvokeSimd.cpp
5657
LowerWGScope.cpp
5758
LowerWGLocalMemory.cpp
5859
MutatePrintfAddrspace.cpp
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
//===---- LowerInvokeSimd.h - lower invoke_simd calls ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Finds and lowers __builtin_invoke_simd calls generated by invoke_simd library
9+
// implementation:
10+
// - Performs data flow analysis for the first argument to determine which
11+
// function link-time constant address it is guaranteed to hold, and replaces
12+
// the argument with found "target" function.
13+
// - Marks target functions with VCStackCall attribute as required by the Intel
14+
// GPU backend.
15+
// TODO:
16+
// - move VCStackCall markup to Intel GPU-specific part (BE) or design a new
17+
// target-neutral attribute for markup.
18+
// - allow "unknown" function pointers, where actual function's address is not
19+
// deducible when BE is ready
20+
21+
#include "llvm/SYCLLowerIR/LowerInvokeSimd.h"
22+
23+
#include "llvm/ADT/SmallPtrSet.h"
24+
#include "llvm/GenXIntrinsics/GenXMetadata.h"
25+
#include "llvm/IR/Instructions.h"
26+
#include "llvm/IR/Module.h"
27+
#include "llvm/IR/Operator.h"
28+
#include "llvm/Pass.h"
29+
30+
#define DEBUG_TYPE "LowerInvokeSimd"
31+
32+
using namespace llvm;
33+
34+
namespace {
35+
class SYCLLowerInvokeSimdLegacyPass : public ModulePass {
36+
public:
37+
static char ID; // Pass identification, replacement for typeid
38+
SYCLLowerInvokeSimdLegacyPass() : ModulePass(ID) {
39+
initializeSYCLLowerInvokeSimdLegacyPassPass(
40+
*PassRegistry::getPassRegistry());
41+
}
42+
43+
// run the LowerESIMD pass on the specified module
44+
bool runOnModule(Module &M) override {
45+
ModuleAnalysisManager MAM;
46+
auto PA = Impl.run(M, MAM);
47+
return !PA.areAllPreserved();
48+
}
49+
50+
private:
51+
SYCLLowerInvokeSimdPass Impl;
52+
};
53+
} // namespace
54+
55+
char SYCLLowerInvokeSimdLegacyPass::ID = 0;
56+
INITIALIZE_PASS(SYCLLowerInvokeSimdLegacyPass, "SYCLLowerInvokeSimd",
57+
"Lower SYCL's invoke_simd calls", false, false)
58+
59+
// Public interface to the LowerInvokeSimdPass.
60+
ModulePass *llvm::createSYCLLowerInvokeSimdPass() {
61+
return new SYCLLowerInvokeSimdLegacyPass();
62+
}
63+
64+
namespace {
65+
// TODO support lambda and functor overloads
66+
// This is the prefixes of the names generated from
67+
// sycl/ext/oneapi/experimental/invoke_simd.hpp::__builtin_invoke_simd
68+
// overloads instantiations:
69+
constexpr char INVOKE_SIMD_PREF[] = "_Z33__regcall3____builtin_invoke_simd";
70+
71+
bool isCast(const Value *V) {
72+
int Opc = Operator::getOpcode(V);
73+
return (Opc == Instruction::BitCast) || (Opc == Instruction::AddrSpaceCast);
74+
}
75+
76+
using ValueSetImpl = SmallPtrSetImpl<Value *>;
77+
using ValueSet = SmallPtrSet<Value *, 4>;
78+
using ConstValueSetImpl = SmallPtrSetImpl<const Value *>;
79+
using ConstValueSet = SmallPtrSet<const Value *, 4>;
80+
81+
Value *stripCasts(Value *V) {
82+
if (!V->getType()->isPtrOrPtrVectorTy())
83+
return V;
84+
// Even though we don't look through PHI nodes, we could be called on an
85+
// instruction in an unreachable block, which may be on a cycle.
86+
ConstValueSet Visited;
87+
Visited.insert(V);
88+
89+
do {
90+
if (isCast(V)) {
91+
V = cast<Operator>(V)->getOperand(0);
92+
}
93+
assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!");
94+
} while (Visited.insert(V).second);
95+
return V;
96+
}
97+
98+
const Value *getSingleUserSkipCasts(const Value *V) {
99+
while (isCast(V)) {
100+
if (V->getNumUses() != 1) {
101+
return nullptr;
102+
}
103+
V = *(V->user_begin());
104+
}
105+
return V;
106+
}
107+
108+
void collectUsesSkipThroughCasts(Value *V, SmallPtrSetImpl<const Use *> &Uses) {
109+
for (Use &U : V->uses()) {
110+
Value *VV = U.getUser();
111+
112+
if (isCast(VV)) {
113+
collectUsesSkipThroughCasts(VV, Uses);
114+
} else {
115+
Uses.insert(&U);
116+
}
117+
}
118+
}
119+
120+
Value *getInvokeeIfInvokeSimdCall(const CallInst *CI) {
121+
Function *F = CI->getCalledFunction();
122+
123+
if (F && F->getName().startswith(INVOKE_SIMD_PREF)) {
124+
return CI->getArgOperand(0);
125+
}
126+
return nullptr;
127+
}
128+
129+
void getPossibleStoredVals(Value *Addr, ValueSetImpl &Vals) {
130+
ValueSet Visited;
131+
AllocaInst *LocalVar = dyn_cast_or_null<AllocaInst>(stripCasts(Addr));
132+
133+
if (!LocalVar) {
134+
llvm_unreachable("unsupported data flow pattern for invoke_simd 10");
135+
}
136+
SmallPtrSet<const Use *, 4> Uses;
137+
collectUsesSkipThroughCasts(LocalVar, Uses);
138+
139+
for (const Use *U : Uses) {
140+
Value *V = U->getUser();
141+
142+
if (auto *StI = dyn_cast<StoreInst>(V)) {
143+
constexpr int StoreInstValueOperandIndex = 0;
144+
145+
if (U != &StI->getOperandUse(StoreInst::getPointerOperandIndex())) {
146+
assert(U == &StI->getOperandUse(StoreInstValueOperandIndex));
147+
// this is double indirection - not supported
148+
llvm_unreachable("unsupported data flow pattern for invoke_simd 11");
149+
}
150+
V = stripCasts(StI->getValueOperand());
151+
152+
if (auto *LI = dyn_cast<LoadInst>(V)) {
153+
// A value loaded from another address is stored at this address -
154+
// recurse into the other address
155+
getPossibleStoredVals(LI->getPointerOperand(), Vals);
156+
} else {
157+
Vals.insert(V);
158+
}
159+
continue;
160+
}
161+
if (const auto *CI = dyn_cast<CallInst>(V)) {
162+
// only __builtin_invoke_simd is allowed, otherwise the pointer escapes
163+
if (!getInvokeeIfInvokeSimdCall(CI)) {
164+
llvm_unreachable("unsupported data flow pattern for invoke_simd 12");
165+
}
166+
continue;
167+
}
168+
if (const auto *LI = dyn_cast<LoadInst>(V)) {
169+
// LoadInst from this addr is OK, as it does not affect what can be stored
170+
// through the addr
171+
continue;
172+
}
173+
llvm_unreachable("unsupported data flow pattern for invoke_simd 13");
174+
}
175+
}
176+
177+
// Example1 (function is direct argument to
178+
// _Z33__regcall3____builtin_invoke_simd):
179+
// %call6.i = call spir_func float @_Z33__regcall3____builtin_invoke_simd...(
180+
// <16 x float> (float addrspace(4)*, <16 x float>, i32)* %28, <== function
181+
// pointer float addrspace(4)* %arg1, float %arg2, i32 %arg3)
182+
//
183+
// Example 2 (invoke_simd's target function pointer flows through IR):
184+
// %fptr_t = <16 x float> (float addrspace(4)*, <16 x float>, i32)*
185+
// ...
186+
// %fa_as0 = alloca %fptr_t
187+
// ...
188+
// %fa = addrspacecast %fptr_t* %fa_as0 to %fptr_t addrspace(4)*
189+
// ...
190+
// store %fptr_t @__SIMD_CALLEE, %fptr_t addrspace(4)* %fa
191+
// ...
192+
// %f = load %fptr_t, %fptr_t addrspace(4)* %fa
193+
// ...
194+
// %res = call spir_func float @_Z33__regcall3____builtin_invoke_simd...(
195+
// %fptr_t %f, <== function pointer
196+
// float addrspace(4)* %arg1,
197+
// float %arg2,
198+
// i32 %arg3)
199+
//
200+
bool processInvokeSimdCall(CallInst *CI) {
201+
Value *V = getInvokeeIfInvokeSimdCall(CI);
202+
203+
if (!V) {
204+
llvm_unreachable(("bad use of " + Twine(INVOKE_SIMD_PREF)).str().c_str());
205+
}
206+
auto *SimdF = dyn_cast<Function>(V);
207+
bool Modified = false;
208+
209+
if (!SimdF) {
210+
auto *LI = dyn_cast<LoadInst>(stripCasts(V));
211+
212+
if (!LI) {
213+
llvm_unreachable("unsupported data flow pattern for invoke_simd 0");
214+
}
215+
ValueSet Vals;
216+
getPossibleStoredVals(LI->getPointerOperand(), Vals);
217+
218+
if (Vals.size() != 1 || !(SimdF = dyn_cast<Function>(*Vals.begin()))) {
219+
llvm_unreachable("unsupported data flow pattern for invoke_simd 1");
220+
}
221+
// _Z33__regcall3____builtin_invoke_simd invokee is an SSA value, replace it
222+
// with the link-time constant SimdF as computed by getPossibleStoredVals
223+
auto *CI1 = cast<CallInst>(CI->clone());
224+
constexpr int SimdInvokeInvokeeArgIndex = 0;
225+
CI1->setOperand(SimdInvokeInvokeeArgIndex, SimdF);
226+
CI1->insertAfter(CI);
227+
CI->replaceAllUsesWith(CI1);
228+
CI->eraseFromParent();
229+
Modified = true;
230+
}
231+
if (!SimdF->hasFnAttribute(llvm::genx::VCFunctionMD::VCStackCall)) {
232+
SimdF->addFnAttr(llvm::genx::VCFunctionMD::VCStackCall);
233+
}
234+
return Modified;
235+
}
236+
} // namespace
237+
238+
namespace llvm {
239+
PreservedAnalyses SYCLLowerInvokeSimdPass::run(Module &M,
240+
ModuleAnalysisManager &MAM) {
241+
bool Modified = false;
242+
243+
for (Function &F : M) {
244+
if (!F.isDeclaration() || !F.getName().startswith(INVOKE_SIMD_PREF)) {
245+
continue;
246+
}
247+
SmallVector<User *, 4> Users(F.users());
248+
for (User *Usr : Users) {
249+
// a call can be the only use of the invoke_simd built-in
250+
CallInst *CI = cast<CallInst>(Usr);
251+
Modified |= processInvokeSimdCall(CI);
252+
}
253+
}
254+
return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all();
255+
}
256+
} // namespace llvm
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
; RUN: opt -SYCLLowerInvokeSimd -S < %s | FileCheck %s
2+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
3+
target triple = "spir64-unknown-unknown"
4+
5+
; Function Attrs: convergent
6+
declare dso_local spir_func <16 x float> @__dummy_read(i64) #4
7+
8+
; Function Attrs: convergent
9+
declare dso_local spir_func float @_Z33__regcall3____builtin_invoke_simdXX(<16 x float> (float addrspace(4)*, <16 x float>, i32)*, float addrspace(4)*, float, i32) local_unnamed_addr #3
10+
11+
; Function Attrs: convergent mustprogress noinline norecurse optnone
12+
define dso_local x86_regcallcc <16 x float> @_SIMD_CALLEE(float addrspace(4)* %A, <16 x float> %non_uni_val, i32 %uni_val) #0 !sycl_explicit_simd !0 !intel_reqd_sub_group_size !0 {
13+
; Verify that correct attributes are attached to the function:
14+
; CHECK: {{.*}} @_SIMD_CALLEE(float addrspace(4)* %A, <16 x float> %non_uni_val, i32 %uni_val) #0
15+
entry:
16+
%AA = ptrtoint float addrspace(4)* %A to i64
17+
%ii = zext i32 %uni_val to i64
18+
%addr = add nuw nsw i64 %ii, %AA
19+
%data = call spir_func <16 x float> @__dummy_read(i64 %addr)
20+
%add = fadd <16 x float> %non_uni_val, %data
21+
ret <16 x float> %add
22+
}
23+
24+
; Function Attrs: convergent mustprogress noinline norecurse optnone
25+
define dso_local x86_regcallcc <16 x float> @_ANOTHER_SIMD_CALLEE(float addrspace(4)* %A, <16 x float> %non_uni_val, i32 %uni_val) #1 !sycl_explicit_simd !0 !intel_reqd_sub_group_size !0 {
26+
; Verify that correct attributes are attached to the function:
27+
; CHECK: {{.*}} @_ANOTHER_SIMD_CALLEE(float addrspace(4)* %A, <16 x float> %non_uni_val, i32 %uni_val) #1
28+
entry:
29+
%AA = ptrtoint float addrspace(4)* %A to i64
30+
%ii = zext i32 %uni_val to i64
31+
%addr = add nuw nsw i64 %ii, %AA
32+
%data = call spir_func <16 x float> @__dummy_read(i64 %addr)
33+
%add = fadd <16 x float> %non_uni_val, %data
34+
ret <16 x float> %add
35+
}
36+
37+
define internal spir_func float @foo(float addrspace(1)* %ptr, <16 x float> (float addrspace(4)*, <16 x float>, i32)* %raw_fptr) align 2 {
38+
entry:
39+
;------------- Typical data flow of the @_SIMD_CALLEE function address in worst
40+
;------------- case (-O0), when invoke_simd uses function name:
41+
;------------- float res = invoke_simd(sg, SIMD_CALLEE, uniform{ A }, x, uniform{ y });
42+
%f.addr.i = alloca <16 x float> (float addrspace(4)*, <16 x float>, i32)*, align 8
43+
%f.addr.ascast.i = addrspacecast <16 x float> (float addrspace(4)*, <16 x float>, i32)** %f.addr.i to <16 x float> (float addrspace(4)*, <16 x float>, i32)* addrspace(4)*
44+
store <16 x float> (float addrspace(4)*, <16 x float>, i32)* @_SIMD_CALLEE, <16 x float> (float addrspace(4)*, <16 x float>, i32)* addrspace(4)* %f.addr.ascast.i, align 8
45+
%FUNC_PTR = load <16 x float> (float addrspace(4)*, <16 x float>, i32)*, <16 x float> (float addrspace(4)*, <16 x float>, i32)* addrspace(4)* %f.addr.ascast.i, align 8
46+
47+
;------------- Data flow for the parameters of SIMD_CALLEE
48+
%param_A = addrspacecast float addrspace(1)* %ptr to float addrspace(4)*
49+
%param_non_uni_val = load float, float addrspace(4)* %param_A, align 4
50+
51+
;------------- The invoke_simd calls.
52+
%res1 = call spir_func float @_Z33__regcall3____builtin_invoke_simdXX(<16 x float> (float addrspace(4)*, <16 x float>, i32)* %FUNC_PTR, float addrspace(4)* %param_A, float %param_non_uni_val, i32 10)
53+
; Verify that %FUNC_PTR is replaced with @_SIMD_CALLEE:
54+
; CHECK: %{{.*}} = call spir_func float @_Z33__regcall3____builtin_invoke_simdXX(<16 x float> (float addrspace(4)*, <16 x float>, i32)* @_SIMD_CALLEE, float addrspace(4)* %param_A, float %param_non_uni_val, i32 10)
55+
56+
%res2 = call spir_func float @_Z33__regcall3____builtin_invoke_simdXX(<16 x float> (float addrspace(4)*, <16 x float>, i32)* @_ANOTHER_SIMD_CALLEE, float addrspace(4)* %param_A, float %param_non_uni_val, i32 10)
57+
; Verify that function address link-time constant is accepted by the pass and left as is:
58+
; CHECK: = call spir_func float @_Z33__regcall3____builtin_invoke_simdXX(<16 x float> (float addrspace(4)*, <16 x float>, i32)* @_ANOTHER_SIMD_CALLEE, float addrspace(4)* %param_A, float %param_non_uni_val, i32 10)
59+
60+
; TODO: enable in the test and LowerInvokeSimd when BE is ready, crash for now:
61+
;%res3 %{{.*}} = call spir_func float @_Z33__regcall3____builtin_invoke_simdXX(<16 x float> (float addrspace(4)*, <16 x float>, i32)* %raw_fptr, float addrspace(4)* %param_A, float %param_non_uni_val, i32 10)
62+
%res = fadd float %res1, %res2
63+
ret float %res
64+
}
65+
66+
; Check that VCStackCall attribute is added to the invoke_simd target functions:
67+
attributes #0 = { convergent mustprogress norecurse "frame-pointer"="all" "min-legal-vector-width"="512" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="invoke_simd.cpp" }
68+
; CHECK: attributes #0 = { convergent mustprogress norecurse "VCStackCall" {{.*}} "sycl-module-id"="invoke_simd.cpp" }
69+
attributes #1 = { convergent mustprogress norecurse "sycl-module-id"="invoke_simd.cpp" }
70+
; CHECK: attributes #1 = { convergent mustprogress norecurse "VCStackCall" "sycl-module-id"="invoke_simd.cpp" }
71+
72+
!0 = !{}
73+
!1 = !{i32 16}

0 commit comments

Comments
 (0)