Skip to content

[DirectX] Simplify and correct the flattening of GEPs in DXILFlattenArrays #146173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

Icohedron
Copy link
Contributor

@Icohedron Icohedron commented Jun 27, 2025

In tandem with #146800, this PR fixes #145370

This PR simplifies the logic for collapsing GEP chains and replacing GEPs to multidimensional arrays with GEPs to flattened arrays. This implementation avoids unnecessary recursion and more robustly computes the index to the flattened array by using the GEPOperator's collectOffset function, which has the side effect of allowing "i8 GEPs" and other types of GEPs to be handled naturally in the flattening / collapsing of GEP chains.

Furthermore, a handful of LLVM DirectX CodeGen tests have been edited to fix incorrect GEP offsets, mismatched types (e.g., loading i32s from a an array of floats), and typos.

This simplification also fixes instances of incorrect flat index
computations
A few tests had incorrect GEP indices or types.
This commit fixes these GEPs and array types.
@llvmbot
Copy link
Member

llvmbot commented Jun 27, 2025

@llvm/pr-subscribers-backend-directx

Author: Deric C. (Icohedron)

Changes

In tandem with #145780 or the changing of pass order to resolve #145924, this PR fixes #145370

This PR simplifies the logic for collapsing GEP chains and replacing GEPs to multidimensional arrays with GEPs flattened arrays, avoiding unnecessary recursion and more robustly computing index to the flattened array by using the GEPOperator's collectOffset function.

Furthermore, a handful of LLVM DirectX CodeGen have been edited to fix incorrect GEP offsets, mismatched types (i.e., loading i32s from a an array of floats), and typos.


Patch is 29.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146173.diff

6 Files Affected:

  • (modified) llvm/lib/Target/DirectX/DXILFlattenArrays.cpp (+132-139)
  • (modified) llvm/test/CodeGen/DirectX/flatten-array.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/flatten-bug-117273.ll (+4-4)
  • (modified) llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll (+27-33)
  • (modified) llvm/test/CodeGen/DirectX/scalar-bug-117273.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/scalarize-alloca.ll (+3-1)
diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
index 0b7cf2f970172..913a8dcb917f4 100644
--- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
+++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp
@@ -40,18 +40,19 @@ class DXILFlattenArraysLegacy : public ModulePass {
   static char ID; // Pass identification.
 };
 
-struct GEPData {
-  ArrayType *ParentArrayType;
-  Value *ParentOperand;
-  SmallVector<Value *> Indices;
-  SmallVector<uint64_t> Dims;
-  bool AllIndicesAreConstInt;
+struct GEPInfo {
+  ArrayType *RootFlattenedArrayType;
+  Value *RootPointerOperand;
+  SmallMapVector<Value *, APInt, 4> VariableOffsets;
+  APInt ConstantOffset;
 };
 
 class DXILFlattenArraysVisitor
     : public InstVisitor<DXILFlattenArraysVisitor, bool> {
 public:
-  DXILFlattenArraysVisitor() {}
+  DXILFlattenArraysVisitor(
+      DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
+      : GlobalMap(GlobalMap) {}
   bool visit(Function &F);
   // InstVisitor methods.  They return true if the instruction was scalarized,
   // false if nothing changed.
@@ -78,7 +79,8 @@ class DXILFlattenArraysVisitor
 
 private:
   SmallVector<WeakTrackingVH> PotentiallyDeadInstrs;
-  DenseMap<GetElementPtrInst *, GEPData> GEPChainMap;
+  DenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap;
+  DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap;
   bool finish();
   ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices,
                                       ArrayRef<uint64_t> Dims,
@@ -86,23 +88,6 @@ class DXILFlattenArraysVisitor
   Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices,
                                       ArrayRef<uint64_t> Dims,
                                       IRBuilder<> &Builder);
-
-  // Helper function to collect indices and dimensions from a GEP instruction
-  void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP,
-                                    SmallVectorImpl<Value *> &Indices,
-                                    SmallVectorImpl<uint64_t> &Dims,
-                                    bool &AllIndicesAreConstInt);
-
-  void
-  recursivelyCollectGEPs(GetElementPtrInst &CurrGEP,
-                         ArrayType *FlattenedArrayType, Value *PtrOperand,
-                         unsigned &GEPChainUseCount,
-                         SmallVector<Value *> Indices = SmallVector<Value *>(),
-                         SmallVector<uint64_t> Dims = SmallVector<uint64_t>(),
-                         bool AllIndicesAreConstInt = true);
-  bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP);
-  bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo,
-                                            GetElementPtrInst &GEP);
 };
 } // namespace
 
