Skip to content

Commit b08bb11

Browse files
committed
Simplify flattening of GEP chains
This simplification also fixes instances of incorrect flat index computations
1 parent 52040b4 commit b08bb11

File tree

1 file changed

+132
-139
lines changed

1 file changed

+132
-139
lines changed

llvm/lib/Target/DirectX/DXILFlattenArrays.cpp

Lines changed: 132 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
4040
static char ID; // Pass identification.
4141
};
4242

43-
struct GEPData {
44-
ArrayType *ParentArrayType;
45-
Value *ParentOperand;
46-
SmallVector<Value *> Indices;
47-
SmallVector<uint64_t> Dims;
48-
bool AllIndicesAreConstInt;
43+
struct GEPInfo {
44+
ArrayType *RootFlattenedArrayType;
45+
Value *RootPointerOperand;
46+
SmallMapVector<Value *, APInt, 4> VariableOffsets;
47+
APInt ConstantOffset;
4948
};
5049

5150
class DXILFlattenArraysVisitor
5251
: public InstVisitor<DXILFlattenArraysVisitor, bool> {
5352
public:
54-
DXILFlattenArraysVisitor() {}
53+
DXILFlattenArraysVisitor(
54+
DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
55+
: GlobalMap(GlobalMap) {}
5556
bool visit(Function &F);
5657
// InstVisitor methods. They return true if the instruction was scalarized,
5758
// false if nothing changed.
@@ -78,31 +79,15 @@ class DXILFlattenArraysVisitor
7879

7980
private:
8081
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81-
DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
82+
DenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
83+
DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
8284
bool finish();
8385
ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
8486
ArrayRef<uint64_t> Dims,
8587
IRBuilder<> &Builder);
8688
Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
8789
ArrayRef<uint64_t> Dims,
8890
IRBuilder<> &Builder);
89-
90-
// Helper function to collect indices and dimensions from a GEP instruction
91-
void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP,
92-
SmallVectorImpl<Value *> &Indices,
93-
SmallVectorImpl<uint64_t> &Dims,
94-
bool &AllIndicesAreConstInt);
95-
96-
void
97-
recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
98-
ArrayType *FlattenedArrayType, Value *PtrOperand,
99-
unsigned &GEPChainUseCount,
100-
SmallVector<Value *> Indices = SmallVector<Value *>(),
101-
SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
102-
bool AllIndicesAreConstInt = true);
103-
bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
104-
bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
105-
GetElementPtrInst &GEP);
10691
};
10792
} // namespace
10893

@@ -225,131 +210,139 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
225210
return true;
226211
}
227212

228-
void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
229-
GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
230-
SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
231-
232-
Type *CurrentType = GEP.getSourceElementType();
213+
bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
214+
// Do not visit GEPs more than once
215+
if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
216+
return false;
233217

234-
// Note index 0 is the ptr index.
235-
for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
236-
Indices.push_back(Index);
237-
AllIndicesAreConstInt &= isa<ConstantInt>(Index);
218+
// Construct GEPInfo for this GEP
219+
GEPInfo Info;
238220

239-
if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
240-
Dims.push_back(ArrayTy->getNumElements());
241-
CurrentType = ArrayTy->getElementType();
242-
} else {
243-
assert(false && "Expected array type in GEP chain");
244-
}
245-
}
246-
}
221+
// Obtain the variable and constant byte offsets computed by this GEP
222+
const DataLayout &DL = GEP.getDataLayout();
223+
unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
224+
Info.ConstantOffset = {BitWidth, 0};
225+
bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets,
226+
Info.ConstantOffset);
227+
(void)Success;
228+
assert(Success && "Failed to collect offsets for GEP");
247229

248-
void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
249-
GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
250-
Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
251-
SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
252-
// Check if this GEP is already in the map to avoid circular references
253-
if (GEPChainMap.count(&CurrGEP) > 0)
254-
return;
230+
Value *PtrOperand = GEP.getPointerOperand();
255231

