From b08bb11fd23d314360d8407fef2bc34b77d58e7c Mon Sep 17 00:00:00 2001 From: Icohedron Date: Thu, 26 Jun 2025 23:16:56 +0000 Subject: [PATCH 01/11] Simplify flattening of GEP chains This simplification also fixes instances of incorrect flat index computations --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 271 +++++++++--------- 1 file changed, 132 insertions(+), 139 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 0b7cf2f970172..913a8dcb917f4 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -40,18 +40,19 @@ class DXILFlattenArraysLegacy : public ModulePass { static char ID; // Pass identification. }; -struct GEPData { - ArrayType *ParentArrayType; - Value *ParentOperand; - SmallVector Indices; - SmallVector Dims; - bool AllIndicesAreConstInt; +struct GEPInfo { + ArrayType *RootFlattenedArrayType; + Value *RootPointerOperand; + SmallMapVector VariableOffsets; + APInt ConstantOffset; }; class DXILFlattenArraysVisitor : public InstVisitor { public: - DXILFlattenArraysVisitor() {} + DXILFlattenArraysVisitor( + DenseMap &GlobalMap) + : GlobalMap(GlobalMap) {} bool visit(Function &F); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. @@ -78,7 +79,8 @@ class DXILFlattenArraysVisitor private: SmallVector PotentiallyDeadInstrs; - DenseMap GEPChainMap; + DenseMap GEPChainInfoMap; + DenseMap &GlobalMap; bool finish(); ConstantInt *genConstFlattenIndices(ArrayRef Indices, ArrayRef Dims, @@ -86,23 +88,6 @@ class DXILFlattenArraysVisitor Value *genInstructionFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); - - // Helper function to collect indices and dimensions from a GEP instruction - void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP, - SmallVectorImpl &Indices, - SmallVectorImpl &Dims, - bool &AllIndicesAreConstInt); - - void - recursivelyCollectGEPs(GetElementPtrInst &CurrGEP, - ArrayType *FlattenedArrayType, Value *PtrOperand, - unsigned &GEPChainUseCount, - SmallVector Indices = SmallVector(), - SmallVector Dims = SmallVector(), - bool AllIndicesAreConstInt = true); - bool visitGetElementPtrInstInGEPChain(GetElementPtrInst &GEP); - bool visitGetElementPtrInstInGEPChainBase(GEPData &GEPInfo, - GetElementPtrInst &GEP); }; } // namespace @@ -225,131 +210,139 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { return true; } -void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP( - GetElementPtrInst &GEP, SmallVectorImpl &Indices, - SmallVectorImpl &Dims, bool &AllIndicesAreConstInt) { - - Type *CurrentType = GEP.getSourceElementType(); +bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { + // Do not visit GEPs more than once + if (GEPChainInfoMap.contains(cast(&GEP))) + return false; - // Note index 0 is the ptr index. - for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) { - Indices.push_back(Index); - AllIndicesAreConstInt &= isa(Index); + // Construct GEPInfo for this GEP + GEPInfo Info; - if (auto *ArrayTy = dyn_cast(CurrentType)) { - Dims.push_back(ArrayTy->getNumElements()); - CurrentType = ArrayTy->getElementType(); - } else { - assert(false && "Expected array type in GEP chain"); - } - } -} + // Obtain the variable and constant byte offsets computed by this GEP + const DataLayout &DL = GEP.getDataLayout(); + unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); + Info.ConstantOffset = {BitWidth, 0}; + bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets, + Info.ConstantOffset); + (void)Success; + assert(Success && "Failed to collect offsets for GEP"); -void DXILFlattenArraysVisitor::recursivelyCollectGEPs( - GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, - Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector Indices, - SmallVector Dims, bool AllIndicesAreConstInt) { - // Check if this GEP is already in the map to avoid circular references - if (GEPChainMap.count(&CurrGEP) > 0) - return; + Value *PtrOperand = GEP.getPointerOperand(); - // Collect indices and dimensions from the current GEP - collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt); - bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType()); - if (!IsMultiDimArr) { - assert(GEPChainUseCount < FlattenedArrayType->getNumElements()); - GEPChainMap.insert( - {&CurrGEP, - {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), - std::move(Dims), AllIndicesAreConstInt}}); - return; - } - bool GepUses = false; - for (auto *User : CurrGEP.users()) { - if (GetElementPtrInst *NestedGEP = dyn_cast(User)) { - recursivelyCollectGEPs(*NestedGEP, FlattenedArrayType, PtrOperand, - ++GEPChainUseCount, Indices, Dims, - AllIndicesAreConstInt); - GepUses = true; + // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that + // it can be visited + if (auto *PtrOpGEPCE = dyn_cast(PtrOperand)) + if (PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEPI = + cast(PtrOpGEPCE->getAsInstruction()); + OldGEPI->insertBefore(GEP.getIterator()); + + IRBuilder<> Builder(&GEP); + SmallVector Indices(GEP.idx_begin(), GEP.idx_end()); + Value *NewGEP = + Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, + GEP.getName(), GEP.getNoWrapFlags()); + assert(isa(NewGEP) && + "Expected newly-created GEP to not be a ConstantExpr"); + GetElementPtrInst *NewGEPI = cast(NewGEP); + + GEP.replaceAllUsesWith(NewGEPI); + GEP.eraseFromParent(); + visitGetElementPtrInst(*OldGEPI); + visitGetElementPtrInst(*NewGEPI); + return true; } - } - // This case is just incase the gep chain doesn't end with a 1d array. - if (IsMultiDimArr && GEPChainUseCount > 0 && !GepUses) { - GEPChainMap.insert( - {&CurrGEP, - {std::move(FlattenedArrayType), PtrOperand, std::move(Indices), - std::move(Dims), AllIndicesAreConstInt}}); - } -} -bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChain( - GetElementPtrInst &GEP) { - GEPData GEPInfo = GEPChainMap.at(&GEP); - return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); -} -bool DXILFlattenArraysVisitor::visitGetElementPtrInstInGEPChainBase( - GEPData &GEPInfo, GetElementPtrInst &GEP) { - IRBuilder<> Builder(&GEP); - Value *FlatIndex; - if (GEPInfo.AllIndicesAreConstInt) - FlatIndex = genConstFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); - else - FlatIndex = - genInstructionFlattenIndices(GEPInfo.Indices, GEPInfo.Dims, Builder); - - ArrayType *FlattenedArrayType = GEPInfo.ParentArrayType; - - // Don't append '.flat' to an empty string. If the SSA name isn't available - // it could conflict with the ParentOperand's name. - std::string FlatName = GEP.hasName() ? GEP.getName().str() + ".flat" : ""; - - Value *FlatGEP = Builder.CreateGEP(FlattenedArrayType, GEPInfo.ParentOperand, - {Builder.getInt32(0), FlatIndex}, FlatName, - GEP.getNoWrapFlags()); - - // Note: Old gep will become an invalid instruction after replaceAllUsesWith. - // Erase the old GEP in the map before to avoid invalid instructions - // and circular references. - GEPChainMap.erase(&GEP); - - GEP.replaceAllUsesWith(FlatGEP); - GEP.eraseFromParent(); - return true; -} + // If there is a parent GEP, inherit the root array type and pointer, and + // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP + // chain and we need to deterine the root array type + if (auto *PtrOpGEP = dyn_cast(PtrOperand)) { + assert(GEPChainInfoMap.contains(PtrOpGEP) && + "Expected parent GEP to be visited before this GEP"); + GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP]; + Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType; + Info.RootPointerOperand = PGEPInfo.RootPointerOperand; + for (auto &VariableOffset : PGEPInfo.VariableOffsets) + Info.VariableOffsets.insert(VariableOffset); + Info.ConstantOffset += PGEPInfo.ConstantOffset; + } else { + Info.RootPointerOperand = PtrOperand; + + // We should try to determine the type of the root from the pointer rather + // than the GEP's source element type because this could be a scalar GEP + // into a multidimensional array-typed pointer from an Alloca or Global + // Variable. + Type *RootTy = GEP.getSourceElementType(); + if (auto *GlobalVar = dyn_cast(PtrOperand)) { + if (!GlobalMap.contains(GlobalVar)) + return false; + GlobalVariable *NewGlobal = GlobalMap[GlobalVar]; + Info.RootPointerOperand = NewGlobal; + RootTy = NewGlobal->getValueType(); + } else if (auto *Alloca = dyn_cast(PtrOperand)) { + RootTy = Alloca->getAllocatedType(); + } + assert(!isMultiDimensionalArray(RootTy) && + "Expected root array type to be flattened"); -bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { - auto It = GEPChainMap.find(&GEP); - if (It != GEPChainMap.end()) - return visitGetElementPtrInstInGEPChain(GEP); - if (!isMultiDimensionalArray(GEP.getSourceElementType())) - return false; + // If the root type is not an array, we don't need to do any flattening + if (!isa(RootTy)) + return false; - ArrayType *ArrType = cast(GEP.getSourceElementType()); - IRBuilder<> Builder(&GEP); - auto [TotalElements, BaseType] = getElementCountAndType(ArrType); - ArrayType *FlattenedArrayType = ArrayType::get(BaseType, TotalElements); + Info.RootFlattenedArrayType = cast(RootTy); + } - Value *PtrOperand = GEP.getPointerOperand(); + // GEPs without users or GEPs with non-GEP users should be replaced such that + // the chain of GEPs they are a part of are collapsed to a single GEP into a + // flattened array. + bool ReplaceThisGEP = GEP.users().empty(); + for (Value *User : GEP.users()) + if (!isa(User)) + ReplaceThisGEP = true; + + if (ReplaceThisGEP) { + // GEP.collectOffset returns the offset in bytes. So we need to divide its + // offsets by the size in bytes of the element type + unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType() + ->getPrimitiveSizeInBits() / + 8; + + // Compute the 32-bit index for this flattened GEP from the constant and + // variable byte offsets in the GEPInfo + IRBuilder<> Builder(&GEP); + Value *ZeroIndex = Builder.getInt32(0); + uint64_t ConstantOffset = + Info.ConstantOffset.udiv(BytesPerElem).getZExtValue(); + assert(ConstantOffset < UINT32_MAX && + "Constant byte offset for flat GEP index must fit within 32 bits"); + Value *FlattenedIndex = Builder.getInt32(ConstantOffset); + for (auto [VarIndex, Multiplier] : Info.VariableOffsets) { + uint64_t Mul = Multiplier.udiv(BytesPerElem).getZExtValue(); + assert(Mul < UINT32_MAX && + "Multiplier for flat GEP index must fit within 32 bits"); + assert(VarIndex->getType()->isIntegerTy(32) && + "Expected i32-typed GEP indices"); + Value *ConstIntMul = Builder.getInt32(Mul); + Value *MulVarIndex = Builder.CreateMul(VarIndex, ConstIntMul); + FlattenedIndex = Builder.CreateAdd(FlattenedIndex, MulVarIndex); + } - unsigned GEPChainUseCount = 0; - recursivelyCollectGEPs(GEP, FlattenedArrayType, PtrOperand, GEPChainUseCount); - - // NOTE: hasNUses(0) is not the same as GEPChainUseCount == 0. - // Here recursion is used to get the length of the GEP chain. - // Handle zero uses here because there won't be an update via - // a child in the chain later. - if (GEPChainUseCount == 0) { - SmallVector Indices; - SmallVector Dims; - bool AllIndicesAreConstInt = true; - - // Collect indices and dimensions from the GEP - collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt); - GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand, - std::move(Indices), std::move(Dims), AllIndicesAreConstInt}; - return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); + // Construct a new GEP for the flattened array to replace the current GEP + Value *NewGEP = Builder.CreateGEP( + Info.RootFlattenedArrayType, Info.RootPointerOperand, + {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags()); + + // Replace the current GEP with the new GEP. Store GEPInfo into the map + // for later use in case this GEP was not the end of the chain + GEPChainInfoMap.insert({cast(NewGEP), std::move(Info)}); + GEP.replaceAllUsesWith(NewGEP); + GEP.eraseFromParent(); + return true; } + // This GEP is potentially dead at the end of the pass since it may not have + // any users anymore after GEP chains have been collapsed. + GEPChainInfoMap.insert({cast(&GEP), std::move(Info)}); PotentiallyDeadInstrs.emplace_back(&GEP); return false; } @@ -456,9 +449,9 @@ flattenGlobalArrays(Module &M, static bool flattenArrays(Module &M) { bool MadeChange = false; - DXILFlattenArraysVisitor Impl; DenseMap GlobalMap; flattenGlobalArrays(M, GlobalMap); + DXILFlattenArraysVisitor Impl(GlobalMap); for (auto &F : make_early_inc_range(M.functions())) { if (F.isDeclaration()) continue; From e7b9d2131a327dd352664b0574dac7ddf2d7e6f9 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Fri, 27 Jun 2025 21:23:34 +0000 Subject: [PATCH 02/11] Fix tests with incorrect GEPs A few tests had incorrect GEP indices or types. This commit fixes these GEPs and array types. --- llvm/test/CodeGen/DirectX/flatten-array.ll | 4 +- .../CodeGen/DirectX/flatten-bug-117273.ll | 8 +-- .../DirectX/llc-vector-load-scalarize.ll | 60 +++++++++---------- .../test/CodeGen/DirectX/scalar-bug-117273.ll | 4 +- llvm/test/CodeGen/DirectX/scalarize-alloca.ll | 4 +- 5 files changed, 38 insertions(+), 42 deletions(-) diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll index dc8c5f8421bfe..e256146bb74f4 100644 --- a/llvm/test/CodeGen/DirectX/flatten-array.ll +++ b/llvm/test/CodeGen/DirectX/flatten-array.ll @@ -159,9 +159,9 @@ define void @global_gep_load_index(i32 %row, i32 %col, i32 %timeIndex) { define void @global_incomplete_gep_chain(i32 %row, i32 %col) { ; CHECK-LABEL: define void @global_incomplete_gep_chain( ; CHECK-SAME: i32 [[ROW:%.*]], i32 [[COL:%.*]]) { -; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[COL]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[COL]], 4 ; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i32 [[ROW]], 3 +; CHECK-NEXT: [[TMP3:%.*]] = mul i32 [[ROW]], 12 ; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]] ; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 0, i32 [[TMP4]] ; CHECK-NOT: getelementptr inbounds [2 x [3 x [4 x i32]]]{{.*}} diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll index c73e5017348d1..930805f0ddc90 100644 --- a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll +++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll @@ -8,16 +8,16 @@ define internal void @main() { ; CHECK-LABEL: define internal void @main() { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 0, i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3 ; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16 -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [24 x float], ptr @ZerroInitArr.1dim, i32 0, i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6 ; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16 ; CHECK-NEXT: ret void ; entry: - %0 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1 + %0 = getelementptr [2 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 1 %.i0 = load float, ptr %0, align 16 - %1 = getelementptr [8 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2 + %1 = getelementptr [2 x [3 x float]], ptr @ZerroInitArr, i32 0, i32 2 %.i03 = load float, ptr %1, align 16 ret void } diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll index d5797f6b51348..78550adbe424a 100644 --- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll +++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll @@ -3,43 +3,35 @@ ; Make sure we can load groupshared, static vectors and arrays of vectors -@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <3 x float>] zeroinitializer, align 16 +@"arrayofVecData" = local_unnamed_addr addrspace(3) global [2 x <4 x i32>] zeroinitializer, align 16 @"vecData" = external addrspace(3) global <4 x i32>, align 4 @staticArrayOfVecData = internal global [3 x <4 x i32>] [<4 x i32> , <4 x i32> , <4 x i32> ], align 4 -@"groushared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [ 3 x <4 x i32>]] zeroinitializer, align 16 +@"groupshared2dArrayofVectors" = local_unnamed_addr addrspace(3) global [3 x [3 x <4 x i32>]] zeroinitializer, align 16 -; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [6 x float] zeroinitializer, align 16 +; CHECK: @arrayofVecData.scalarized.1dim = local_unnamed_addr addrspace(3) global [8 x i32] zeroinitializer, align 16 ; CHECK: @vecData.scalarized = external addrspace(3) global [4 x i32], align 4 ; CHECK: @staticArrayOfVecData.scalarized.1dim = internal global [12 x i32] [i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12], align 4 -; CHECK: @groushared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16 +; CHECK: @groupshared2dArrayofVectors.scalarized.1dim = local_unnamed_addr addrspace(3) global [36 x i32] zeroinitializer, align 16 ; CHECK-NOT: @arrayofVecData ; CHECK-NOT: @arrayofVecData.scalarized ; CHECK-NOT: @vecData ; CHECK-NOT: @staticArrayOfVecData ; CHECK-NOT: @staticArrayOfVecData.scalarized -; CHECK-NOT: @groushared2dArrayofVectors -; CHECK-NOT: @groushared2dArrayofVectors.scalarized +; CHECK-NOT: @groupshared2dArrayofVectors +; CHECK-NOT: @groupshared2dArrayofVectors.scalarized define <4 x i32> @load_array_vec_test() #0 { ; CHECK-LABEL: define <4 x i32> @load_array_vec_test( ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] { -; CHECK-NEXT: [[TMP1:%.*]] = bitcast ptr addrspace(3) @arrayofVecData.scalarized.1dim to ptr addrspace(3) -; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr addrspace(3) [[TMP1]], align 4 -; CHECK-NEXT: [[TMP3:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1) to ptr addrspace(3) -; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr addrspace(3) [[TMP3]], align 4 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 2) to ptr addrspace(3) -; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4 -; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3) to ptr addrspace(3) -; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4 -; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1) to ptr addrspace(3) -; CHECK-NEXT: [[TMP10:%.*]] = load i32, ptr addrspace(3) [[TMP9]], align 4 -; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1), i32 1) to ptr addrspace(3) -; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4 -; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1), i32 2) to ptr addrspace(3) -; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr addrspace(3) [[TMP13]], align 4 -; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([6 x float], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 1), i32 3) to ptr addrspace(3) -; CHECK-NEXT: [[TMP16:%.*]] = load i32, ptr addrspace(3) [[TMP15]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, align 4 +; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 1), align 4 +; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 2), align 4 +; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 3), align 4 +; CHECK-NEXT: [[TMP10:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), align 4 +; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), i32 1), align 4 +; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), i32 2), align 4 +; CHECK-NEXT: [[TMP16:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([8 x i32], ptr addrspace(3) @arrayofVecData.scalarized.1dim, i32 0, i32 4), i32 3), align 4 ; CHECK-NEXT: [[DOTI05:%.*]] = add i32 [[TMP2]], [[TMP10]] ; CHECK-NEXT: [[DOTI16:%.*]] = add i32 [[TMP4]], [[TMP12]] ; CHECK-NEXT: [[DOTI27:%.*]] = add i32 [[TMP6]], [[TMP14]] @@ -77,7 +69,9 @@ define <4 x i32> @load_vec_test() #0 { define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 { ; CHECK-LABEL: define <4 x i32> @load_static_array_of_vec_test( ; CHECK-SAME: i32 [[INDEX:%.*]]) #[[ATTR0]] { -; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 0, i32 [[INDEX]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i32 [[INDEX]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[TMP3]] +; CHECK-NEXT: [[DOTFLAT:%.*]] = getelementptr inbounds [12 x i32], ptr @staticArrayOfVecData.scalarized.1dim, i32 0, i32 [[TMP2]] ; CHECK-NEXT: [[DOTI0:%.*]] = load i32, ptr [[DOTFLAT]], align 4 ; CHECK-NEXT: [[DOTFLAT_I1:%.*]] = getelementptr i32, ptr [[DOTFLAT]], i32 1 ; CHECK-NEXT: [[DOTI1:%.*]] = load i32, ptr [[DOTFLAT_I1]], align 4 @@ -99,14 +93,14 @@ define <4 x i32> @load_static_array_of_vec_test(i32 %index) #0 { define <4 x i32> @multid_load_test() #0 { ; CHECK-LABEL: define <4 x i32> @multid_load_test( ; CHECK-SAME: ) #[[ATTR0]] { -; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, align 4 -; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 1), align 4 -; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 2), align 4 -; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3), align 4 -; CHECK-NEXT: [[TMP5:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), align 4 -; CHECK-NEXT: [[DOTI13:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 1), align 4 -; CHECK-NEXT: [[DOTI25:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 2), align 4 -; CHECK-NEXT: [[DOTI37:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 3), align 4 +; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, align 4 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 1), align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 2), align 4 +; CHECK-NEXT: [[TMP4:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 3), align 4 +; CHECK-NEXT: [[TMP5:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), align 4 +; CHECK-NEXT: [[DOTI13:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), i32 1), align 4 +; CHECK-NEXT: [[DOTI25:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), i32 2), align 4 +; CHECK-NEXT: [[DOTI37:%.*]] = load i32, ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groupshared2dArrayofVectors.scalarized.1dim, i32 0, i32 16), i32 3), align 4 ; CHECK-NEXT: [[DOTI08:%.*]] = add i32 [[TMP1]], [[TMP5]] ; CHECK-NEXT: [[DOTI19:%.*]] = add i32 [[TMP2]], [[DOTI13]] ; CHECK-NEXT: [[DOTI210:%.*]] = add i32 [[TMP3]], [[DOTI25]] @@ -117,8 +111,8 @@ define <4 x i32> @multid_load_test() #0 { ; CHECK-NEXT: [[TMP6:%.*]] = insertelement <4 x i32> [[DOTUPTO217]], i32 [[DOTI311]], i32 3 ; CHECK-NEXT: ret <4 x i32> [[TMP6]] ; - %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 0, i32 0), align 4 - %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groushared2dArrayofVectors", i32 0, i32 1, i32 1), align 4 + %1 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groupshared2dArrayofVectors", i32 0, i32 0, i32 0), align 4 + %2 = load <4 x i32>, <4 x i32> addrspace(3)* getelementptr inbounds ([3 x [3 x <4 x i32>]], [3 x [3 x <4 x i32>]] addrspace(3)* @"groupshared2dArrayofVectors", i32 0, i32 1, i32 1), align 4 %3 = add <4 x i32> %1, %2 ret <4 x i32> %3 } diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll index a07ce2c24f7ac..9ce2108a03831 100644 --- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll +++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll @@ -8,13 +8,13 @@ define internal void @main() #1 { ; CHECK-LABEL: define internal void @main() { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 1 +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3 ; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16 ; CHECK-NEXT: [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1 ; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4 ; CHECK-NEXT: [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2 ; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8 -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 2 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6 ; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16 ; CHECK-NEXT: [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1 ; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4 diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll index 32e2c3ca2c302..a8557e47b0ea6 100644 --- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll +++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll @@ -33,7 +33,9 @@ define void @alloca_2d_gep_test() { ; FCHECK: [[alloca_val:%.*]] = alloca [4 x i32], align 16 ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [2 x [2 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]] - ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[tid]] + ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 2 + ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] + ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]] ; CHECK: ret void %1 = alloca [2 x <2 x i32>], align 16 %2 = tail call i32 @llvm.dx.thread.id(i32 0) From 1a09803d75b42e9b70c37e7e92c0ca2f3ddcf3f0 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Sat, 28 Jun 2025 01:39:17 +0000 Subject: [PATCH 03/11] Allow flattening GEPs for Global Variables not in the GlobalMap --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 41 +++++++++---------- llvm/test/CodeGen/DirectX/flatten-array.ll | 6 +-- .../CodeGen/DirectX/flatten-bug-117273.ll | 6 +-- .../test/CodeGen/DirectX/scalar-bug-117273.ll | 18 +++----- 4 files changed, 30 insertions(+), 41 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 913a8dcb917f4..e58cd829d96d8 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -215,18 +215,6 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEPChainInfoMap.contains(cast(&GEP))) return false; - // Construct GEPInfo for this GEP - GEPInfo Info; - - // Obtain the variable and constant byte offsets computed by this GEP - const DataLayout &DL = GEP.getDataLayout(); - unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); - Info.ConstantOffset = {BitWidth, 0}; - bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets, - Info.ConstantOffset); - (void)Success; - assert(Success && "Failed to collect offsets for GEP"); - Value *PtrOperand = GEP.getPointerOperand(); // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that @@ -243,7 +231,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, GEP.getName(), GEP.getNoWrapFlags()); assert(isa(NewGEP) && - "Expected newly-created GEP to not be a ConstantExpr"); + "Expected newly-created GEP to be an instruction"); GetElementPtrInst *NewGEPI = cast(NewGEP); GEP.replaceAllUsesWith(NewGEPI); @@ -253,6 +241,18 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { return true; } + // Construct GEPInfo for this GEP + GEPInfo Info; + + // Obtain the variable and constant byte offsets computed by this GEP + const DataLayout &DL = GEP.getDataLayout(); + unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); + Info.ConstantOffset = {BitWidth, 0}; + bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets, + Info.ConstantOffset); + (void)Success; + assert(Success && "Failed to collect offsets for GEP"); + // If there is a parent GEP, inherit the root array type and pointer, and // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP // chain and we need to deterine the root array type @@ -270,15 +270,13 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { // We should try to determine the type of the root from the pointer rather // than the GEP's source element type because this could be a scalar GEP - // into a multidimensional array-typed pointer from an Alloca or Global - // Variable. + // into an array-typed pointer from an Alloca or Global Variable. Type *RootTy = GEP.getSourceElementType(); if (auto *GlobalVar = dyn_cast(PtrOperand)) { - if (!GlobalMap.contains(GlobalVar)) - return false; - GlobalVariable *NewGlobal = GlobalMap[GlobalVar]; - Info.RootPointerOperand = NewGlobal; - RootTy = NewGlobal->getValueType(); + if (GlobalMap.contains(GlobalVar)) + GlobalVar = GlobalMap[GlobalVar]; + Info.RootPointerOperand = GlobalVar; + RootTy = GlobalVar->getValueType(); } else if (auto *Alloca = dyn_cast(PtrOperand)) { RootTy = Alloca->getAllocatedType(); } @@ -341,7 +339,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { } // This GEP is potentially dead at the end of the pass since it may not have - // any users anymore after GEP chains have been collapsed. + // any users anymore after GEP chains have been collapsed. We retain store + // GEPInfo for GEPs down the chain to use to compute their indices. GEPChainInfoMap.insert({cast(&GEP), std::move(Info)}); PotentiallyDeadInstrs.emplace_back(&GEP); return false; diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll index e256146bb74f4..dbb1d95df16f3 100644 --- a/llvm/test/CodeGen/DirectX/flatten-array.ll +++ b/llvm/test/CodeGen/DirectX/flatten-array.ll @@ -123,8 +123,7 @@ define void @gep_4d_test () { @b = internal global [2 x [3 x [4 x i32]]] zeroinitializer, align 16 define void @global_gep_load() { - ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @a.1dim, i32 0, i32 6 - ; CHECK-NEXT: load i32, ptr [[GEP_PTR]], align 4 + ; CHECK: {{.*}} = load i32, ptr getelementptr inbounds ([24 x i32], ptr @a.1dim, i32 0, i32 6), align 4 ; CHECK-NEXT: ret void %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @a, i32 0, i32 0 %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 1 @@ -177,8 +176,7 @@ define void @global_incomplete_gep_chain(i32 %row, i32 %col) { } define void @global_gep_store() { - ; CHECK: [[GEP_PTR:%.*]] = getelementptr inbounds [24 x i32], ptr @b.1dim, i32 0, i32 13 - ; CHECK-NEXT: store i32 1, ptr [[GEP_PTR]], align 4 + ; CHECK: store i32 1, ptr getelementptr inbounds ([24 x i32], ptr @b.1dim, i32 0, i32 13), align 4 ; CHECK-NEXT: ret void %1 = getelementptr inbounds [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* @b, i32 0, i32 1 %2 = getelementptr inbounds [3 x [4 x i32]], [3 x [4 x i32]]* %1, i32 0, i32 0 diff --git a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll index 930805f0ddc90..78971b8954150 100644 --- a/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll +++ b/llvm/test/CodeGen/DirectX/flatten-bug-117273.ll @@ -8,10 +8,8 @@ define internal void @main() { ; CHECK-LABEL: define internal void @main() { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3 -; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16 -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6 -; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16 +; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr ([6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 3), align 16 +; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr getelementptr ([6 x float], ptr @ZerroInitArr.1dim, i32 0, i32 6), align 16 ; CHECK-NEXT: ret void ; entry: diff --git a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll index 9ce2108a03831..43bbe9249aac0 100644 --- a/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll +++ b/llvm/test/CodeGen/DirectX/scalar-bug-117273.ll @@ -8,18 +8,12 @@ define internal void @main() #1 { ; CHECK-LABEL: define internal void @main() { ; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3 -; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr [[TMP0]], align 16 -; CHECK-NEXT: [[DOTI1:%.*]] = getelementptr float, ptr [[TMP0]], i32 1 -; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr [[DOTI1]], align 4 -; CHECK-NEXT: [[DOTI2:%.*]] = getelementptr float, ptr [[TMP0]], i32 2 -; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr [[DOTI2]], align 8 -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6 -; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr [[TMP1]], align 16 -; CHECK-NEXT: [[DOTI14:%.*]] = getelementptr float, ptr [[TMP1]], i32 1 -; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr [[DOTI14]], align 4 -; CHECK-NEXT: [[DOTI26:%.*]] = getelementptr float, ptr [[TMP1]], i32 2 -; CHECK-NEXT: [[DOTI27:%.*]] = load float, ptr [[DOTI26]], align 8 +; CHECK-NEXT: [[DOTI0:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3), align 16 +; CHECK-NEXT: [[DOTI11:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3), i32 1), align 4 +; CHECK-NEXT: [[DOTI22:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 3), i32 2), align 8 +; CHECK-NEXT: [[DOTI03:%.*]] = load float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6), align 16 +; CHECK-NEXT: [[DOTI15:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6), i32 1), align 4 +; CHECK-NEXT: [[DOTI27:%.*]] = load float, ptr getelementptr (float, ptr getelementptr inbounds ([24 x float], ptr @StaticArr.scalarized.1dim, i32 0, i32 6), i32 2), align 8 ; CHECK-NEXT: ret void ; entry: From dee77e9f5cf1bed2b13d6078d7255219cac70538 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Sat, 28 Jun 2025 05:33:51 +0000 Subject: [PATCH 04/11] Add test demonstrating the flattening of scalar GEPs, including i8 --- llvm/test/CodeGen/DirectX/flatten-array.ll | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll index dbb1d95df16f3..bd83a3da24cca 100644 --- a/llvm/test/CodeGen/DirectX/flatten-array.ll +++ b/llvm/test/CodeGen/DirectX/flatten-array.ll @@ -255,5 +255,24 @@ define void @gep_4d_index_and_gep_chain_mixed() { ret void } +; This test demonstrates that the collapsing of GEP chains occurs regardless of +; the source element type given to the GEP. As long as the root pointer being +; indexed to is an aggregate data structure, the GEP will be flattened. +define void @gep_scalar_flatten() { + ; CHECK-LABEL: gep_scalar_flatten + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [24 x i32] + ; CHECK-NEXT: getelementptr inbounds nuw [24 x i32], ptr [[ALLOCA]], i32 0, i32 17 + ; CHECK-NEXT: getelementptr inbounds nuw [24 x i32], ptr [[ALLOCA]], i32 0, i32 17 + ; CHECK-NEXT: getelementptr inbounds nuw [24 x i32], ptr [[ALLOCA]], i32 0, i32 23 + ; CHECK-NEXT: ret void + %a = alloca [2 x [3 x [4 x i32]]], align 4 + %i8root = getelementptr inbounds nuw i8, [2 x [3 x [4 x i32]]]* %a, i32 68 ; %a[1][1][1] + %i32root = getelementptr inbounds nuw i32, [2 x [3 x [4 x i32]]]* %a, i32 17 ; %a[1][1][1] + %c0 = getelementptr inbounds nuw [2 x [3 x [4 x i32]]], [2 x [3 x [4 x i32]]]* %a, i32 0, i32 1 ; %a[1] + %c1 = getelementptr inbounds nuw i32, [3 x [4 x i32]]* %c0, i32 8 ; %a[1][2] + %c2 = getelementptr inbounds nuw i8, [4 x i32]* %c1, i32 12 ; %a[1][2][3] + ret void +} + ; Make sure we don't try to walk the body of a function declaration. declare void @opaque_function() From 1ccaa986e4e9e024d91ef6a67509af657c7c9ebc Mon Sep 17 00:00:00 2001 From: Icohedron Date: Mon, 30 Jun 2025 21:25:57 +0000 Subject: [PATCH 05/11] Clear GEPChainInfoMap after visiting a function It is possible for GEPOperator* to overlap between function visits, so we have to clear the map or else there could be stale data leftover that causes GEPs to be incorrectly generated or skipped over. --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index e58cd829d96d8..0f1c7673da2cf 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -92,6 +92,7 @@ class DXILFlattenArraysVisitor } // namespace bool DXILFlattenArraysVisitor::finish() { + GEPChainInfoMap.clear(); RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); return true; } From 32806096da0117831839567b18def3b8cd1a474e Mon Sep 17 00:00:00 2001 From: Icohedron Date: Mon, 30 Jun 2025 21:49:14 +0000 Subject: [PATCH 06/11] Fix check for constantexpr GEP instead of an instruction --- llvm/test/CodeGen/DirectX/flatten-array.ll | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll index bd83a3da24cca..97e4d7a709260 100644 --- a/llvm/test/CodeGen/DirectX/flatten-array.ll +++ b/llvm/test/CodeGen/DirectX/flatten-array.ll @@ -202,8 +202,7 @@ define void @two_index_gep() { define void @two_index_gep_const() { ; CHECK-LABEL: define void @two_index_gep_const( - ; CHECK-NEXT: [[GEP_PTR:%.*]] = getelementptr inbounds nuw [4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 3 - ; CHECK-NEXT: load float, ptr addrspace(3) [[GEP_PTR]], align 4 + ; CHECK-NEXT: load float, ptr addrspace(3) getelementptr inbounds nuw ([4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 3), align 4 ; CHECK-NEXT: ret void %1 = getelementptr inbounds nuw [2 x [2 x float]], ptr addrspace(3) @g, i32 0, i32 1, i32 1 %3 = load float, ptr addrspace(3) %1, align 4 From 9ff05d250a061bad0ef8f94d6e0103cbf35dc750 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Mon, 30 Jun 2025 23:42:19 +0000 Subject: [PATCH 07/11] Use .indices() instead of .idx_begin() and .idx_end() --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 0f1c7673da2cf..33a3378aad105 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -227,7 +227,7 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { OldGEPI->insertBefore(GEP.getIterator()); IRBuilder<> Builder(&GEP); - SmallVector Indices(GEP.idx_begin(), GEP.idx_end()); + SmallVector Indices(GEP.indices()); Value *NewGEP = Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, GEP.getName(), GEP.getNoWrapFlags()); From 3734f57aa50172899cc970b07dd3a18569374e93 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Tue, 1 Jul 2025 02:01:57 +0000 Subject: [PATCH 08/11] Handle cases where Multiplier is not divisible by BytesPerElem in variable index calculation --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 24 ++++++++++++++----- llvm/test/CodeGen/DirectX/flatten-array.ll | 23 ++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 33a3378aad105..b9898d3411631 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/InstVisitor.h" #include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/Local.h" #include #include @@ -305,6 +306,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType() ->getPrimitiveSizeInBits() / 8; + assert(isPowerOf2_32(BytesPerElem) && + "Bytes per element should be a power of 2"); // Compute the 32-bit index for this flattened GEP from the constant and // variable byte offsets in the GEPInfo @@ -316,14 +319,23 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { "Constant byte offset for flat GEP index must fit within 32 bits"); Value *FlattenedIndex = Builder.getInt32(ConstantOffset); for (auto [VarIndex, Multiplier] : Info.VariableOffsets) { - uint64_t Mul = Multiplier.udiv(BytesPerElem).getZExtValue(); - assert(Mul < UINT32_MAX && - "Multiplier for flat GEP index must fit within 32 bits"); + assert(Multiplier.getActiveBits() <= 32 && + "The multiplier for a flat GEP index must fit within 32 bits"); assert(VarIndex->getType()->isIntegerTy(32) && "Expected i32-typed GEP indices"); - Value *ConstIntMul = Builder.getInt32(Mul); - Value *MulVarIndex = Builder.CreateMul(VarIndex, ConstIntMul); - FlattenedIndex = Builder.CreateAdd(FlattenedIndex, MulVarIndex); + Value *VI; + if (Multiplier.getZExtValue() % BytesPerElem != 0) { + // This can happen, e.g., with i8 GEPs. To handle this we just divide + // by BytesPerElem using an instruction after multiplying VarIndex by + // Multiplier. + VI = Builder.CreateMul(VarIndex, + Builder.getInt32(Multiplier.getZExtValue())); + VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem))); + } else + VI = Builder.CreateMul( + VarIndex, + Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem)); + FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI); } // Construct a new GEP for the flattened array to replace the current GEP diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll index 97e4d7a709260..2dd6ee2e53c20 100644 --- a/llvm/test/CodeGen/DirectX/flatten-array.ll +++ b/llvm/test/CodeGen/DirectX/flatten-array.ll @@ -273,5 +273,28 @@ define void @gep_scalar_flatten() { ret void } +define void @gep_scalar_flatten_dynamic(i32 %index) { + ; CHECK-LABEL: gep_scalar_flatten_dynamic + ; CHECK-SAME: i32 [[INDEX:%.*]]) { + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [6 x i32], align 4 + ; CHECK-NEXT: [[I8INDEX:%.*]] = mul i32 [[INDEX]], 12 + ; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[I8INDEX]], 1 + ; CHECK-NEXT: [[DIV:%.*]] = lshr i32 [[MUL]], 2 + ; CHECK-NEXT: [[ADD:%.*]] = add i32 0, [[DIV]] + ; CHECK-NEXT: getelementptr inbounds nuw [6 x i32], ptr [[ALLOCA]], i32 0, i32 [[ADD]] + ; CHECK-NEXT: [[I32INDEX:%.*]] = mul i32 [[INDEX]], 3 + ; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[I32INDEX]], 1 + ; CHECK-NEXT: [[ADD:%.*]] = add i32 0, [[MUL]] + ; CHECK-NEXT: getelementptr inbounds nuw [6 x i32], ptr [[ALLOCA]], i32 0, i32 [[ADD]] + ; CHECK-NEXT: ret void + ; + %a = alloca [2 x [3 x i32]], align 4 + %i8index = mul i32 %index, 12 + %i8root = getelementptr inbounds nuw i8, [2 x [3 x i32]]* %a, i32 %i8index; + %i32index = mul i32 %index, 3 + %i32root = getelementptr inbounds nuw i32, [2 x [3 x i32]]* %a, i32 %i32index; + ret void +} + ; Make sure we don't try to walk the body of a function declaration. declare void @opaque_function() From 5e0dd955a1cfda707243f7bce219f7782a3f1d94 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Tue, 8 Jul 2025 16:44:01 +0000 Subject: [PATCH 09/11] Use [[maybe_unused]] instead of (void) cast --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index b9898d3411631..5214551c70f5c 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -250,9 +250,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { const DataLayout &DL = GEP.getDataLayout(); unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); Info.ConstantOffset = {BitWidth, 0}; - bool Success = GEP.collectOffset(DL, BitWidth, Info.VariableOffsets, - Info.ConstantOffset); - (void)Success; + [[maybe_unused]] bool Success = GEP.collectOffset( + DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset); assert(Success && "Failed to collect offsets for GEP"); // If there is a parent GEP, inherit the root array type and pointer, and From 56ea5b8d94ea45d4678694c45b75c325fbc14f2b Mon Sep 17 00:00:00 2001 From: Icohedron Date: Tue, 8 Jul 2025 18:06:34 +0000 Subject: [PATCH 10/11] Combine PtrOpGEPCE if statements into one --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 5214551c70f5c..88400b7f0312a 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -221,27 +221,27 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that // it can be visited - if (auto *PtrOpGEPCE = dyn_cast(PtrOperand)) - if (PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { - GetElementPtrInst *OldGEPI = - cast(PtrOpGEPCE->getAsInstruction()); - OldGEPI->insertBefore(GEP.getIterator()); - - IRBuilder<> Builder(&GEP); - SmallVector Indices(GEP.indices()); - Value *NewGEP = - Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, - GEP.getName(), GEP.getNoWrapFlags()); - assert(isa(NewGEP) && - "Expected newly-created GEP to be an instruction"); - GetElementPtrInst *NewGEPI = cast(NewGEP); - - GEP.replaceAllUsesWith(NewGEPI); - GEP.eraseFromParent(); - visitGetElementPtrInst(*OldGEPI); - visitGetElementPtrInst(*NewGEPI); - return true; - } + if (auto *PtrOpGEPCE = dyn_cast(PtrOperand); + PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEPI = + cast(PtrOpGEPCE->getAsInstruction()); + OldGEPI->insertBefore(GEP.getIterator()); + + IRBuilder<> Builder(&GEP); + SmallVector Indices(GEP.indices()); + Value *NewGEP = + Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, + GEP.getName(), GEP.getNoWrapFlags()); + assert(isa(NewGEP) && + "Expected newly-created GEP to be an instruction"); + GetElementPtrInst *NewGEPI = cast(NewGEP); + + GEP.replaceAllUsesWith(NewGEPI); + GEP.eraseFromParent(); + visitGetElementPtrInst(*OldGEPI); + visitGetElementPtrInst(*NewGEPI); + return true; + } // Construct GEPInfo for this GEP GEPInfo Info; From bc5b053fb4be41ff2696219867df4362fdc30923 Mon Sep 17 00:00:00 2001 From: Icohedron Date: Tue, 8 Jul 2025 18:20:37 +0000 Subject: [PATCH 11/11] Use DL.getTypeAllocSize to get BytesPerElem --- llvm/lib/Target/DirectX/DXILFlattenArrays.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index 88400b7f0312a..db9fd31bfbc32 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -300,11 +300,8 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { ReplaceThisGEP = true; if (ReplaceThisGEP) { - // GEP.collectOffset returns the offset in bytes. So we need to divide its - // offsets by the size in bytes of the element type - unsigned BytesPerElem = Info.RootFlattenedArrayType->getArrayElementType() - ->getPrimitiveSizeInBits() / - 8; + unsigned BytesPerElem = + DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType()); assert(isPowerOf2_32(BytesPerElem) && "Bytes per element should be a power of 2");