@@ -225,131 +210,139 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
   return true;
 }
 
-void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP(
-    GetElementPtrInst &GEP, SmallVectorImpl<Value *> &Indices,
-    SmallVectorImpl<uint64_t> &Dims, bool &AllIndicesAreConstInt) {
-
-  Type *CurrentType = GEP.getSourceElementType();
+bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+  // Do not visit GEPs more than once
+  if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP)))
+    return false;
 
-  // Note index 0 is the ptr index.
-  for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) {
-    Indices.push_back(Index);
-    AllIndicesAreConstInt &= isa<ConstantInt>(Index);
+  // Construct GEPInfo for this GEP
+  GEPInfo Info;
 
-    if (auto *ArrayTy = dyn_cast<ArrayType>(CurrentType)) {
-      Dims.push_back(ArrayTy->getNumElements());
-      CurrentType = ArrayTy->getElementType();
-    } else {
-      assert(false && "Expected array type in GEP chain");
-    }
-  }
-}
+  // Obtain the variable and constant byte offsets computed by this GEP
+  const DataLayout &DL = GEP.getDataLayout();
+  unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType());
+  Info.ConstantOffset = {BitWidth, 0};
+  bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets,
+                                   Info.ConstantOffset);
+  (void)Success;
+  assert(Success && "Failed to collect offsets for GEP");
 
-void DXILFlattenArraysVisitor::recursivelyCollectGEPs(
-    GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType,
-    Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector<Value *> Indices,
-    SmallVector<uint64_t> Dims, bool AllIndicesAreConstInt) {
-  // Check if this GEP is already in the map to avoid circular references
-  if (GEPChainMap.count(&CurrGEP) > 0)
-    return;
+  Value *PtrOperand = GEP.getPointerOperand();
 
-  // Collect indices and dimensions from the current GEP
-  collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt);
-  bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType());
-  if (!IsMultiDimArr) {
-    assert(GEPChainUseCount < FlattenedArrayType->getNumElements());
-    GEPChainMap.insert(
-        {&CurrGEP,
-         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
-          std::move(Dims), AllIndicesAreConstInt}});
-    return;
-  }
-  bool GepUses = false;
-  for (auto *User : CurrGEP.users()) {
-    if (GetElementPtrInst *NestedGEP = dyn_cast<GetElementPtrInst>(User)) {
-      recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand,
-                             ++GEPChainUseCount, Indices, Dims,
-                             AllIndicesAreConstInt);
-      GepUses = true;
+  // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that
+  // it can be visited
+  if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand))
+    if (PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
+      GetElementPtrInst *OldGEPI =
+          cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction());
+      OldGEPI->insertBefore(GEP.getIterator());
+
+      IRBuilder<> Builder(&GEP);
+      SmallVector<Value *> Indices(GEP.idx_begin(), GEP.idx_end());
+      Value *NewGEP =
+          Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices,
+                            GEP.getName(), GEP.getNoWrapFlags());
+      assert(isa<GetElementPtrInst>(NewGEP) &&
+             "Expected newly-created GEP to not be a ConstantExpr");
+      GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP);
+
+      GEP.replaceAllUsesWith(NewGEPI);
+      GEP.eraseFromParent();
+      visitGetElementPtrInst(*OldGEPI);
+      visitGetElementPtrInst(*NewGEPI);
+      return true;
     }
