Skip to content

Commit 463aede

Browse files
committed
Implement lowering of llvm.umul.with.overflow.* intrinsic (#743)
Since the intrinsic can't be mapped to SPIR-V directly, its semantics has been implemented as a function in LLVM IR.
1 parent bb5da8a commit 463aede

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ class SPIRVRegularizeLLVM : public ModulePass {
9797
void lowerFunnelShiftLeft(IntrinsicInst *FSHLIntrinsic);
9898
void buildFunnelShiftLeftFunc(Function *FSHLFunc);
9999

100+
void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic);
101+
void buildUMulWithOverflowFunc(Function *UMulFunc);
102+
100103
static std::string lowerLLVMIntrinsicName(IntrinsicInst *II);
101104

102105
static char ID;
@@ -225,6 +228,45 @@ void SPIRVRegularizeLLVM::lowerFunnelShiftLeft(IntrinsicInst *FSHLIntrinsic) {
225228
FSHLIntrinsic->setCalledFunction(FSHLFunc);
226229
}
227230

231+
void SPIRVRegularizeLLVM::buildUMulWithOverflowFunc(Function *UMulFunc) {
232+
if (!UMulFunc->empty())
233+
return;
234+
235+
BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
236+
IRBuilder<> Builder(EntryBB);
237+
// Build the actual unsigned multiplication logic with the overflow
238+
// indication.
239+
auto *FirstArg = UMulFunc->getArg(0);
240+
auto *SecondArg = UMulFunc->getArg(1);
241+
242+
// Do unsigned multiplication Mul = A * B.
243+
// Then check if unsigned division Div = Mul / A is not equal to B.
244+
// If so, then overflow has happened.
245+
auto *Mul = Builder.CreateNUWMul(FirstArg, SecondArg);
246+
auto *Div = Builder.CreateUDiv(Mul, FirstArg);
247+
auto *Overflow = Builder.CreateICmpNE(FirstArg, Div);
248+
249+
// umul.with.overflow intrinsic return a structure, where the first element
250+
// is the multiplication result, and the second is an overflow bit.
251+
auto *StructTy = UMulFunc->getReturnType();
252+
auto *Agg = Builder.CreateInsertValue(UndefValue::get(StructTy), Mul, {0});
253+
auto *Res = Builder.CreateInsertValue(Agg, Overflow, {1});
254+
Builder.CreateRet(Res);
255+
}
256+
257+
void SPIRVRegularizeLLVM::lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
258+
// Get a separate function - otherwise, we'd have to rework the CFG of the
259+
// current one. Then simply replace the intrinsic uses with a call to the new
260+
// function.
261+
FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
262+
Type *FSHLRetTy = UMulFuncTy->getReturnType();
263+
const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
264+
Function *UMulFunc =
265+
getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
266+
buildUMulWithOverflowFunc(UMulFunc);
267+
UMulIntrinsic->setCalledFunction(UMulFunc);
268+
}
269+
228270
bool SPIRVRegularizeLLVM::runOnModule(Module &Module) {
229271
M = &Module;
230272
Ctx = &M->getContext();
@@ -263,6 +305,8 @@ bool SPIRVRegularizeLLVM::regularize() {
263305
lowerMemset(MSI);
264306
else if (II->getIntrinsicID() == Intrinsic::fshl)
265307
lowerFunnelShiftLeft(II);
308+
else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
309+
lowerUMulWithOverflow(II);
266310
}
267311
}
268312

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -spirv-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
3+
; RUN: llvm-spirv %t.bc -o %t.spv
4+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
5+
6+
; On LLVM level, we'll check that the intrinsics were generated again in reverse
7+
; translation, replacing the SPIR-V level implementations.
8+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM \
9+
; RUN: "--implicit-check-not=declare {{.*}} @spirv.llvm_umul_with_overflow_{{.*}}"
10+
11+
; CHECK-SPIRV: Name [[NAME_UMUL_FUNC_8:[0-9]+]] "spirv.llvm_umul_with_overflow_i8"
12+
; CHECK-SPIRV: Name [[NAME_UMUL_FUNC_32:[0-9]+]] "spirv.llvm_umul_with_overflow_i32"
13+
; CHECK-SPIRV: Name [[NAME_UMUL_FUNC_VEC_I64:[0-9]+]] "spirv.llvm_umul_with_overflow_v2i64"
14+
15+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
16+
target triple = "spir"
17+
18+
; CHECK-LLVM: [[UMUL_8_TY:%structtype]] = type { i8, i1 }
19+
; CHECK-LLVM: [[UMUL_32_TY:%structtype.[0-9]+]] = type { i32, i1 }
20+
; CHECK-LLVM: [[UMUL_VEC64_TY:%structtype.[0-9]+]] = type { <2 x i64>, <2 x i1> }
21+
22+
; Function Attrs: nofree nounwind writeonly
23+
define dso_local spir_func void @_Z4foo8hhPh(i8 zeroext %a, i8 zeroext %b, i8* nocapture %c) local_unnamed_addr #0 {
24+
entry:
25+
; CHECK-LLVM: call [[UMUL_8_TY]] @llvm.umul.with.overflow.i8
26+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[NAME_UMUL_FUNC_8]]
27+
%umul = tail call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 %b)
28+
%cmp = extractvalue { i8, i1 } %umul, 1
29+
%umul.value = extractvalue { i8, i1 } %umul, 0
30+
%storemerge = select i1 %cmp, i8 0, i8 %umul.value
31+
store i8 %storemerge, i8* %c, align 1, !tbaa !2
32+
ret void
33+
}
34+
35+
; CHECK-SPIRV: Function [[#]] [[NAME_UMUL_FUNC_8]]
36+
; CHECK-SPIRV: FunctionParameter [[#]] [[VAR_A:[0-9]+]]
37+
; CHECK-SPIRV: FunctionParameter [[#]] [[VAR_B:[0-9]+]]
38+
; CHECK-SPIRV: IMul [[#]] [[MUL_RES:[0-9]+]] [[VAR_A]] [[VAR_B]]
39+
; CHECK-SPIRV: UDiv [[#]] [[DIV_RES:[0-9]+]] [[MUL_RES]] [[VAR_A]]
40+
; CHECK-SPIRV: INotEqual [[#]] [[CMP_RES:[0-9]+]] [[VAR_A]] [[DIV_RES]]
41+
; CHECK-SPIRV: CompositeInsert [[#]] [[INSERT_RES:[0-9]+]] [[MUL_RES]]
42+
; CHECK-SPIRV: CompositeInsert [[#]] [[INSERT_RES_1:[0-9]+]] [[CMP_RES]] [[INSERT_RES]]
43+
; CHECK-SPIRV: ReturnValue [[INSERT_RES_1]]
44+
45+
; Function Attrs: nofree nounwind writeonly
46+
define dso_local spir_func void @_Z5foo32jjPj(i32 %a, i32 %b, i32* nocapture %c) local_unnamed_addr #0 {
47+
entry:
48+
; CHECK-LLVM: call [[UMUL_32_TY]] @llvm.umul.with.overflow.i32
49+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[NAME_UMUL_FUNC_32]]
50+
%umul = tail call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %b, i32 %a)
51+
%umul.val = extractvalue { i32, i1 } %umul, 0
52+
%umul.ov = extractvalue { i32, i1 } %umul, 1
53+
%spec.select = select i1 %umul.ov, i32 0, i32 %umul.val
54+
store i32 %spec.select, i32* %c, align 4, !tbaa !5
55+
ret void
56+
}
57+
58+
; Function Attrs: nofree nounwind writeonly
59+
define dso_local spir_func void @umulo_v2i64(<2 x i64> %a, <2 x i64> %b, <2 x i64>* %p) nounwind {
60+
; CHECK-LLVM: call [[UMUL_VEC64_TY]] @llvm.umul.with.overflow.v2i64
61+
; CHECK-SPIRV: FunctionCall [[#]] [[#]] [[NAME_UMUL_FUNC_VEC_I64]]
62+
%umul = call {<2 x i64>, <2 x i1>} @llvm.umul.with.overflow.v2i64(<2 x i64> %a, <2 x i64> %b)
63+
%umul.val = extractvalue {<2 x i64>, <2 x i1>} %umul, 0
64+
%umul.ov = extractvalue {<2 x i64>, <2 x i1>} %umul, 1
65+
%zero = alloca <2 x i64>, align 16
66+
%spec.select = select <2 x i1> %umul.ov, <2 x i64> <i64 0, i64 0>, <2 x i64> %umul.val
67+
store <2 x i64> %spec.select, <2 x i64>* %p
68+
ret void
69+
}
70+
71+
; Function Attrs: nounwind readnone speculatable willreturn
72+
declare { i8, i1 } @llvm.umul.with.overflow.i8(i8, i8) #1
73+
74+
; Function Attrs: nounwind readnone speculatable willreturn
75+
declare { i32, i1 } @llvm.umul.with.overflow.i32(i32, i32) #1
76+
77+
; Function Attrs: nounwind readnone speculatable willreturn
78+
declare {<2 x i64>, <2 x i1>} @llvm.umul.with.overflow.v2i64(<2 x i64>, <2 x i64>) #1
79+
80+
attributes #0 = { nofree nounwind writeonly "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
81+
attributes #1 = { nounwind readnone speculatable willreturn }
82+
83+
!llvm.module.flags = !{!0}
84+
!llvm.ident = !{!1}
85+
86+
!0 = !{i32 1, !"wchar_size", i32 4}
87+
!1 = !{!"clang version 12.0.0 (https://github.com/llvm/llvm-project.git ddcc7ce59150c9ebc6b0b2d61e7ef4f2525c11f4)"}
88+
!2 = !{!3, !3, i64 0}
89+
!3 = !{!"omnipotent char", !4, i64 0}
90+
!4 = !{!"Simple C++ TBAA"}
91+
!5 = !{!6, !6, i64 0}
92+
!6 = !{!"int", !3, i64 0}

0 commit comments

Comments
 (0)