Skip to content

Commit 3734f57

Browse files
committed
Handle cases where Multiplier is not divisible by BytesPerElem in variable index calculation
1 parent 9ff05d2 commit 3734f57

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/IR/InstVisitor.h"
2121
#include "llvm/IR/ReplaceConstant.h"
2222
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/MathExtras.h"
2324
#include "llvm/Transforms/Utils/Local.h"
2425
#include <cassert>
2526
#include <cstddef>
@@ -305,6 +306,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
305306
unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType()
306307
->getPrimitiveSizeInBits() /
307308
8;
309+
assert(isPowerOf2_32(BytesPerElem) &&
310+
"Bytes per element should be a power of 2");
308311

309312
// Compute the 32-bit index for this flattened GEP from the constant and
310313
// variable byte offsets in the GEPInfo
@@ -316,14 +319,23 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
316319
"Constant byte offset for flat GEP index must fit within 32 bits");
317320
Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
318321
for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
319-
uint64_t Mul = Multiplier.udiv(BytesPerElem).getZExtValue();
320-
assert(Mul < UINT32_MAX &&
321-
"Multiplier for flat GEP index must fit within 32 bits");
322+
assert(Multiplier.getActiveBits() <= 32 &&
323+
"The multiplier for a flat GEP index must fit within 32 bits");
322324
assert(VarIndex->getType()->isIntegerTy(32) &&
323325
"Expected i32-typed GEP indices");
324-
Value *ConstIntMul = Builder.getInt32(Mul);
325-
Value *MulVarIndex = Builder.CreateMul(VarIndex, ConstIntMul);
326-
FlattenedIndex = Builder.CreateAdd(FlattenedIndex, MulVarIndex);
326+
Value *VI;
327+
if (Multiplier.getZExtValue() % BytesPerElem != 0) {
328+
// This can happen, e.g., with i8 GEPs. To handle this we just divide
329+
// by BytesPerElem using an instruction after multiplying VarIndex by
330+
// Multiplier.
331+
VI = Builder.CreateMul(VarIndex,
332+
Builder.getInt32(Multiplier.getZExtValue()));
333+
VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem)));
334+
} else
335+
VI = Builder.CreateMul(
336+
VarIndex,
337+
Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem));
338+
FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI);
327339
}
328340

329341
// Construct a new GEP for the flattened array to replace the current GEP

llvm/test/CodeGen/DirectX/flatten-array.ll

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,5 +273,28 @@ define void @gep_scalar_flatten() {
273273
ret void
274274
}
275275

276+
define void @gep_scalar_flatten_dynamic(i32 %index) {
277+
; CHECK-LABEL: gep_scalar_flatten_dynamic
278+
; CHECK-SAME: i32 [[INDEX:%.*]]) {
279+
; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [6 x i32], align 4
280+
; CHECK-NEXT: [[I8INDEX:%.*]] = mul i32 [[INDEX]], 12
281+
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[I8INDEX]], 1
282+
; CHECK-NEXT: [[DIV:%.*]] = lshr i32 [[MUL]], 2
283+
; CHECK-NEXT: [[ADD:%.*]] = add i32 0, [[DIV]]
284+
; CHECK-NEXT: getelementptr inbounds nuw [6 x i32], ptr [[ALLOCA]], i32 0, i32 [[ADD]]
285+
; CHECK-NEXT: [[I32INDEX:%.*]] = mul i32 [[INDEX]], 3
286+
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[I32INDEX]], 1
287+
; CHECK-NEXT: [[ADD:%.*]] = add i32 0, [[MUL]]
288+
; CHECK-NEXT: getelementptr inbounds nuw [6 x i32], ptr [[ALLOCA]], i32 0, i32 [[ADD]]
289+
; CHECK-NEXT: ret void
290+
;
291+
%a = alloca [2 x [3 x i32]], align 4
292+
%i8index = mul i32 %index, 12
293+
%i8root = getelementptr inbounds nuw i8, [2 x [3 x i32]]* %a, i32 %i8index;
294+
%i32index = mul i32 %index, 3
295+
%i32root = getelementptr inbounds nuw i32, [2 x [3 x i32]]* %a, i32 %i32index;
296+
ret void
297+
}
298+
276299
; Make sure we don't try to walk the body of a function declaration.
277300
declare void @opaque_function()

0 commit comments

Comments
 (0)