@@ -40,18 +40,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
40
40
static char ID; // Pass identification.
41
41
};
42
42
43
- struct GEPData {
44
- ArrayType *ParentArrayType;
45
- Value *ParentOperand;
46
- SmallVector<Value *> Indices;
47
- SmallVector<uint64_t > Dims;
48
- bool AllIndicesAreConstInt;
43
+ struct GEPInfo {
44
+ ArrayType *RootFlattenedArrayType;
45
+ Value *RootPointerOperand;
46
+ SmallMapVector<Value *, APInt, 4 > VariableOffsets;
47
+ APInt ConstantOffset;
49
48
};
50
49
51
50
class DXILFlattenArraysVisitor
52
51
: public InstVisitor<DXILFlattenArraysVisitor, bool > {
53
52
public:
54
- DXILFlattenArraysVisitor () {}
53
+ DXILFlattenArraysVisitor (
54
+ DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
55
+ : GlobalMap(GlobalMap) {}
55
56
bool visit (Function &F);
56
57
// InstVisitor methods. They return true if the instruction was scalarized,
57
58
// false if nothing changed.
@@ -78,31 +79,15 @@ class DXILFlattenArraysVisitor
78
79
79
80
private:
80
81
SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
81
- DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
82
+ DenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
83
+ DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
82
84
bool finish ();
83
85
ConstantInt *genConstFlattenIndices (ArrayRef<Value *> Indices,
84
86
ArrayRef<uint64_t > Dims,
85
87
IRBuilder<> &Builder);
86
88
Value *genInstructionFlattenIndices (ArrayRef<Value *> Indices,
87
89
ArrayRef<uint64_t > Dims,
88
90
IRBuilder<> &Builder);
89
-
90
- // Helper function to collect indices and dimensions from a GEP instruction
91
- void collectIndicesAndDimsFromGEP (GetElementPtrInst &GEP,
92
- SmallVectorImpl<Value *> &Indices,
93
- SmallVectorImpl<uint64_t > &Dims,
94
- bool &AllIndicesAreConstInt);
95
-
96
- void
97
- recursivelyCollectGEPs (GetElementPtrInst &CurrGEP,
98
- ArrayType *FlattenedArrayType, Value *PtrOperand,
99
- unsigned &GEPChainUseCount,
100
- SmallVector<Value *> Indices = SmallVector<Value *>(),
101
- SmallVector<uint64_t > Dims = SmallVector<uint64_t >(),
102
- bool AllIndicesAreConstInt = true );
103
- bool visitGetElementPtrInstInGEPChain (GetElementPtrInst &GEP);
104
- bool visitGetElementPtrInstInGEPChainBase (GEPData &GEPInfo,
105
- GetElementPtrInst &GEP);
106
91
};
107
92
} // namespace
108
93
@@ -225,131 +210,139 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
225
210
return true ;
226
211
}
227
212
228
- void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP (
229
- GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
230
- SmallVectorImpl<uint64_t > &Dims, bool &AllIndicesAreConstInt) {
231
-
232
- Type *CurrentType = GEP.getSourceElementType ();
213
+ bool DXILFlattenArraysVisitor::visitGetElementPtrInst (GetElementPtrInst &GEP) {
214
+ // Do not visit GEPs more than once
215
+ if (GEPChainInfoMap.contains (cast<GEPOperator>(&GEP)))
216
+ return false ;
233
217
234
- // Note index 0 is the ptr index.
235
- for (Value *Index : llvm::drop_begin (GEP.indices (), 1 )) {
236
- Indices.push_back (Index);
237
- AllIndicesAreConstInt &= isa<ConstantInt>(Index);
218
+ // Construct GEPInfo for this GEP
219
+ GEPInfo Info;
238
220
239
- if ( auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
240
- Dims. push_back (ArrayTy-> getNumElements () );
241
- CurrentType = ArrayTy-> getElementType ( );
242
- } else {
243
- assert ( false && " Expected array type in GEP chain " );
244
- }
245
- }
246
- }
221
+ // Obtain the variable and constant byte offsets computed by this GEP
222
+ const DataLayout &DL = GEP. getDataLayout ( );
223
+ unsigned BitWidth = DL. getIndexTypeSizeInBits (GEP. getType () );
224
+ Info. ConstantOffset = {BitWidth, 0 };
225
+ bool Success = GEP. collectOffset (DL, BitWidth, Info. VariableOffsets ,
226
+ Info. ConstantOffset );
227
+ ( void )Success;
228
+ assert (Success && " Failed to collect offsets for GEP " );
247
229
248
- void DXILFlattenArraysVisitor::recursivelyCollectGEPs (
249
- GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
250
- Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
251
- SmallVector<uint64_t > Dims, bool AllIndicesAreConstInt) {
252
- // Check if this GEP is already in the map to avoid circular references
253
- if (GEPChainMap.count (&CurrGEP) > 0 )
254
- return ;
230
+ Value *PtrOperand = GEP.getPointerOperand ();
255
231
256
- // Collect indices and dimensions from the current GEP
257
- collectIndicesAndDimsFromGEP (CurrGEP, Indices, Dims, AllIndicesAreConstInt);
258
- bool IsMultiDimArr = isMultiDimensionalArray (CurrGEP.getSourceElementType ());
259
- if (!IsMultiDimArr) {
260
- assert (GEPChainUseCount < FlattenedArrayType->getNumElements ());
261
- GEPChainMap.insert (
262
- {&CurrGEP,
263
- {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
264
- std::move (Dims), AllIndicesAreConstInt}});
265
- return ;
266
- }
267
- bool GepUses = false ;
268
- for (auto *User : CurrGEP.users ()) {
269
- if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
270
- recursivelyCollectGEPs (*NestedGEP, FlattenedArrayType, PtrOperand,
271
- ++GEPChainUseCount, Indices, Dims,
272
- AllIndicesAreConstInt);
273
- GepUses = true ;
232
+ // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
233
+ // it can be visited
234
+ if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand))
235
+ if (PtrOpGEPCE->getOpcode () == Instruction::GetElementPtr) {
236
+ GetElementPtrInst *OldGEPI =
237
+ cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction ());
238
+ OldGEPI->insertBefore (GEP.getIterator ());
239
+
240
+ IRBuilder<> Builder (&GEP);
241
+ SmallVector<Value *> Indices (GEP.idx_begin (), GEP.idx_end ());
242
+ Value *NewGEP =
243
+ Builder.CreateGEP (GEP.getSourceElementType (), OldGEPI, Indices,
244
+ GEP.getName (), GEP.getNoWrapFlags ());
245
+ assert (isa<GetElementPtrInst>(NewGEP) &&
246
+ " Expected newly-created GEP to not be a ConstantExpr" );
247
+ GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
248
+
249
+ GEP.replaceAllUsesWith (NewGEPI);
250
+ GEP.eraseFromParent ();
251
+ visitGetElementPtrInst (*OldGEPI);
252
+ visitGetElementPtrInst (*NewGEPI);
253
+ return true ;
274
254
}
275
- }
276
- // This case is just incase the gep chain doesn't end with a 1d array.
277
- if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
278
- GEPChainMap.insert (
279
- {&CurrGEP,
280
- {std::move (FlattenedArrayType), PtrOperand, std::move (Indices),
281
- std::move (Dims), AllIndicesAreConstInt}});
282
- }
283
- }
284
255
285
- bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain (
286
- GetElementPtrInst &GEP) {
287
- GEPData GEPInfo = GEPChainMap.at (&GEP);
288
- return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
289
- }
290
- bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase (
291
- GEPData &GEPInfo, GetElementPtrInst &GEP) {
292
- IRBuilder<> Builder (&GEP);
293
- Value *FlatIndex;
294
- if (GEPInfo.AllIndicesAreConstInt )
295
- FlatIndex = genConstFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
296
- else
297
- FlatIndex =
298
- genInstructionFlattenIndices (GEPInfo.Indices , GEPInfo.Dims , Builder);
299
-
300
- ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType ;
301
-
302
- // Don't append '.flat' to an empty string. If the SSA name isn't available
303
- // it could conflict with the ParentOperand's name.
304
- std::string FlatName = GEP.hasName () ? GEP.getName ().str () + " .flat" : " " ;
305
-
306
- Value *FlatGEP = Builder.CreateGEP (FlattenedArrayType, GEPInfo.ParentOperand ,
307
- {Builder.getInt32 (0 ), FlatIndex}, FlatName,
308
- GEP.getNoWrapFlags ());
309
-
310
- // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
311
- // Erase the old GEP in the map before to avoid invalid instructions
312
- // and circular references.
313
- GEPChainMap.erase (&GEP);
314
-
315
- GEP.replaceAllUsesWith (FlatGEP);
316
- GEP.eraseFromParent ();
317
- return true ;
318
- }
256
+ // If there is a parent GEP, inherit the root array type and pointer, and
257
+ // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
258
+ // chain and we need to deterine the root array type
259
+ if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
260
+ assert (GEPChainInfoMap.contains (PtrOpGEP) &&
261
+ " Expected parent GEP to be visited before this GEP" );
262
+ GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
263
+ Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType ;
264
+ Info.RootPointerOperand = PGEPInfo.RootPointerOperand ;
265
+ for (auto &VariableOffset : PGEPInfo.VariableOffsets )
266
+ Info.VariableOffsets .insert (VariableOffset);
267
+ Info.ConstantOffset += PGEPInfo.ConstantOffset ;
268
+ } else {
269
+ Info.RootPointerOperand = PtrOperand;
270
+
271
+ // We should try to determine the type of the root from the pointer rather
272
+ // than the GEP's source element type because this could be a scalar GEP
273
+ // into a multidimensional array-typed pointer from an Alloca or Global
274
+ // Variable.
275
+ Type *RootTy = GEP.getSourceElementType ();
276
+ if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
277
+ if (!GlobalMap.contains (GlobalVar))
278
+ return false ;
279
+ GlobalVariable *NewGlobal = GlobalMap[GlobalVar];
280
+ Info.RootPointerOperand = NewGlobal;
281
+ RootTy = NewGlobal->getValueType ();
282
+ } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
283
+ RootTy = Alloca->getAllocatedType ();
284
+ }
285
+ assert (!isMultiDimensionalArray (RootTy) &&
286
+ " Expected root array type to be flattened" );
319
287
320
- bool DXILFlattenArraysVisitor::visitGetElementPtrInst (GetElementPtrInst &GEP) {
321
- auto It = GEPChainMap.find (&GEP);
322
- if (It != GEPChainMap.end ())
323
- return visitGetElementPtrInstInGEPChain (GEP);
324
- if (!isMultiDimensionalArray (GEP.getSourceElementType ()))
325
- return false ;
288
+ // If the root type is not an array, we don't need to do any flattening
289
+ if (!isa<ArrayType>(RootTy))
290
+ return false ;
326
291
327
- ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType ());
328
- IRBuilder<> Builder (&GEP);
329
- auto [TotalElements, BaseType] = getElementCountAndType (ArrType);
330
- ArrayType *FlattenedArrayType = ArrayType::get (BaseType, TotalElements);
292
+ Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
293
+ }
331
294
332
- Value *PtrOperand = GEP.getPointerOperand ();
295
+ // GEPs without users or GEPs with non-GEP users should be replaced such that
296
+ // the chain of GEPs they are a part of are collapsed to a single GEP into a
297
+ // flattened array.
298
+ bool ReplaceThisGEP = GEP.users ().empty ();
299
+ for (Value *User : GEP.users ())
300
+ if (!isa<GetElementPtrInst>(User))
301
+ ReplaceThisGEP = true ;
302
+
303
+ if (ReplaceThisGEP) {
304
+ // GEP.collectOffset returns the offset in bytes. So we need to divide its
305
+ // offsets by the size in bytes of the element type
306
+ unsigned BytesPerElem = Info.RootFlattenedArrayType ->getArrayElementType ()
307
+ ->getPrimitiveSizeInBits () /
308
+ 8 ;
309
+
310
+ // Compute the 32-bit index for this flattened GEP from the constant and
311
+ // variable byte offsets in the GEPInfo
312
+ IRBuilder<> Builder (&GEP);
313
+ Value *ZeroIndex = Builder.getInt32 (0 );
314
+ uint64_t ConstantOffset =
315
+ Info.ConstantOffset .udiv (BytesPerElem).getZExtValue ();
316
+ assert (ConstantOffset < UINT32_MAX &&
317
+ " Constant byte offset for flat GEP index must fit within 32 bits" );
318
+ Value *FlattenedIndex = Builder.getInt32 (ConstantOffset);
319
+ for (auto [VarIndex, Multiplier] : Info.VariableOffsets ) {
320
+ uint64_t Mul = Multiplier.udiv (BytesPerElem).getZExtValue ();
321
+ assert (Mul < UINT32_MAX &&
322
+ " Multiplier for flat GEP index must fit within 32 bits" );
323
+ assert (VarIndex->getType ()->isIntegerTy (32 ) &&
324
+ " Expected i32-typed GEP indices" );
325
+ Value *ConstIntMul = Builder.getInt32 (Mul);
326
+ Value *MulVarIndex = Builder.CreateMul (VarIndex, ConstIntMul);
327
+ FlattenedIndex = Builder.CreateAdd (FlattenedIndex, MulVarIndex);
328
+ }
333
329
334
- unsigned GEPChainUseCount = 0 ;
335
- recursivelyCollectGEPs (GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
336
-
337
- // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
338
- // Here recursion is used to get the length of the GEP chain.
339
- // Handle zero uses here because there won't be an update via
340
- // a child in the chain later.
341
- if (GEPChainUseCount == 0 ) {
342
- SmallVector<Value *> Indices;
343
- SmallVector<uint64_t > Dims;
344
- bool AllIndicesAreConstInt = true ;
345
-
346
- // Collect indices and dimensions from the GEP
347
- collectIndicesAndDimsFromGEP (GEP, Indices, Dims, AllIndicesAreConstInt);
348
- GEPData GEPInfo{std::move (FlattenedArrayType), PtrOperand,
349
- std::move (Indices), std::move (Dims), AllIndicesAreConstInt};
350
- return visitGetElementPtrInstInGEPChainBase (GEPInfo, GEP);
330
+ // Construct a new GEP for the flattened array to replace the current GEP
331
+ Value *NewGEP = Builder.CreateGEP (
332
+ Info.RootFlattenedArrayType , Info.RootPointerOperand ,
333
+ {ZeroIndex, FlattenedIndex}, GEP.getName (), GEP.getNoWrapFlags ());
334
+
335
+ // Replace the current GEP with the new GEP. Store GEPInfo into the map
336
+ // for later use in case this GEP was not the end of the chain
337
+ GEPChainInfoMap.insert ({cast<GEPOperator>(NewGEP), std::move (Info)});
338
+ GEP.replaceAllUsesWith (NewGEP);
339
+ GEP.eraseFromParent ();
340
+ return true ;
351
341
}
352
342
343
+ // This GEP is potentially dead at the end of the pass since it may not have
344
+ // any users anymore after GEP chains have been collapsed.
345
+ GEPChainInfoMap.insert ({cast<GEPOperator>(&GEP), std::move (Info)});
353
346
PotentiallyDeadInstrs.emplace_back (&GEP);
354
347
return false ;
355
348
}
@@ -456,9 +449,9 @@ flattenGlobalArrays(Module &M,
456
449
457
450
static bool flattenArrays (Module &M) {
458
451
bool MadeChange = false ;
459
- DXILFlattenArraysVisitor Impl;
460
452
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
461
453
flattenGlobalArrays (M, GlobalMap);
454
+ DXILFlattenArraysVisitor Impl (GlobalMap);
462
455
for (auto &F : make_early_inc_range (M.functions ())) {
463
456
if (F.isDeclaration ())
464
457
continue ;
0 commit comments