-  }
-  // This case is just incase the gep chain doesn't end with a 1d array.
-  if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) {
-    GEPChainMap.insert(
-        {&CurrGEP,
-         {std::move(FlattenedArrayType), PtrOperand, std::move(Indices),
-          std::move(Dims), AllIndicesAreConstInt}});
-  }
-}
 
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain(
-    GetElementPtrInst &GEP) {
-  GEPData GEPInfo = GEPChainMap.at(&GEP);
-  return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
-}
-bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase(
-    GEPData &GEPInfo, GetElementPtrInst &GEP) {
-  IRBuilder<> Builder(&GEP);
-  Value *FlatIndex;
-  if (GEPInfo.AllIndicesAreConstInt)
-    FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-  else
-    FlatIndex =
-        genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder);
-
-  ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType;
-
-  // Don't append '.flat' to an empty string. If the SSA name isn't available
-  // it could conflict with the ParentOperand's name.
-  std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : "";
-
-  Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand,
-                                     {Builder.getInt32(0), FlatIndex}, FlatName,
-                                     GEP.getNoWrapFlags());
-
-  // Note: Old gep will become an invalid instruction after replaceAllUsesWith.
-  // Erase the old GEP in the map before to avoid invalid instructions
-  // and circular references.
-  GEPChainMap.erase(&GEP);
-
-  GEP.replaceAllUsesWith(FlatGEP);
-  GEP.eraseFromParent();
-  return true;
-}
+  // If there is a parent GEP, inherit the root array type and pointer, and
+  // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP
+  // chain and we need to deterine the root array type
+  if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) {
+    assert(GEPChainInfoMap.contains(PtrOpGEP) &&
+           "Expected parent GEP to be visited before this GEP");
+    GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
+    Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
+    Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
+    for (auto &VariableOffset : PGEPInfo.VariableOffsets)
+      Info.VariableOffsets.insert(VariableOffset);
+    Info.ConstantOffset += PGEPInfo.ConstantOffset;
+  } else {
+    Info.RootPointerOperand = PtrOperand;
+
+    // We should try to determine the type of the root from the pointer rather
+    // than the GEP's source element type because this could be a scalar GEP
+    // into a multidimensional array-typed pointer from an Alloca or Global
+    // Variable.
+    Type *RootTy = GEP.getSourceElementType();
+    if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
+      if (!GlobalMap.contains(GlobalVar))
+        return false;
+      GlobalVariable *NewGlobal = GlobalMap[GlobalVar];
+      Info.RootPointerOperand = NewGlobal;
+      RootTy = NewGlobal->getValueType();
+    } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
+      RootTy = Alloca->getAllocatedType();
+    }
+    assert(!isMultiDimensionalArray(RootTy) &&
+           "Expected root array type to be flattened");
 
-bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) {
-  auto It = GEPChainMap.find(&GEP);
-  if (It != GEPChainMap.end())
-    return visitGetElementPtrInstInGEPChain(GEP);
-  if (!isMultiDimensionalArray(GEP.getSourceElementType()))
-    return false;
+    // If the root type is not an array, we don't need to do any flattening
+    if (!isa<ArrayType>(RootTy))
+      return false;
 
-  ArrayType *ArrType = cast<ArrayType>(GEP.getSourceElementType());
-  IRBuilder<> Builder(&GEP);
-  auto [TotalElements, BaseType] = getElementCountAndType(ArrType);
-  ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements);
+    Info.RootFlattenedArrayType = cast<ArrayType>(RootTy);
+  }
 
-  Value *PtrOperand = GEP.getPointerOperand();
+  // GEPs without users or GEPs with non-GEP users should be replaced such that
+  // the chain of GEPs they are a part of are collapsed to a single GEP into a
+  // flattened array.
+  bool ReplaceThisGEP = GEP.users().empty();
+  for (Value *User : GEP.users())
+    if (!isa<GetElementPtrInst>(User))
+      ReplaceThisGEP = true;
+
+  if (ReplaceThisGEP) {
+    // GEP.collectOffset returns the offset in bytes. So we need to divide its
+    // offsets by the size in bytes of the element type
+    unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType()
+                                ->getPrimitiveSizeInBits() /
+                            8;
+
+    // Compute the 32-bit index for this flattened GEP from the constant and
+    // variable byte offsets in the GEPInfo
+    IRBuilder<> Builder(&GEP);
+    Value *ZeroIndex = Builder.getInt32(0);
+    uint64_t ConstantOffset =
+        Info.ConstantOffset.udiv(BytesPerElem).getZExtValue();
+    assert(ConstantOffset < UINT32_MAX &&
+           "Constant byte offset for flat GEP index must fit within 32 bits");
+    Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
+    for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
+      uint64_t Mul = Multiplier.udiv(BytesPerElem).getZExtValue();
+      assert(Mul < UINT32_MAX &&
+             "Multiplier for flat GEP index must fit within 32 bits");
+      assert(VarIndex->getType()->isIntegerTy(32) &&
+             "Expected i32-typed GEP indices");
+      Value *ConstIntMul = Builder.getInt32(Mul);
+      Value *MulVarIndex = Builder.CreateMul(VarIndex, ConstIntMul);
+      FlattenedIndex = Builder.CreateAdd(FlattenedIndex, MulVarIndex);
+    }
 