256-
// Collect indices and dimensions from the current GEP
257-
collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
258-
bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
259-
if (!IsMultiDimArr) {
260-
assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
261-
GEPChainMap.insert(
262-
{&CurrGEP,
263-
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
264-
std::move(Dims), AllIndicesAreConstInt}});
265-
return;
266-
}
267-
bool GepUses = false;
268-
for (auto *User : CurrGEP.users()) {
269-
if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
270-
recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
271-
++GEPChainUseCount, Indices, Dims,
272-
AllIndicesAreConstInt);
273-
GepUses = true;
232+
// Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
233+
// it can be visited
234+
if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand))
235+
if (PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
236+
GetElementPtrInst *OldGEPI =
237+
cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
238+
OldGEPI->insertBefore(GEP.getIterator());
239+
240+
IRBuilder<> Builder(&GEP);
241+
SmallVector<Value *> Indices(GEP.idx_begin(), GEP.idx_end());
242+
Value *NewGEP =
243+
Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
244+
GEP.getName(), GEP.getNoWrapFlags());
245+
assert(isa<GetElementPtrInst>(NewGEP) &&
246+
"Expected newly-created GEP to not be a ConstantExpr");
247+
GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
248+
249+
GEP.replaceAllUsesWith(NewGEPI);
250+
GEP.eraseFromParent();
251+
visitGetElementPtrInst(*OldGEPI);
252+
visitGetElementPtrInst(*NewGEPI);
253+
return true;
274254
}
275-
}
276-
// This case is just incase the gep chain doesn't end with a 1d array.
277-
if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
278-
GEPChainMap.insert(
279-
{&CurrGEP,
280-
{std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
281-
std::move(Dims), AllIndicesAreConstInt}});
282-
}
283-
}
284255

285-
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
286-
GetElementPtrInst &GEP) {
287-
GEPData GEPInfo = GEPChainMap.at(&GEP);
288-
return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
289-
}
290-
bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
291-
GEPData &GEPInfo, GetElementPtrInst &GEP) {
292-
IRBuilder<> Builder(&GEP);
293-
Value *FlatIndex;
294-
if (GEPInfo.AllIndicesAreConstInt)
295-
FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
296-
else
297-
FlatIndex =
298-
genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
299-
300-
ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
301-
302-
// Don't append '.flat' to an empty string. If the SSA name isn't available
303-
// it could conflict with the ParentOperand's name.
304-
std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : "";
305-
306-
Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand,
307-
{Builder.getInt32(0), FlatIndex}, FlatName,
308-
GEP.getNoWrapFlags());
309-
310-
// Note: Old gep will become an invalid instruction after replaceAllUsesWith.
311-
// Erase the old GEP in the map before to avoid invalid instructions
312-
// and circular references.
313-
GEPChainMap.erase(&GEP);
314-
315-
GEP.replaceAllUsesWith(FlatGEP);
316-
GEP.eraseFromParent();
317-
return true;
318-
}
256+
// If there is a parent GEP, inherit the root array type and pointer, and
257+
// merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
258+
// chain and we need to deterine the root array type
259+
if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
260+
assert(GEPChainInfoMap.contains(PtrOpGEP) &&
261+
"Expected parent GEP to be visited before this GEP");
262+
GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
263+
Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
264+
Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
265+
for (auto &VariableOffset : PGEPInfo.VariableOffsets)
266+
Info.VariableOffsets.insert(VariableOffset);
267+
Info.ConstantOffset += PGEPInfo.ConstantOffset;
268+
} else {
269+
Info.RootPointerOperand = PtrOperand;
270+
271+
// We should try to determine the type of the root from the pointer rather
272+
// than the GEP's source element type because this could be a scalar GEP
273+
// into a multidimensional array-typed pointer from an Alloca or Global
274+
// Variable.
275+
Type *RootTy = GEP.getSourceElementType();
276+
if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
277+
if (!GlobalMap.contains(GlobalVar))
278+
return false;
279+
GlobalVariable *NewGlobal = GlobalMap[GlobalVar];
280+
Info.RootPointerOperand = NewGlobal;
281+
RootTy = NewGlobal->getValueType();
282+
} else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
283+
RootTy = Alloca->getAllocatedType();
284+
}
285+
assert(!isMultiDimensionalArray(RootTy) &&
286+
"Expected root array type to be flattened");
319287

