Skip to content

Commit 461f493

Browse files
svenvhdwoodwor-intel
authored andcommitted
Move pass declarations into separate header files
Prepare for migrating SPIRVWriter to the new PassManager by making sure the new non-legacy pass declarations are visible outside of their .cpp file. Original commit: KhronosGroup/SPIRV-LLVM-Translator@337b171
1 parent 49e2d09 commit 461f493

17 files changed

+1351
-852
lines changed

llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 52 additions & 281 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,16 @@
3838
//===----------------------------------------------------------------------===//
3939
#define DEBUG_TYPE "ocl-to-spv"
4040

41+
#include "OCLToSPIRV.h"
4142
#include "OCLTypeToSPIRV.h"
42-
#include "OCLUtil.h"
4343
#include "SPIRVInternal.h"
4444
#include "libSPIRV/SPIRVDebug.h"
4545

4646
#include "llvm/ADT/StringSwitch.h"
4747
#include "llvm/Analysis/ValueTracking.h"
4848
#include "llvm/IR/IRBuilder.h"
49-
#include "llvm/IR/InstVisitor.h"
5049
#include "llvm/IR/Instruction.h"
5150
#include "llvm/IR/Instructions.h"
52-
#include "llvm/IR/PassManager.h"
53-
#include "llvm/Pass.h"
5451
#include "llvm/Support/Debug.h"
5552

5653
#include <algorithm>
@@ -87,290 +84,64 @@ static Type *getBlockStructType(Value *Parameter) {
8784
return ParamType;
8885
}
8986

