|
| 1 | +//===---- LowerInvokeSimd.h - lower invoke_simd calls ---------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// Finds and lowers __builtin_invoke_simd calls generated by invoke_simd library |
| 9 | +// implementation: |
| 10 | +// - Performs data flow analysis for the first argument to determine which |
| 11 | +// function link-time constant address it is guaranteed to hold, and replaces |
| 12 | +// the argument with found "target" function. |
| 13 | +// - Marks target functions with VCStackCall attribute as required by the Intel |
| 14 | +// GPU backend. |
| 15 | +// TODO: |
| 16 | +// - move VCStackCall markup to Intel GPU-specific part (BE) or design a new |
| 17 | +// target-neutral attribute for markup. |
| 18 | +// - allow "unknown" function pointers, where actual function's address is not |
| 19 | +// deducible when BE is ready |
| 20 | + |
| 21 | +#include "llvm/SYCLLowerIR/LowerInvokeSimd.h" |
| 22 | + |
| 23 | +#include "llvm/ADT/SmallPtrSet.h" |
| 24 | +#include "llvm/GenXIntrinsics/GenXMetadata.h" |
| 25 | +#include "llvm/IR/Instructions.h" |
| 26 | +#include "llvm/IR/Module.h" |
| 27 | +#include "llvm/IR/Operator.h" |
| 28 | +#include "llvm/Pass.h" |
| 29 | + |
| 30 | +#define DEBUG_TYPE "LowerInvokeSimd" |
| 31 | + |
| 32 | +using namespace llvm; |
| 33 | + |
| 34 | +namespace { |
| 35 | +class SYCLLowerInvokeSimdLegacyPass : public ModulePass { |
| 36 | +public: |
| 37 | + static char ID; // Pass identification, replacement for typeid |
| 38 | + SYCLLowerInvokeSimdLegacyPass() : ModulePass(ID) { |
| 39 | + initializeSYCLLowerInvokeSimdLegacyPassPass( |
| 40 | + *PassRegistry::getPassRegistry()); |
| 41 | + } |
| 42 | + |
| 43 | + // run the LowerESIMD pass on the specified module |
| 44 | + bool runOnModule(Module &M) override { |
| 45 | + ModuleAnalysisManager MAM; |
| 46 | + auto PA = Impl.run(M, MAM); |
| 47 | + return !PA.areAllPreserved(); |
| 48 | + } |
| 49 | + |
| 50 | +private: |
| 51 | + SYCLLowerInvokeSimdPass Impl; |
| 52 | +}; |
| 53 | +} // namespace |
| 54 | + |
| 55 | +char SYCLLowerInvokeSimdLegacyPass::ID = 0; |
| 56 | +INITIALIZE_PASS(SYCLLowerInvokeSimdLegacyPass, "SYCLLowerInvokeSimd", |
| 57 | + "Lower SYCL's invoke_simd calls", false, false) |
| 58 | + |
| 59 | +// Public interface to the LowerInvokeSimdPass. |
| 60 | +ModulePass *llvm::createSYCLLowerInvokeSimdPass() { |
| 61 | + return new SYCLLowerInvokeSimdLegacyPass(); |
| 62 | +} |
| 63 | + |
| 64 | +namespace { |
| 65 | +// TODO support lambda and functor overloads |
| 66 | +// This is the prefixes of the names generated from |
| 67 | +// sycl/ext/oneapi/experimental/invoke_simd.hpp::__builtin_invoke_simd |
| 68 | +// overloads instantiations: |
| 69 | +constexpr char INVOKE_SIMD_PREF[] = "_Z33__regcall3____builtin_invoke_simd"; |
| 70 | + |
| 71 | +bool isCast(const Value *V) { |
| 72 | + int Opc = Operator::getOpcode(V); |
| 73 | + return (Opc == Instruction::BitCast) || (Opc == Instruction::AddrSpaceCast); |
| 74 | +} |
| 75 | + |
| 76 | +using ValueSetImpl = SmallPtrSetImpl<Value *>; |
| 77 | +using ValueSet = SmallPtrSet<Value *, 4>; |
| 78 | +using ConstValueSetImpl = SmallPtrSetImpl<const Value *>; |
| 79 | +using ConstValueSet = SmallPtrSet<const Value *, 4>; |
| 80 | + |
| 81 | +Value *stripCasts(Value *V) { |
| 82 | + if (!V->getType()->isPtrOrPtrVectorTy()) |
| 83 | + return V; |
| 84 | + // Even though we don't look through PHI nodes, we could be called on an |
| 85 | + // instruction in an unreachable block, which may be on a cycle. |
| 86 | + ConstValueSet Visited; |
| 87 | + Visited.insert(V); |
| 88 | + |
| 89 | + do { |
| 90 | + if (isCast(V)) { |
| 91 | + V = cast<Operator>(V)->getOperand(0); |
| 92 | + } |
| 93 | + assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!"); |
| 94 | + } while (Visited.insert(V).second); |
| 95 | + return V; |
| 96 | +} |
| 97 | + |
| 98 | +const Value *getSingleUserSkipCasts(const Value *V) { |
| 99 | + while (isCast(V)) { |
| 100 | + if (V->getNumUses() != 1) { |
| 101 | + return nullptr; |
| 102 | + } |
| 103 | + V = *(V->user_begin()); |
| 104 | + } |
| 105 | + return V; |
| 106 | +} |
| 107 | + |
| 108 | +void collectUsesSkipThroughCasts(Value *V, SmallPtrSetImpl<const Use *> &Uses) { |
| 109 | + for (Use &U : V->uses()) { |
| 110 | + Value *VV = U.getUser(); |
| 111 | + |
| 112 | + if (isCast(VV)) { |
| 113 | + collectUsesSkipThroughCasts(VV, Uses); |
| 114 | + } else { |
| 115 | + Uses.insert(&U); |
| 116 | + } |
| 117 | + } |
| 118 | +} |
| 119 | + |
| 120 | +Value *getInvokeeIfInvokeSimdCall(const CallInst *CI) { |
| 121 | + Function *F = CI->getCalledFunction(); |
| 122 | + |
| 123 | + if (F && F->getName().startswith(INVOKE_SIMD_PREF)) { |
| 124 | + return CI->getArgOperand(0); |
| 125 | + } |
| 126 | + return nullptr; |
| 127 | +} |
| 128 | + |
| 129 | +void getPossibleStoredVals(Value *Addr, ValueSetImpl &Vals) { |
| 130 | + ValueSet Visited; |
| 131 | + AllocaInst *LocalVar = dyn_cast_or_null<AllocaInst>(stripCasts(Addr)); |
| 132 | + |
| 133 | + if (!LocalVar) { |
| 134 | + llvm_unreachable("unsupported data flow pattern for invoke_simd 10"); |
| 135 | + } |
| 136 | + SmallPtrSet<const Use *, 4> Uses; |
| 137 | + collectUsesSkipThroughCasts(LocalVar, Uses); |
| 138 | + |
| 139 | + for (const Use *U : Uses) { |
| 140 | + Value *V = U->getUser(); |
| 141 | + |
| 142 | + if (auto *StI = dyn_cast<StoreInst>(V)) { |
| 143 | + constexpr int StoreInstValueOperandIndex = 0; |
| 144 | + |
| 145 | + if (U != &StI->getOperandUse(StoreInst::getPointerOperandIndex())) { |
| 146 | + assert(U == &StI->getOperandUse(StoreInstValueOperandIndex)); |
| 147 | + // this is double indirection - not supported |
| 148 | + llvm_unreachable("unsupported data flow pattern for invoke_simd 11"); |
| 149 | + } |
| 150 | + V = stripCasts(StI->getValueOperand()); |
| 151 | + |
| 152 | + if (auto *LI = dyn_cast<LoadInst>(V)) { |
| 153 | + // A value loaded from another address is stored at this address - |
| 154 | + // recurse into the other address |
| 155 | + getPossibleStoredVals(LI->getPointerOperand(), Vals); |
| 156 | + } else { |
| 157 | + Vals.insert(V); |
| 158 | + } |
| 159 | + continue; |
| 160 | + } |
| 161 | + if (const auto *CI = dyn_cast<CallInst>(V)) { |
| 162 | + // only __builtin_invoke_simd is allowed, otherwise the pointer escapes |
| 163 | + if (!getInvokeeIfInvokeSimdCall(CI)) { |
| 164 | + llvm_unreachable("unsupported data flow pattern for invoke_simd 12"); |
| 165 | + } |
| 166 | + continue; |
| 167 | + } |
| 168 | + if (const auto *LI = dyn_cast<LoadInst>(V)) { |
| 169 | + // LoadInst from this addr is OK, as it does not affect what can be stored |
| 170 | + // through the addr |
| 171 | + continue; |
| 172 | + } |
| 173 | + llvm_unreachable("unsupported data flow pattern for invoke_simd 13"); |
| 174 | + } |
| 175 | +} |
| 176 | + |
| 177 | +// Example1 (function is direct argument to |
| 178 | +// _Z33__regcall3____builtin_invoke_simd): |
| 179 | +// %call6.i = call spir_func float @_Z33__regcall3____builtin_invoke_simd...( |
| 180 | +// <16 x float> (float addrspace(4)*, <16 x float>, i32)* %28, <== function |
| 181 | +// pointer float addrspace(4)* %arg1, float %arg2, i32 %arg3) |
| 182 | +// |
| 183 | +// Example 2 (invoke_simd's target function pointer flows through IR): |
| 184 | +// %fptr_t = <16 x float> (float addrspace(4)*, <16 x float>, i32)* |
| 185 | +// ... |
| 186 | +// %fa_as0 = alloca %fptr_t |
| 187 | +// ... |
| 188 | +// %fa = addrspacecast %fptr_t* %fa_as0 to %fptr_t addrspace(4)* |
| 189 | +// ... |
| 190 | +// store %fptr_t @__SIMD_CALLEE, %fptr_t addrspace(4)* %fa |
| 191 | +// ... |
| 192 | +// %f = load %fptr_t, %fptr_t addrspace(4)* %fa |
| 193 | +// ... |
| 194 | +// %res = call spir_func float @_Z33__regcall3____builtin_invoke_simd...( |
| 195 | +// %fptr_t %f, <== function pointer |
| 196 | +// float addrspace(4)* %arg1, |
| 197 | +// float %arg2, |
| 198 | +// i32 %arg3) |
| 199 | +// |
| 200 | +bool processInvokeSimdCall(CallInst *CI) { |
| 201 | + Value *V = getInvokeeIfInvokeSimdCall(CI); |
| 202 | + |
| 203 | + if (!V) { |
| 204 | + llvm_unreachable(("bad use of " + Twine(INVOKE_SIMD_PREF)).str().c_str()); |
| 205 | + } |
| 206 | + auto *SimdF = dyn_cast<Function>(V); |
| 207 | + bool Modified = false; |
| 208 | + |
| 209 | + if (!SimdF) { |
| 210 | + auto *LI = dyn_cast<LoadInst>(stripCasts(V)); |
| 211 | + |
| 212 | + if (!LI) { |
| 213 | + llvm_unreachable("unsupported data flow pattern for invoke_simd 0"); |
| 214 | + } |
| 215 | + ValueSet Vals; |
| 216 | + getPossibleStoredVals(LI->getPointerOperand(), Vals); |
| 217 | + |
| 218 | + if (Vals.size() != 1 || !(SimdF = dyn_cast<Function>(*Vals.begin()))) { |
| 219 | + llvm_unreachable("unsupported data flow pattern for invoke_simd 1"); |
| 220 | + } |
| 221 | + // _Z33__regcall3____builtin_invoke_simd invokee is an SSA value, replace it |
| 222 | + // with the link-time constant SimdF as computed by getPossibleStoredVals |
| 223 | + auto *CI1 = cast<CallInst>(CI->clone()); |
| 224 | + constexpr int SimdInvokeInvokeeArgIndex = 0; |
| 225 | + CI1->setOperand(SimdInvokeInvokeeArgIndex, SimdF); |
| 226 | + CI1->insertAfter(CI); |
| 227 | + CI->replaceAllUsesWith(CI1); |
| 228 | + CI->eraseFromParent(); |
| 229 | + Modified = true; |
| 230 | + } |
| 231 | + if (!SimdF->hasFnAttribute(llvm::genx::VCFunctionMD::VCStackCall)) { |
| 232 | + SimdF->addFnAttr(llvm::genx::VCFunctionMD::VCStackCall); |
| 233 | + } |
| 234 | + return Modified; |
| 235 | +} |
| 236 | +} // namespace |
| 237 | + |
| 238 | +namespace llvm { |
| 239 | +PreservedAnalyses SYCLLowerInvokeSimdPass::run(Module &M, |
| 240 | + ModuleAnalysisManager &MAM) { |
| 241 | + bool Modified = false; |
| 242 | + |
| 243 | + for (Function &F : M) { |
| 244 | + if (!F.isDeclaration() || !F.getName().startswith(INVOKE_SIMD_PREF)) { |
| 245 | + continue; |
| 246 | + } |
| 247 | + SmallVector<User *, 4> Users(F.users()); |
| 248 | + for (User *Usr : Users) { |
| 249 | + // a call can be the only use of the invoke_simd built-in |
| 250 | + CallInst *CI = cast<CallInst>(Usr); |
| 251 | + Modified |= processInvokeSimdCall(CI); |
| 252 | + } |
| 253 | + } |
| 254 | + return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
| 255 | +} |
| 256 | +} // namespace llvm |
0 commit comments