320-
bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
321-
auto It = GEPChainMap.find(&GEP);
322-
if (It != GEPChainMap.end())
323-
return visitGetElementPtrInstInGEPChain(GEP);
324-
if (!isMultiDimensionalArray(GEP.getSourceElementType()))
325-
return false;
288+
// If the root type is not an array, we don't need to do any flattening
289+
if (!isa<ArrayType>(RootTy))
290+
return false;
326291

327-
ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
328-
IRBuilder<> Builder(&GEP);
329-
auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
330-
ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
292+
Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
293+
}
331294

332-
Value *PtrOperand = GEP.getPointerOperand();
295+
// GEPs without users or GEPs with non-GEP users should be replaced such that
296+
// the chain of GEPs they are a part of are collapsed to a single GEP into a
297+
// flattened array.
298+
bool ReplaceThisGEP = GEP.users().empty();
299+
for (Value *User : GEP.users())
300+
if (!isa<GetElementPtrInst>(User))
301+
ReplaceThisGEP = true;
302+
303+
if (ReplaceThisGEP) {
304+
// GEP.collectOffset returns the offset in bytes. So we need to divide its
305+
// offsets by the size in bytes of the element type
306+
unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType()
307+
->getPrimitiveSizeInBits() /
308+
8;
309+
310+
// Compute the 32-bit index for this flattened GEP from the constant and
311+
// variable byte offsets in the GEPInfo
312+
IRBuilder<> Builder(&GEP);
313+
Value *ZeroIndex = Builder.getInt32(0);
314+
uint64_t ConstantOffset =
315+
Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
316+
assert(ConstantOffset < UINT32_MAX &&
317+
"Constant byte offset for flat GEP index must fit within 32 bits");
318+
Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
319+
for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
320+
uint64_t Mul = Multiplier.udiv(BytesPerElem).getZExtValue();
321+
assert(Mul < UINT32_MAX &&
322+
"Multiplier for flat GEP index must fit within 32 bits");
323+
assert(VarIndex->getType()->isIntegerTy(32) &&
324+
"Expected i32-typed GEP indices");
325+
Value *ConstIntMul = Builder.getInt32(Mul);
326+
Value *MulVarIndex = Builder.CreateMul(VarIndex, ConstIntMul);
327+
FlattenedIndex = Builder.CreateAdd(FlattenedIndex, MulVarIndex);
328+
}
333329

334-
unsigned GEPChainUseCount = 0;
335-
recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
336-
337-
// NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
338-
// Here recursion is used to get the length of the GEP chain.
339-
// Handle zero uses here because there won't be an update via
340-
// a child in the chain later.
341-
if (GEPChainUseCount == 0) {
342-
SmallVector<Value *> Indices;
343-
SmallVector<uint64_t> Dims;
344-
bool AllIndicesAreConstInt = true;
345-
346-
// Collect indices and dimensions from the GEP
347-
collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
348-
GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
349-
std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
350-
return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
330+
// Construct a new GEP for the flattened array to replace the current GEP
331+
Value *NewGEP = Builder.CreateGEP(
332+
Info.RootFlattenedArrayType, Info.RootPointerOperand,
333+
{ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
334+
335+
// Replace the current GEP with the new GEP. Store GEPInfo into the map
336+
// for later use in case this GEP was not the end of the chain
337+
GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
338+
GEP.replaceAllUsesWith(NewGEP);
339+
GEP.eraseFromParent();
340+
return true;
351341
}
352342

343+
// This GEP is potentially dead at the end of the pass since it may not have
344+
// any users anymore after GEP chains have been collapsed.
345+
GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
353346
PotentiallyDeadInstrs.emplace_back(&GEP);
354347
return false;
355348
}
@@ -456,9 +449,9 @@ flattenGlobalArrays(Module &M,
456449

457450
static bool flattenArrays(Module &M) {
458451
bool MadeChange = false;
459-
DXILFlattenArraysVisitor Impl;
460452
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
461453
flattenGlobalArrays(M, GlobalMap);
454+
DXILFlattenArraysVisitor Impl(GlobalMap);
462455
for (auto &F : make_early_inc_range(M.functions())) {
463456
if (F.isDeclaration())
464457
continue;

0 commit comments

Comments
 (0)