90-
class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase> {
91-
public:
92-
OCLToSPIRVBase() : M(nullptr), Ctx(nullptr), CLVer(0) {}
93-
virtual ~OCLToSPIRVBase() {}
94-
bool runOCLToSPIRV(Module &M);
95-
96-
virtual void visitCallInst(CallInst &CI);
97-
98-
/// Transform barrier/work_group_barrier/sub_group_barrier
99-
/// to __spirv_ControlBarrier.
100-
/// barrier(flag) =>
101-
/// __spirv_ControlBarrier(workgroup, workgroup, map(flag))
102-
/// work_group_barrier(scope, flag) =>
103-
/// __spirv_ControlBarrier(workgroup, map(scope), map(flag))
104-
/// sub_group_barrier(scope, flag) =>
105-
/// __spirv_ControlBarrier(subgroup, map(scope), map(flag))
106-
void visitCallBarrier(CallInst *CI);
107-
108-
/// Erase useless convert functions.
109-
/// \return true if the call instruction is erased.
110-
bool eraseUselessConvert(CallInst *Call, StringRef MangledName,
111-
StringRef DeMangledName);
112-
113-
/// Transform convert_ to
114-
/// __spirv_{CastOpName}_R{TargeTyName}{_sat}{_rt[p|n|z|e]}
115-
void visitCallConvert(CallInst *CI, StringRef MangledName,
116-
StringRef DemangledName);
117-
118-
/// Transform async_work_group{_strided}_copy.
119-
/// async_work_group_copy(dst, src, n, event)
120-
/// => async_work_group_strided_copy(dst, src, n, 1, event)
121-
/// async_work_group_strided_copy(dst, src, n, stride, event)
122-
/// => __spirv_AsyncGroupCopy(ScopeWorkGroup, dst, src, n, stride, event)
123-
void visitCallAsyncWorkGroupCopy(CallInst *CI, StringRef DemangledName);
124-
125-
/// Transform OCL builtin function to SPIR-V builtin function.
126-
void transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info);
127-
128-
/// Transform atomic_work_item_fence/mem_fence to __spirv_MemoryBarrier.
129-
/// func(flag, order, scope) =>
130-
/// __spirv_MemoryBarrier(map(scope), map(flag)|map(order))
131-
void transMemoryBarrier(CallInst *CI, AtomicWorkItemFenceLiterals);
132-
133-
/// Transform all to __spirv_Op(All|Any). Note that the types mismatch so
134-
// some extra code is emitted to convert between the two.
135-
void visitCallAllAny(spv::Op OC, CallInst *CI);
136-
137-
/// Transform atomic_* to __spirv_Atomic*.
138-
/// atomic_x(ptr_arg, args, order, scope) =>
139-
/// __spirv_AtomicY(ptr_arg, map(order), map(scope), args)
140-
void transAtomicBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info);
141-
142-
/// Transform atomic_work_item_fence to __spirv_MemoryBarrier.
143-
/// atomic_work_item_fence(flag, order, scope) =>
144-
/// __spirv_MemoryBarrier(map(scope), map(flag)|map(order))
145-
void visitCallAtomicWorkItemFence(CallInst *CI);
146-
147-
/// Transform atomic_compare_exchange call.
148-
/// In atomic_compare_exchange, the expected value parameter is a pointer.
149-
/// However in SPIR-V it is a value. The transformation adds a load
150-
/// instruction, result of which is passed to atomic_compare_exchange as
151-
/// argument.
152-
/// The transformation adds a store instruction after the call, to update the
153-
/// value in expected with the value pointed to by object. Though, it is not
154-
/// necessary in case they are equal, this approach makes result code simpler.
155-
/// Also ICmp instruction is added, because the call must return result of
156-
/// comparison.
157-
/// \returns the call instruction of atomic_compare_exchange_strong.
158-
CallInst *visitCallAtomicCmpXchg(CallInst *CI);
159-
160-
/// Transform atomic_init.
161-
/// atomic_init(p, x) => store p, x
162-
void visitCallAtomicInit(CallInst *CI);
163-
164-
/// Transform legacy OCL 1.x atomic builtins to SPIR-V builtins for extensions
165-
/// cl_khr_int64_base_atomics
166-
/// cl_khr_int64_extended_atomics
167-
/// Do nothing if the called function is not a legacy atomic builtin.
168-
void visitCallAtomicLegacy(CallInst *CI, StringRef MangledName,
169-
StringRef DemangledName);
170-
171-
/// Transform OCL 2.0 C++11 atomic builtins to SPIR-V builtins.
172-
/// Do nothing if the called function is not a C++11 atomic builtin.
173-
void visitCallAtomicCpp11(CallInst *CI, StringRef MangledName,
174-
StringRef DemangledName);
175-
176-
/// Transform OCL builtin function to SPIR-V builtin function.
177-
/// Assuming there is a simple name mapping without argument changes.
178-
/// Should be called at last.
179-
void visitCallBuiltinSimple(CallInst *CI, StringRef MangledName,
180-
StringRef DemangledName);
181-
182-
/// Transform get_image_{width|height|depth|dim}.
183-
/// get_image_xxx(...) =>
184-
/// dimension = __spirv_ImageQuerySizeLod_R{ReturnType}(...);
185-
/// return dimension.{x|y|z};
186-
void visitCallGetImageSize(CallInst *CI, StringRef DemangledName);
187-
188-
/// Transform {work|sub}_group_x =>
189-
/// __spirv_{OpName}
190-
///
191-
/// Special handling of work_group_broadcast.
192-
/// work_group_broadcast(a, x, y, z)
193-
/// =>
194-
/// __spirv_GroupBroadcast(a, vec3(x, y, z))
195-
196-
void visitCallGroupBuiltin(CallInst *CI, StringRef DemangledName);
197-
198-
/// Transform mem_fence to __spirv_MemoryBarrier.
199-
/// mem_fence(flag) => __spirv_MemoryBarrier(Workgroup, map(flag))
200-
void visitCallMemFence(CallInst *CI, StringRef DemangledName);
201-
202-
void visitCallNDRange(CallInst *CI, StringRef DemangledName);
203-
204-
/// Transform read_image with sampler arguments.
205-
/// read_image(image, sampler, ...) =>
206-
/// sampled_image = __spirv_SampledImage(image, sampler);
207-
/// return __spirv_ImageSampleExplicitLod_R{ReturnType}(sampled_image, ...);
208-
void visitCallReadImageWithSampler(CallInst *CI, StringRef MangledName);
209-
210-
/// Transform read_image with msaa image arguments.
211-
/// Sample argument must be acoded as Image Operand.
212-
void visitCallReadImageMSAA(CallInst *CI, StringRef MangledName);
213-
214-
/// Transform {read|write}_image without sampler arguments.
215-
void visitCallReadWriteImage(CallInst *CI, StringRef DemangledName);
216-
217-
/// Transform to_{global|local|private}.
218-
///
219-
/// T* a = ...;
220-
/// addr T* b = to_addr(a);
221-
/// =>
222-
/// i8* x = cast<i8*>(a);
223-
/// addr i8* y = __spirv_GenericCastToPtr_ToAddr(x);
224-
/// addr T* b = cast<addr T*>(y);
225-
void visitCallToAddr(CallInst *CI, StringRef DemangledName);
226-
227-
/// Transform return type of relatinal built-in functions like isnan, isfinite
228-
/// to boolean values.
229-
void visitCallRelational(CallInst *CI, StringRef DemangledName);
230-
231-
/// Transform vector load/store functions to SPIR-V extended builtin
232-
/// functions
233-
/// {vload|vstore{a}}{_half}{n}{_rte|_rtz|_rtp|_rtn} =>
234-
/// __spirv_ocl_{ExtendedInstructionOpCodeName}__R{ReturnType}
235-
void visitCallVecLoadStore(CallInst *CI, StringRef MangledName,
236-
StringRef DemangledName);
237-
238-
/// Transforms get_mem_fence built-in to SPIR-V function and aligns result
239-
/// values with SPIR 1.2. get_mem_fence(ptr) => __spirv_GenericPtrMemSemantics
240-
/// GenericPtrMemSemantics valid values are 0x100, 0x200 and 0x300, where is
241-
/// SPIR 1.2 defines them as 0x1, 0x2 and 0x3, so this function adjusts
242-
/// GenericPtrMemSemantics results to SPIR 1.2 values.
243-
void visitCallGetFence(CallInst *CI, StringRef DemangledName);
244-
245-
/// Transforms OpDot instructions with a scalar type to a fmul instruction
246-
void visitCallDot(CallInst *CI);
247-
248-
/// Fixes for built-in functions with vector+scalar arguments that are
249-
/// translated to the SPIR-V instructions where all arguments must have the
250-
/// same type.
251-
void visitCallScalToVec(CallInst *CI, StringRef MangledName,
252-
StringRef DemangledName);
253-
254-
/// Transform get_image_channel_{order|data_type} built-in functions to
255-
/// __spirv_ocl_{ImageQueryOrder|ImageQueryFormat}
256-
void visitCallGetImageChannel(CallInst *CI, StringRef DemangledName,
257-
unsigned int Offset);
258-
259-
/// Transform enqueue_kernel and kernel query built-in functions to
260-
/// spirv-friendly format filling arguments, required for device-side enqueue
261-
/// instructions, but missed in the original call
262-
void visitCallEnqueueKernel(CallInst *CI, StringRef DemangledName);
263-
void visitCallKernelQuery(CallInst *CI, StringRef DemangledName);
264-
265-
/// For cl_intel_subgroups block read built-ins:
266-
void visitSubgroupBlockReadINTEL(CallInst *CI);
267-
268-
/// For cl_intel_subgroups block write built-ins:
269-
void visitSubgroupBlockWriteINTEL(CallInst *CI);
270-
271-
/// For cl_intel_media_block_io built-ins:
272-
void visitSubgroupImageMediaBlockINTEL(CallInst *CI, StringRef DemangledName);
273-
// For cl_intel_device_side_avc_motion_estimation built-ins
274-
void visitSubgroupAVCBuiltinCall(CallInst *CI, StringRef DemangledName);
275-
void visitSubgroupAVCWrapperBuiltinCall(CallInst *CI, Op WrappedOC,
276-
StringRef DemangledName);
277-
void visitSubgroupAVCBuiltinCallWithSampler(CallInst *CI,
278-
StringRef DemangledName);
279-
280-
void visitCallLdexp(CallInst *CI, StringRef MangledName,
281-
StringRef DemangledName);
282-
283-
/// For cl_intel_convert_bfloat16_as_ushort
284-
void visitCallConvertBFloat16AsUshort(CallInst *CI, StringRef DemangledName);
285-
/// For cl_intel_convert_as_bfloat16_float
286-
void visitCallConvertAsBFloat16Float(CallInst *CI, StringRef DemangledName);
287-
288-
void setOCLTypeToSPIRV(OCLTypeToSPIRVBase *OCLTypeToSPIRV) {
289-
OCLTypeToSPIRVPtr = OCLTypeToSPIRV;
290-
}
291-
OCLTypeToSPIRVBase *getOCLTypeToSPIRV() { return OCLTypeToSPIRVPtr; }
292-
293-
private:
294-
Module *M;
295-
LLVMContext *Ctx;
296-
unsigned CLVer; /// OpenCL version as major*10+minor
297-
std::set<Value *> ValuesToDelete;
298-
OCLTypeToSPIRVBase *OCLTypeToSPIRVPtr;
299-
300-
ConstantInt *addInt32(int I) { return getInt32(M, I); }
301-
ConstantInt *addSizet(uint64_t I) { return getSizet(M, I); }
302-
303-
/// Get vector width from OpenCL vload* function name.
304-
SPIRVWord getVecLoadWidth(const std::string &DemangledName) {
305-
SPIRVWord Width = 0;
306-
if (DemangledName == "vloada_half")
307-
Width = 1;
308-
else {
309-
unsigned Loc = 5;
310-
if (DemangledName.find("vload_half") == 0)
311-
Loc = 10;
312-
else if (DemangledName.find("vloada_half") == 0)
313-
Loc = 11;
314-
315-
std::stringstream SS(DemangledName.substr(Loc));
316-
SS >> Width;
317-
}
318-
return Width;
319-
}
87+
bool OCLToSPIRVLegacy::runOnModule(Module &M) {
88+
setOCLTypeToSPIRV(&getAnalysis<OCLTypeToSPIRVLegacy>());
89+
return runOCLToSPIRV(M);
90+
}
32091

