From 1714e7f2dec96b9a5bc2f43f3b08359ddfcf9e5c Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 9 Aug 2024 15:45:40 +0000 Subject: [PATCH 1/2] [mlir][vector] Use `DenseI64ArrayAttr` in vector.multi_reduction This prevents some unnecessary conversions to/from int64_t and IntegerAttr. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++--- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 +++---- .../Vector/Transforms/LowerVectorMultiReduction.cpp | 5 +---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 925eb80dbe71e..b96f5c2651bce 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -286,7 +286,7 @@ def Vector_MultiDimReductionOp : Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$source, AnyType:$acc, - I64ArrayAttr:$reduction_dims)>, + DenseI64ArrayAttr:$reduction_dims)>, Results<(outs AnyType:$dest)> { let summary = "Multi-dimensional reduction operation"; let description = [{ @@ -325,8 +325,8 @@ def Vector_MultiDimReductionOp : SmallVector getReductionMask() { SmallVector res(getSourceVectorType().getRank(), false); - for (auto ia : getReductionDims().getAsRange()) - res[ia.getInt()] = true; + for (int64_t dim : getReductionDims()) + res[dim] = true; return res; } static SmallVector getReductionMask( diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ab4485c37e5e7..60b4f93a53ad4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -445,8 +445,7 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder, for (const auto &en : llvm::enumerate(reductionMask)) if (en.value()) reductionDims.push_back(en.index()); - build(builder, result, kind, source, acc, - builder.getI64ArrayAttr(reductionDims)); + build(builder, result, kind, source, acc, reductionDims); } OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) { @@ -467,8 +466,8 @@ LogicalResult MultiDimReductionOp::verify() { Type inferredReturnType; auto sourceScalableDims = getSourceVectorType().getScalableDims(); for (auto it : llvm::enumerate(getSourceVectorType().getShape())) - if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { - return llvm::cast(attr).getValue() == it.index(); + if (!llvm::any_of(getReductionDims(), [&](int64_t dim) { + return dim == static_cast(it.index()); })) { targetShape.push_back(it.value()); scalableDims.push_back(sourceScalableDims[it.index()]); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index ac576ed0b4f09..716da55ba09ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -67,10 +67,7 @@ class InnerOuterDimReductionConversion auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Separate reduction and parallel dims - auto reductionDimsRange = - multiReductionOp.getReductionDims().getAsValueRange(); - auto reductionDims = llvm::to_vector<4>(llvm::map_range( - reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); + ArrayRef reductionDims = multiReductionOp.getReductionDims(); llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); int64_t reductionSize = reductionDims.size(); From 24780a2d410d1a61ecc8b47fefdd668d7b47d184 Mon Sep 17 00:00:00 2001 From: MacDue Date: Sat, 10 Aug 2024 13:37:24 +0100 Subject: [PATCH 2/2] Fix up --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 60b4f93a53ad4..44bd4aa76ffbd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -465,12 +465,14 @@ LogicalResult MultiDimReductionOp::verify() { SmallVector scalableDims; Type inferredReturnType; auto sourceScalableDims = getSourceVectorType().getScalableDims(); - for (auto it : llvm::enumerate(getSourceVectorType().getShape())) - if (!llvm::any_of(getReductionDims(), [&](int64_t dim) { - return dim == static_cast(it.index()); - })) { - targetShape.push_back(it.value()); - scalableDims.push_back(sourceScalableDims[it.index()]); + for (auto [dimIdx, dimSize] : + llvm::enumerate(getSourceVectorType().getShape())) + if (!llvm::any_of(getReductionDims(), + [dimIdx = dimIdx](int64_t reductionDimIdx) { + return reductionDimIdx == static_cast(dimIdx); + })) { + targetShape.push_back(dimSize); + scalableDims.push_back(sourceScalableDims[dimIdx]); } // TODO: update to also allow 0-d vectors when available. if (targetShape.empty())