-  unsigned GEPChainUseCount = 0;
-  recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount);
-
-  // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0.
-  // Here recursion is used to get the length of the GEP chain.
-  // Handle zero uses here because there won't be an update via
-  // a child in the chain later.
-  if (GEPChainUseCount == 0) {
-    SmallVector<Value *> Indices;
-    SmallVector<uint64_t> Dims;
-    bool AllIndicesAreConstInt = true;
-
-    // Collect indices and dimensions from the GEP
-    collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt);
-    GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand,
-                    std::move(Indices), std::move(Dims), AllIndicesAreConstInt};
-    return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP);
+    // Construct a new GEP for the flattened array to replace the current GEP
+    Value *NewGEP = Builder.CreateGEP(
+        Info.RootFlattenedArrayType, Info.RootPointerOperand,
+        {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags());
+
+    // Replace the current GEP with the new GEP. Store GEPInfo into the map
+    // for later use in case this GEP was not the end of the chain
+    GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)});
+    GEP.replaceAllUsesWith(NewGEP);
+    GEP.eraseFromParent();
+    return true;
   }
 
+  // This GEP is potentially dead at the end of the pass since it may not have
+  // any users anymore after GEP chains have been collapsed.
+  GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)});
   PotentiallyDeadInstrs.emplace_back(&GEP);
   return false;
 }
