Skip to content

Commit 881d181

Browse files
KornevNikitavmaksimo
authored andcommitted
Extend bool arg of shift operations
This is a patch to regularize lshr instruction with an i1 argument. According to the SPIR-V specification OpShiftRightLogical operands should be of integer type. Co-authored-by: Alexey Sachkov <alexey.sachkov@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@c65f190
1 parent ecb25c8 commit 881d181

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ class SPIRVRegularizeLLVMBase {
117117
void expandVEDWithSYCLTypeSRetArg(Function *F);
118118
void expandVIDWithSYCLTypeByValComp(Function *F);
119119

120+
// According to the specification, the operands of a shift instruction must be
121+
// a scalar/vector of integer. When LLVM-IR contains a shift instruction with
122+
// i1 operands, they are treated as a bool. We need to extend them to i32 to
123+
// comply with the specification. For example: "%shift = lshr i1 0, 1";
124+
// The bit instruction should be changed to the extended version
125+
// "%shift = lshr i32 0, 1" so the args are treated as int operands.
126+
Value *extendBitInstBoolArg(Instruction *OldInst);
127+
120128
static std::string lowerLLVMIntrinsicName(IntrinsicInst *II);
121129
void adaptStructTypes(StructType *ST);
122130
static char ID;
@@ -412,6 +420,31 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
412420
expandVIDWithSYCLTypeByValComp(F);
413421
}
414422

423+
Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg(Instruction *II) {
424+
IRBuilder<> Builder(II);
425+
auto *ArgTy = II->getOperand(0)->getType();
426+
Type *NewArgType = nullptr;
427+
if (ArgTy->isIntegerTy()) {
428+
NewArgType = Builder.getInt32Ty();
429+
} else if (ArgTy->isVectorTy() &&
430+
cast<VectorType>(ArgTy)->getElementType()->isIntegerTy()) {
431+
unsigned NumElements = cast<FixedVectorType>(ArgTy)->getNumElements();
432+
NewArgType = VectorType::get(Builder.getInt32Ty(), NumElements, false);
433+
} else {
434+
llvm_unreachable("Unexpected type");
435+
}
436+
auto *NewBase = Builder.CreateZExt(II->getOperand(0), NewArgType);
437+
auto *NewShift = Builder.CreateZExt(II->getOperand(1), NewArgType);
438+
switch (II->getOpcode()) {
439+
case Instruction::LShr:
440+
return Builder.CreateLShr(NewBase, NewShift);
441+
case Instruction::Shl:
442+
return Builder.CreateShl(NewBase, NewShift);
443+
default:
444+
return II;
445+
}
446+
}
447+
415448
void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
416449
if (!ST->hasName())
417450
return;
@@ -556,6 +589,21 @@ bool SPIRVRegularizeLLVMBase::regularize() {
556589
}
557590
}
558591

592+
// Translator treats i1 as boolean, but bit instructions take
593+
// a scalar/vector integers, so we have to extend such arguments
594+
if (II.isLogicalShift() &&
595+
II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
596+
auto *NewInst = extendBitInstBoolArg(&II);
597+
for (auto *U : II.users()) {
598+
if (cast<Instruction>(U)->getOpcode() == Instruction::ZExt) {
599+
U->dropAllReferences();
600+
U->replaceAllUsesWith(NewInst);
601+
ToErase.push_back(cast<Instruction>(U));
602+
}
603+
}
604+
ToErase.push_back(&II);
605+
}
606+
559607
// Remove optimization info not supported by SPIRV
560608
if (auto BO = dyn_cast<BinaryOperator>(&II)) {
561609
if (isa<PossiblyExactOperator>(BO) && BO->isExact())

llvm-spirv/test/ExtendBitBoolArg.ll

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o %t.regulzarized.bc
3+
; RUN: llvm-dis %t.regulzarized.bc -o %t.regulzarized.ll
4+
; RUN: FileCheck < %t.regulzarized.ll %s
5+
6+
; Translation cycle should be successfull:
7+
; RUN: llvm-spirv %t.regulzarized.bc -o %t.spv
8+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
9+
10+
; CHECK: %[[#Base:]] = load i1, i1 addrspace(4)*{{.*}}, align 1
11+
; CHECK: %[[#LoadShift:]] = load i32, i32 addrspace(4)*{{.*}} align 4
12+
; CHECK: %[[#AndShift:]] = and i32 %[[#LoadShift]], 1
13+
; CHECK: %[[#CmpShift:]] = icmp ne i32 %[[#AndShift]], 0
14+
; CHECK: %[[#ExtBase:]] = select i1 %[[#Base]], i32 1, i32 0
15+
; CHECK: %[[#ExtShift:]] = select i1 %[[#CmpShift]], i32 1, i32 0
16+
; CHECK: %[[#LSHR:]] = lshr i32 %[[#ExtBase]], %[[#ExtShift]]
17+
; CHECK: and i32 %[[#LSHR]], 1
18+
19+
; CHECK: %[[#ExtVecBase:]] = select <2 x i1> %vec1, <2 x i32> <i32 1, i32 1>, <2 x i32> zeroinitializer
20+
; CHECK: %[[#ExtVecShift:]] = select <2 x i1> %vec2, <2 x i32> <i32 1, i32 1>, <2 x i32> zeroinitializer
21+
; CHECK: lshr <2 x i32> %[[#ExtVecBase]], %[[#ExtVecShift]]
22+
23+
; ModuleID = 'source.bc'
24+
source_filename = "source.cpp"
25+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
26+
target triple = "spir64-unknown-unknown"
27+
28+
%"class.ac" = type { i1 }
29+
30+
; Function Attrs: convergent mustprogress norecurse nounwind
31+
define linkonce_odr dso_local spir_func void @foo(<2 x i1> %vec1, <2 x i1> %vec2) align 2 {
32+
%1 = alloca %"class.ac" addrspace(4)*, align 8
33+
%2 = alloca i32, align 4
34+
%3 = addrspacecast %"class.ac" addrspace(4)** %1 to %"class.ac" addrspace(4)* addrspace(4)*
35+
%4 = addrspacecast i32* %2 to i32 addrspace(4)*
36+
%5 = load %"class.ac" addrspace(4)*, %"class.ac" addrspace(4)* addrspace(4)* %3, align 8
37+
%6 = getelementptr inbounds %"class.ac", %"class.ac" addrspace(4)* %5, i32 0, i32 0
38+
%7 = load i1, i1 addrspace(4)* %6, align 1
39+
%8 = load i32, i32 addrspace(4)* %4, align 4
40+
%9 = trunc i32 %8 to i1
41+
%10 = lshr i1 %7, %9
42+
%11 = zext i1 %10 to i32
43+
%12 = and i32 %11, 1
44+
%13 = lshr <2 x i1> %vec1, %vec2
45+
ret void
46+
}

0 commit comments

Comments
 (0)