321-
/// Transform OpenCL vload/vstore function name.
322-
void transVecLoadStoreName(std::string &DemangledName,
323-
const std::string &Stem, bool AlwaysN) {
324-
auto HalfStem = Stem + "_half";
325-
auto HalfStemR = HalfStem + "_r";
326-
if (!AlwaysN && DemangledName == HalfStem)
327-
return;
328-
if (!AlwaysN && DemangledName.find(HalfStemR) == 0) {
329-
DemangledName = HalfStemR;
330-
return;
331-
}
332-
if (DemangledName.find(HalfStem) == 0) {
333-
auto OldName = DemangledName;
334-
DemangledName = HalfStem + "n";
335-
if (OldName.find("_r") != std::string::npos)
336-
DemangledName += "_r";
337-
return;
338-
}
339-
if (DemangledName.find(Stem) == 0) {
340-
DemangledName = Stem + "n";
341-
return;
342-
}
343-
}
344-
};
92+
void OCLToSPIRVLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
93+
AU.addRequired<OCLTypeToSPIRVLegacy>();
94+
}
34595

346-
class OCLToSPIRVLegacy : public OCLToSPIRVBase, public llvm::ModulePass {
347-
public:
348-
OCLToSPIRVLegacy() : ModulePass(ID) {
349-
initializeOCLToSPIRVLegacyPass(*PassRegistry::getPassRegistry());
350-
}
96+
llvm::PreservedAnalyses OCLToSPIRVPass::run(llvm::Module &M,
97+
llvm::ModuleAnalysisManager &MAM) {
98+
setOCLTypeToSPIRV(&MAM.getResult<OCLTypeToSPIRVPass>(M));
99+
return runOCLToSPIRV(M) ? llvm::PreservedAnalyses::none()
100+
: llvm::PreservedAnalyses::all();
101+
}
351102

352-
bool runOnModule(Module &M) override {
353-
setOCLTypeToSPIRV(&getAnalysis<OCLTypeToSPIRVLegacy>());
354-
return runOCLToSPIRV(M);
103+
/// Get vector width from OpenCL vload* function name.
104+
SPIRVWord OCLToSPIRVBase::getVecLoadWidth(const std::string &DemangledName) {
105+
SPIRVWord Width = 0;
106+
if (DemangledName == "vloada_half")
107+
Width = 1;
108+
else {
109+
unsigned Loc = 5;
110+
if (DemangledName.find("vload_half") == 0)
111+
Loc = 10;
112+
else if (DemangledName.find("vloada_half") == 0)
113+
Loc = 11;
114+
115+
std::stringstream SS(DemangledName.substr(Loc));
116+
SS >> Width;
355117
}
118+
return Width;
119+
}
356120

357-
void getAnalysisUsage(AnalysisUsage &AU) const override {
358-
AU.addRequired<OCLTypeToSPIRVLegacy>();
121+
/// Transform OpenCL vload/vstore function name.
122+
void OCLToSPIRVBase::transVecLoadStoreName(std::string &DemangledName,
123+
const std::string &Stem,
124+
bool AlwaysN) {
125+
auto HalfStem = Stem + "_half";
126+
auto HalfStemR = HalfStem + "_r";
127+
if (!AlwaysN && DemangledName == HalfStem)
128+
return;
129+
if (!AlwaysN && DemangledName.find(HalfStemR) == 0) {
130+
DemangledName = HalfStemR;
131+
return;
359132
}
360-
361-
static char ID;
362-
};
363-
364-
class OCLToSPIRVPass : public OCLToSPIRVBase,
365-
public llvm::PassInfoMixin<OCLToSPIRVBase> {
366-
public:
367-
llvm::PreservedAnalyses run(llvm::Module &M,
368-
llvm::ModuleAnalysisManager &MAM) {
369-
setOCLTypeToSPIRV(&MAM.getResult<OCLTypeToSPIRVPass>(M));
370-
return runOCLToSPIRV(M) ? llvm::PreservedAnalyses::none()
371-
: llvm::PreservedAnalyses::all();
133+
if (DemangledName.find(HalfStem) == 0) {
134+
auto OldName = DemangledName;
135+
DemangledName = HalfStem + "n";
136+
if (OldName.find("_r") != std::string::npos)
137+
DemangledName += "_r";
138+
return;
372139
}
373-
};
140+
if (DemangledName.find(Stem) == 0) {
141+
DemangledName = Stem + "n";
142+
return;
143+
}
144+
}
374145

375146
char OCLToSPIRVLegacy::ID = 0;
376147

0 commit comments

Comments
 (0)