@@ -456,9 +449,9 @@ flattenGlobalArrays(Module &M,
 
 static bool flattenArrays(Module &M) {
   bool MadeChange = false;
-  DXILFlattenArraysVisitor Impl;
   DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
   flattenGlobalArrays(M, GlobalMap);
+  DXILFlattenArraysVisitor Impl(GlobalMap);
   for (auto &F : make_early_inc_range(M.functions())) {
     if (F.isDeclaration())
       continue;
diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll
index dc8c5f8421bfe..e256146bb74f4 100644
--- a/llvm/test/CodeGen/DirectX/flatten-array.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-array.ll
@@ -159,9 +159,9 @@ define void @global_gep_load_index(i32 %row, i32 %col, i32 %timeIndex) {
 define void @global_incomplete_gep_chain(i32 %row, i32 %col) {
 ; CHECK-LABEL: define void @global_incomplete_gep_chain(
 ; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]]) {
-; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[COL]], 1
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[COL]], 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = add i32 0, [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[ROW]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i32 [[ROW]], 12
 ; CHECK-NEXT:    [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
 ; CHECK-NEXT:    [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 0, i32 [[TMP4]]
 ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}}
diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
index c73e5017348d1..930805f0ddc90 100644
--- a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
+++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll
@@ -8,16 +8,16 @@
 define internal void @main() {
 ; CHECK-LABEL: define internal void @main() {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 0, i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3
 ; CHECK-NEXT:    [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 0, i32 2
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6
 ; CHECK-NEXT:    [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:
-  %0 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1
+  %0 = getelementptr [2 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1
   %.i0 = load float, ptr %0, align 16
-  %1 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2
+  %1 = getelementptr [2 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2
   %.i03 = load float, ptr %1, align 16
   ret void
 }
diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
index d5797f6b51348..78550adbe424a 100644
--- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
+++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll
@@ -3,43 +3,35 @@
 
 ; Make sure we can load groupshared, static vectors and arrays of vectors
 
-@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16
+@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <4 x i32>] zeroinitializer, align 16
 @"vecData" = external addrspace(3) global <4 x i32>, align 4
 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>, <4 x i32> <i32 9, i32 10, i32 11, i32 12>], align 4
-@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16
+@"groupshared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x <4 x i32>]] zeroinitializer, align 16
 
-; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16
+; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [8 x i32] zeroinitializer, align 16
 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4
 ; CHECK: @staticArrayOfVecData.scalarized.1dim = internal global [12 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12], align 4
-; CHECK: @groushared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16
+; CHECK: @groupshared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16
 
 ; CHECK-NOT: @arrayofVecData
 ; CHECK-NOT: @arrayofVecData.scalarized
 ; CHECK-NOT: @vecData
 ; CHECK-NOT: @staticArrayOfVecData
 ; CHECK-NOT: @staticArrayOfVecData.scalarized
-; CHECK-NOT: @groushared2dArrayofVectors
-; CHECK-NOT: @groushared2dArrayofVectors.scalarized
+; CHECK-NOT: @groupshared2dArrayofVectors
+; CHECK-NOT: @groupshared2dArrayofVectors.scalarized
 
 define <4 x i32> @load_array_vec_test() #0 {
 ; CHECK-LABEL: define <4 x i32> @load_array_vec_test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast ptr addrspace(3) @arrayofVecData.scalarized.1dim to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr addrspace(3) [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP4:%.*]] = load i32, ptr addrspace(3) [[TMP3]], align 4
-; CHECK-NEXT:    [[TMP5:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 2) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4
-; CHECK-NEXT:    [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4
-; CHECK-NEXT:    [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1) to ptr addrspace(3)
-; CHECK-NEXT:    [[TMP10:%.*]] = load i32, ptr addrspace(3) [[TMP9]], align 4
-; CHECK-NEXT:    [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32...
[truncated]

@Icohedron Icohedron force-pushed the dxil-flatten-array-geps branch from 17778df to 1a09803 Compare June 28, 2025 01:48
It is possible for GEPOperator* to overlap between function visits, so
we have to clear the map or else there could be stale data leftover that
causes GEPs to be incorrectly generated or skipped over.
@Icohedron Icohedron force-pushed the dxil-flatten-array-geps branch from 2c5db9a to 3280609 Compare June 30, 2025 21:52
};

class DXILFlattenArraysVisitor
: public InstVisitor<DXILFlattenArraysVisitor, bool> {
public:
DXILFlattenArraysVisitor() {}
DXILFlattenArraysVisitor(
DenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap)
Copy link
Member

@farzonl farzonl Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't clear why this was neccessary. Why do we need a reference of GlobalMap in the DXILFlattenArraysVisitor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's necessary because otherwise the DXILFlattenArraysVisitor won't have visibility of the GlobalMap which comes from a stack-allocated variable here

static bool flattenArrays(Module &M) {
bool MadeChange = false;
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
flattenGlobalArrays(M, GlobalMap);
DXILFlattenArraysVisitor Impl(GlobalMap);

which is necessary to determine the type of the GEP

// We should try to determine the type of the root from the pointer rather
// than the GEP's source element type because this could be a scalar GEP
// into an array-typed pointer from an Alloca or Global Variable.
Type *RootTy = GEP.getSourceElementType();
if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) {
if (GlobalMap.contains(GlobalVar))
GlobalVar = GlobalMap[GlobalVar];
Info.RootPointerOperand = GlobalVar;
RootTy = GlobalVar->getValueType();
} else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
RootTy = Alloca->getAllocatedType();
}

};
} // namespace

bool DXILFlattenArraysVisitor::finish() {
GEPChainInfoMap.clear();
Copy link
Member

@farzonl farzonl Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. not sure if its necessary though. The GEPChainInfoMap will get recreated on every construction of th DXILFlattenArraysVisitor which should be the same as clearing it. That said the idea makes sense, there are going to be many invalid gep chains the longer we don't clear because we are flattening as we walk the chains. So I think it makes sense, but we should probably be clearing the GEPChainInfoMap after each function, while I think this clears it after each module.

Copy link
Contributor Author

@Icohedron Icohedron Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine. not sure if its necessary though.

This is a change I added later because I found that tests with multiple functions had GEPs that were incorrectly flattened because GEPOperator* from previous functions happened to collide with GEPOperator* in the GEPChainInfoMap while processing the current function.

The GEPChainInfoMap will get recreated on every construction of th DXILFlattenArraysVisitor which should be the same as clearing it.

The DXILFlattenArraysVisitor is only created once and it's for the module.

static bool flattenArrays(Module &M) {
bool MadeChange = false;
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
flattenGlobalArrays(M, GlobalMap);
DXILFlattenArraysVisitor Impl(GlobalMap);
for (auto &F : make_early_inc_range(M.functions())) {
if (F.isDeclaration())
continue;
MadeChange |= Impl.visit(F);
}
for (auto &[Old, New] : GlobalMap) {
Old->replaceAllUsesWith(New);
Old->eraseFromParent();
MadeChange = true;
}
return MadeChange;
}
PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) {
bool MadeChanges = flattenArrays(M);
if (!MadeChanges)
return PreservedAnalyses::all();
PreservedAnalyses PA;
return PA;
}
bool DXILFlattenArraysLegacy::runOnModule(Module &M) {
return flattenArrays(M);
}

we should probably be clearing the GEPChainInfoMap after each function

The finish() function is called at the end of DXILFlattenArraysVisitor::visit(Function &F), so the GEPChainInfoMap is indeed cleared after each function.

bool DXILFlattenArraysVisitor::visit(Function &F) {
bool MadeChange = false;
ReversePostOrderTraversal<Function *> RPOT(&F);
for (BasicBlock *BB : make_early_inc_range(RPOT)) {
for (Instruction &I : make_early_inc_range(*BB))
MadeChange |= InstVisitor::visit(I);
}
finish();
return MadeChange;
}


GEP.replaceAllUsesWith(NewGEPI);
GEP.eraseFromParent();
visitGetElementPtrInst(*OldGEPI);
Copy link
Member

@farzonl farzonl Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concerns about this if this can be recursive in the constexpr case more than one level deep. Your calls to visitGetElementPtrInst and you are doing eraseFromParent before you do the recursions. We usualy can only get away with doing one instruction erasure via the visitor pattern because of our use of make_early_inc_range To be able to erase more you typically need to start by erasing the leaf geps in the gep chain.

Copy link
Contributor Author

@Icohedron Icohedron Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine because there can only be one ConstantExpr GEP in a chain and it has to be the root/start of the chain.
A ConstantExpr GEP can not take virtual registers as any of its operands, which would imply that the ptr operand has to be a global variable and the indices must be constant ints.

// GEPs without users or GEPs with non-GEP users should be replaced such that
// the chain of GEPs they are a part of are collapsed to a single GEP into a
// flattened array.
bool ReplaceThisGEP = GEP.users().empty();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems unnecessary, why can't you just do

Suggested change
bool ReplaceThisGEP = GEP.users().empty();
bool ReplaceThisGEP = false;

Copy link
Contributor Author

@Icohedron Icohedron Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to preserve the behavior of the previous implementation in the llvm/test/CodeGen/DirectX tests. There are many tests that create GEPs that are not used. So I replace unused GEPs to keep the tests working.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can set bool ReplaceThisGEP = false; and then update the tests so that all the GEPs are used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we can leave this as is for now.

"Constant byte offset for flat GEP index must fit within 32 bits");
Value *FlattenedIndex = Builder.getInt32(ConstantOffset);
for (auto [VarIndex, Multiplier] : Info.VariableOffsets) {
assert(Multiplier.getActiveBits() <= 32 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't there a 32 bit constant we can use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, Multiplier.getActiveBits() <= 32 is already clear enough that it's testing the number held fits within an unsigned 32 bit integer.

GEP.replaceAllUsesWith(NewGEPI);
GEP.eraseFromParent();
visitGetElementPtrInst(*OldGEPI);
visitGetElementPtrInst(*NewGEPI);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to visit the new Gep?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GEPChainInfoMap entry for the new GEP needs to be filled in case the new GEP is not the leaf/end of a GEP chain.
The new GEP may also be the leaf/end of a GEP chain itself and need to be replaced with a fattened GEP.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it just needs to be for book keeping purposes can we take the book keeping part out and make it a helper that we can call here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not just for book keeping; the new GEP can be replaced/flattened in that visit.

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
3 participants