From c252a3dd92ab1145dc80a847729423366a8f1dea Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 13 May 2025 16:06:45 -0700 Subject: [PATCH 01/10] first commit --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 58 +++++++++++++++ mlir/test/Dialect/Vector/canonicalize.mlir | 69 ------------------ .../canonicalize/vector-from-elements.mlir | 72 +++++++++++++++++++ 3 files changed, 130 insertions(+), 69 deletions(-) create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f6c3c6a61afb6..71844e62baba7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2385,6 +2385,64 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, return success(); } +static LogicalResult +rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, + PatternRewriter &rewriter) { + + mlir::OperandRange elements = fromElementsOp.getElements(); + const size_t nbElements = elements.size(); + assert(nbElements > 0 && "must be at least one element"); + + // https://en.wikipedia.org/wiki/List_of_prime_numbers + const int prime = 5387; + bool pseudoRandomOrder = nbElements < prime; + + Value source; + ArrayRef shape; + for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) { + + // Rather than iterating through the elements in ascending order, we might + // be able to exit quickly if we go through in pseudo-random order. Use + // fact that (i * p) % a is a bijection for i in [0, a) if p is prime and + // a < p. + int currentIndex = + pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements; + Value element = elements[currentIndex]; + + // From an extract on the same source as the other elements. + auto extractOp = + dyn_cast_if_present(element.getDefiningOp()); + if (!extractOp) + return failure(); + Value currentSource = extractOp.getVector(); + if (!source) { + source = currentSource; + shape = extractOp.getSourceVectorType().getShape(); + } else if (currentSource != source) { + return failure(); + } + + ArrayRef position = extractOp.getStaticPosition(); + assert(position.size() == shape.size()); + + int64_t stride{1}; + int64_t offset{0}; + for (auto [pos, size] : + llvm::zip(llvm::reverse(position), llvm::reverse(shape))) { + if (pos == ShapedType::kDynamic) + return failure(); + offset += pos * stride; + stride *= size; + } + if (offset != currentIndex) + return failure(); + } + + // Can replace with a shape_cast. + rewriter.replaceOpWithNewOp(fromElementsOp, + fromElementsOp.getType(), source); +} + void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(rewriteFromElementsAsSplat); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 99f0850000a16..6af517d988360 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, // ----- -// CHECK-LABEL: func @extract_scalar_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { - // Extract from 0D. - %0 = vector.from_elements %a : vector - %1 = vector.extract %0[] : f32 from vector - - // Extract from 1D. - %2 = vector.from_elements %a : vector<1xf32> - %3 = vector.extract %2[0] : f32 from vector<1xf32> - %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32> - %5 = vector.extract %4[4] : f32 from vector<5xf32> - - // Extract from 2D. - %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32> - %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> - %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> - %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> - - // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]] - return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 -} - -// ----- - -// CHECK-LABEL: func @extract_1d_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { - %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32> - %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> - // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32> - %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[splat1]], %[[splat2]] - return %1, %2 : vector<3xf32>, vector<3xf32> -} - -// ----- - -// CHECK-LABEL: func @extract_2d_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { - %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> - // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32> - %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> - // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32> - %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> - // CHECK: return %[[splat1]], %[[splat2]] - return %1, %2 : vector<2x2xf32>, vector<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @from_elements_to_splat( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { - // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32> - %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> - // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> - %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> - // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector - %2 = vector.from_elements %a : vector - // CHECK: return %[[splat]], %[[from_el]], %[[splat2]] - return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector -} - -// ----- - // CHECK-LABEL: func @vector_insert_const_regression( // CHECK: llvm.mlir.undef // CHECK: vector.insert diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir new file mode 100644 index 0000000000000..21ce71473a3cd --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s + +// This file contains some tests of folding/canonicalizing vector.from_elements + +// CHECK-LABEL: func @extract_scalar_from_from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { + // Extract from 0D. + %0 = vector.from_elements %a : vector + %1 = vector.extract %0[] : f32 from vector + + // Extract from 1D. + %2 = vector.from_elements %a : vector<1xf32> + %3 = vector.extract %2[0] : f32 from vector<1xf32> + %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32> + %5 = vector.extract %4[4] : f32 from vector<5xf32> + + // Extract from 2D. + %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> + %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32> + %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> + %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> + + // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]] + return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @extract_1d_from_from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { + %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> + // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32> + %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> + // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32> + %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> + // CHECK: return %[[splat1]], %[[splat2]] + return %1, %2 : vector<3xf32>, vector<3xf32> +} + +// ----- + +// CHECK-LABEL: func @extract_2d_from_from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { + %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> + // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32> + %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> + // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32> + %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> + // CHECK: return %[[splat1]], %[[splat2]] + return %1, %2 : vector<2x2xf32>, vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @from_elements_to_splat( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { + // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32> + %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> + // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> + %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> + // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector + %2 = vector.from_elements %a : vector + // CHECK: return %[[splat]], %[[from_el]], %[[splat2]] + return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector +} + +// ----- From 7f40da6728bb5b197548d3466818869d45ce9720 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 13 May 2025 17:33:41 -0700 Subject: [PATCH 02/10] improvements --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 46 ++++++------ .../canonicalize/vector-from-elements.mlir | 73 +++++++++++++++++++ 2 files changed, 98 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 71844e62baba7..e0ce41e5d6245 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2385,45 +2385,49 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, return success(); } + +/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible. +/// +/// Example: +/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8> +/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8> +/// %2 = vector.from_elements %0, %1 : vector<2xi8> +/// +/// becomes +/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> static LogicalResult rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, PatternRewriter &rewriter) { - mlir::OperandRange elements = fromElementsOp.getElements(); - const size_t nbElements = elements.size(); - assert(nbElements > 0 && "must be at least one element"); - - // https://en.wikipedia.org/wiki/List_of_prime_numbers - const int prime = 5387; - bool pseudoRandomOrder = nbElements < prime; - + // The common source of vector.extract operations (if one exists), as well + // as its shape and rank. Set in the first iteration of the loop over the + // operands of `fromElementsOp`. Value source; ArrayRef shape; - for (size_t elementIndex = 0ULL; elementIndex < nbElements; elementIndex++) { + int64_t rank; - // Rather than iterating through the elements in ascending order, we might - // be able to exit quickly if we go through in pseudo-random order. Use - // fact that (i * p) % a is a bijection for i in [0, a) if p is prime and - // a < p. - int currentIndex = - pseudoRandomOrder ? elementIndex : (elementIndex * prime) % nbElements; - Value element = elements[currentIndex]; + for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) { - // From an extract on the same source as the other elements. + // Check that the element is defined by an extract operation, and that + // the extract is on the same vector as all preceding elements. auto extractOp = dyn_cast_if_present(element.getDefiningOp()); if (!extractOp) return failure(); Value currentSource = extractOp.getVector(); - if (!source) { + if (index == 0) { source = currentSource; shape = extractOp.getSourceVectorType().getShape(); + rank = shape.size(); } else if (currentSource != source) { return failure(); } + // Check that the (linearized) index of extraction is the same as the index + // in the result of `fromElementsOp`. ArrayRef position = extractOp.getStaticPosition(); - assert(position.size() == shape.size()); + if (position.size() != rank) + return failure(); int64_t stride{1}; int64_t offset{0}; @@ -2434,11 +2438,10 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, offset += pos * stride; stride *= size; } - if (offset != currentIndex) + if (offset != index) return failure(); } - // Can replace with a shape_cast. rewriter.replaceOpWithNewOp(fromElementsOp, fromElementsOp.getType(), source); } @@ -2446,6 +2449,7 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(rewriteFromElementsAsSplat); + results.add(rewriteFromElementsAsShapeCast); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index 21ce71473a3cd..fafac4419d719 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -2,6 +2,10 @@ // This file contains some tests of folding/canonicalizing vector.from_elements +///===----------------------------------------------===// +/// Tests of `rewriteFromElementsAsSplat` +///===----------------------------------------------===// + // CHECK-LABEL: func @extract_scalar_from_from_elements( // CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { @@ -70,3 +74,72 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< } // ----- + + +///===----------------------------------------------===// +/// Tests of `rewriteFromElementsAsShapeCast` +///===----------------------------------------------===// + +// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( +// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>) +// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8> +// CHECK: return %[[shape_cast]] : vector<2xi8> +func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> + %4 = vector.from_elements %0, %1 : vector<2xi8> + return %4 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3( +// CHECK-SAME: %[[a:.*]]: vector<8xi8>) +// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8> +// CHECK: return %[[shape_cast]] : vector<2x2x2xi8> +func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> { + %0 = vector.extract %arg0[0] : i8 from vector<8xi8> + %1 = vector.extract %arg0[1] : i8 from vector<8xi8> + %2 = vector.extract %arg0[2] : i8 from vector<8xi8> + %3 = vector.extract %arg0[3] : i8 from vector<8xi8> + %4 = vector.extract %arg0[4] : i8 from vector<8xi8> + %5 = vector.extract %arg0[5] : i8 from vector<8xi8> + %6 = vector.extract %arg0[6] : i8 from vector<8xi8> + %7 = vector.extract %arg0[7] : i8 from vector<8xi8> + %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8> + return %8 : vector<2x2x2xi8> +} + +// ----- + +// The extracted elements are recombined into a single vector, but in a new order. +// CHECK-LABEL: func @negative_nonascending_order( +// CHECK-NOT: shape_cast +func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_nonstatic_extract( +// CHECK-NOT: shape_cast +func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> { + %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_different_sources( +// CHECK-NOT: shape_cast +func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} From 3ce0713d33fef03141239a00033f5bf5c896fc72 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 13 May 2025 17:41:11 -0700 Subject: [PATCH 03/10] apply some polish --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e0ce41e5d6245..e0b406ce0bc46 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2385,8 +2385,7 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, return success(); } - -/// Rewrite a vecor.from_elements as a vector.shape_cast, if possible. +/// Rewrite vector.from_elements as vector.shape_cast, if possible. /// /// Example: /// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8> @@ -2400,8 +2399,8 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, PatternRewriter &rewriter) { // The common source of vector.extract operations (if one exists), as well - // as its shape and rank. Set in the first iteration of the loop over the - // operands of `fromElementsOp`. + // as its shape and rank. These are set in the first iteration of the loop + // over the operands (elements) of `fromElementsOp`. Value source; ArrayRef shape; int64_t rank; @@ -2426,9 +2425,8 @@ rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, // Check that the (linearized) index of extraction is the same as the index // in the result of `fromElementsOp`. ArrayRef position = extractOp.getStaticPosition(); - if (position.size() != rank) - return failure(); - + assert(position.size() == rank && + "scalar extract must have full rank position"); int64_t stride{1}; int64_t offset{0}; for (auto [pos, size] : From 94c5d8c7b3919975f46437c481f9f49e076814f8 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 14 May 2025 11:28:40 -0700 Subject: [PATCH 04/10] fix blindspot where source is larger --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 116 +++++++++++------- .../canonicalize/vector-from-elements.mlir | 11 ++ 2 files changed, 84 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e0b406ce0bc46..7b7f014480ccd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SubsetOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LLVM.h" @@ -2394,60 +2395,89 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, /// /// becomes /// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> -static LogicalResult -rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, - PatternRewriter &rewriter) { +/// +/// The requirements for this to be valid are +/// i) all elements are extracted from the same vector (source), +/// ii) source and from_elements result have the same number of elements, +/// iii) the elements are extracted in ascending order. +/// +/// It might be possible to rewrite vector.from_elements as a single +/// vector.extract if (ii) is not satisifed, or in some cases as a +/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied, +/// this is left for future consideration. +class FromElementsToShapCast : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; - // The common source of vector.extract operations (if one exists), as well - // as its shape and rank. These are set in the first iteration of the loop - // over the operands (elements) of `fromElementsOp`. - Value source; - ArrayRef shape; - int64_t rank; + LogicalResult matchAndRewrite(FromElementsOp fromElements, + PatternRewriter &rewriter) const override { - for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) { + mlir::OperandRange elements = fromElements.getElements(); + assert(!elements.empty() && "must be at least 1 element"); - // Check that the element is defined by an extract operation, and that - // the extract is on the same vector as all preceding elements. - auto extractOp = - dyn_cast_if_present(element.getDefiningOp()); - if (!extractOp) - return failure(); - Value currentSource = extractOp.getVector(); - if (index == 0) { - source = currentSource; - shape = extractOp.getSourceVectorType().getShape(); - rank = shape.size(); - } else if (currentSource != source) { - return failure(); + Value firstElement = elements.front(); + ExtractOp extractOp = + dyn_cast_if_present(firstElement.getDefiningOp()); + if (!extractOp) { + return rewriter.notifyMatchFailure( + fromElements, "first element not from vector.extract"); } + VectorType sourceType = extractOp.getSourceVectorType(); + Value source = extractOp.getVector(); - // Check that the (linearized) index of extraction is the same as the index - // in the result of `fromElementsOp`. - ArrayRef position = extractOp.getStaticPosition(); - assert(position.size() == rank && - "scalar extract must have full rank position"); - int64_t stride{1}; - int64_t offset{0}; - for (auto [pos, size] : - llvm::zip(llvm::reverse(position), llvm::reverse(shape))) { - if (pos == ShapedType::kDynamic) - return failure(); - offset += pos * stride; - stride *= size; + // Check condition (ii). + if (static_cast(sourceType.getNumElements()) != elements.size()) { + return rewriter.notifyMatchFailure(fromElements, + "number of elements differ"); } - if (offset != index) - return failure(); - } - rewriter.replaceOpWithNewOp(fromElementsOp, - fromElementsOp.getType(), source); -} + for (auto [indexMinusOne, element] : + llvm::enumerate(elements.drop_front(1))) { + + extractOp = + dyn_cast_if_present(element.getDefiningOp()); + if (!extractOp) { + return rewriter.notifyMatchFailure(fromElements, + "element not from vector.extract"); + } + Value currentSource = extractOp.getVector(); + // Check condition (i). + if (currentSource != source) { + return rewriter.notifyMatchFailure(fromElements, + "element from different vector"); + } + + ArrayRef position = extractOp.getStaticPosition(); + assert(position.size() == static_cast(sourceType.getRank()) && + "scalar extract must have full rank position"); + int64_t stride{1}; + int64_t offset{0}; + for (auto [pos, size] : llvm::zip(llvm::reverse(position), + llvm::reverse(sourceType.getShape()))) { + if (pos == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure( + fromElements, "elements not in ascending order (dynamic order)"); + } + offset += pos * stride; + stride *= size; + } + // Check condition (iii). + if (offset != static_cast(indexMinusOne + 1)) { + return rewriter.notifyMatchFailure( + fromElements, "elements not in ascending order (static order)"); + } + } + + rewriter.replaceOpWithNewOp(fromElements, + fromElements.getType(), source); + return success(); + } +}; void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(rewriteFromElementsAsSplat); - results.add(rewriteFromElementsAsShapeCast); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index fafac4419d719..2899abb07c97c 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -143,3 +143,14 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi %2 = vector.from_elements %0, %1 : vector<2xi8> return %2 : vector<2xi8> } + +// ----- + +// CHECK-LABEL: func @negative_source_too_large( +// CHECK-NOT: shape_cast +func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} From 28fccebe260d0ddf9bdb48ebde5efe36d7967516 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 14 May 2025 12:40:11 -0700 Subject: [PATCH 05/10] spacing nit --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7b7f014480ccd..1080263ed3eb6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2414,8 +2414,8 @@ class FromElementsToShapCast : public OpRewritePattern { mlir::OperandRange elements = fromElements.getElements(); assert(!elements.empty() && "must be at least 1 element"); - Value firstElement = elements.front(); + ExtractOp extractOp = dyn_cast_if_present(firstElement.getDefiningOp()); if (!extractOp) { From e985a7ef3fcfda4fc435218fba6349611e5deae0 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 14 May 2025 12:47:25 -0700 Subject: [PATCH 06/10] fix test grouping title --- .../Vector/canonicalize/vector-from-elements.mlir | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index 2899abb07c97c..14bf5d9df4783 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s -// This file contains some tests of folding/canonicalizing vector.from_elements +// This file contains some tests of folding/canonicalizing vector.from_elements ///===----------------------------------------------===// /// Tests of `rewriteFromElementsAsSplat` @@ -75,9 +75,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< // ----- - ///===----------------------------------------------===// -/// Tests of `rewriteFromElementsAsShapeCast` +/// Tests of `FromElementsToShapeCast` ///===----------------------------------------------===// // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( @@ -112,7 +111,7 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> // ----- -// The extracted elements are recombined into a single vector, but in a new order. +// The extracted elements are recombined into a single vector, but in a new order. // CHECK-LABEL: func @negative_nonascending_order( // CHECK-NOT: shape_cast func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { @@ -122,7 +121,7 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { return %2 : vector<2xi8> } -// ----- +// ----- // CHECK-LABEL: func @negative_nonstatic_extract( // CHECK-NOT: shape_cast From 74d74f9596c86754675e35509f2aaa2106d5be80 Mon Sep 17 00:00:00 2001 From: James Newling Date: Fri, 16 May 2025 09:12:59 -0700 Subject: [PATCH 07/10] initial pass of review comments. Add additional test --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 81 ++++++++++--------- .../canonicalize/vector-from-elements.mlir | 56 ++++++++----- 2 files changed, 79 insertions(+), 58 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1080263ed3eb6..e671dddad3a8f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2397,13 +2397,13 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, /// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> /// /// The requirements for this to be valid are -/// i) all elements are extracted from the same vector (source), -/// ii) source and from_elements result have the same number of elements, +/// i) source and from_elements result have the same number of elements, +/// ii) all elements are extracted from the same vector (%source), /// iii) the elements are extracted in ascending order. /// /// It might be possible to rewrite vector.from_elements as a single -/// vector.extract if (ii) is not satisifed, or in some cases as a -/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied, +/// vector.extract if (i) is not satisifed, or in some cases as a +/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied, /// this is left for future consideration. class FromElementsToShapCast : public OpRewritePattern { public: @@ -2412,64 +2412,71 @@ class FromElementsToShapCast : public OpRewritePattern { LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { - mlir::OperandRange elements = fromElements.getElements(); - assert(!elements.empty() && "must be at least 1 element"); - Value firstElement = elements.front(); + // The source of the first element. This is initialized in the first + // iteration of the loop over elements. + TypedValue firstElementSource; - ExtractOp extractOp = - dyn_cast_if_present(firstElement.getDefiningOp()); - if (!extractOp) { - return rewriter.notifyMatchFailure( - fromElements, "first element not from vector.extract"); - } - VectorType sourceType = extractOp.getSourceVectorType(); - Value source = extractOp.getVector(); - - // Check condition (ii). - if (static_cast(sourceType.getNumElements()) != elements.size()) { - return rewriter.notifyMatchFailure(fromElements, - "number of elements differ"); - } + for (auto [insertIndex, element] : + llvm::enumerate(fromElements.getElements())) { - for (auto [indexMinusOne, element] : - llvm::enumerate(elements.drop_front(1))) { - - extractOp = + // Check that the element is from a vector.extract operation. + auto extractOp = dyn_cast_if_present(element.getDefiningOp()); if (!extractOp) { return rewriter.notifyMatchFailure(fromElements, "element not from vector.extract"); } + + // Check condition (i) on the first element. As we will check that all + // elements have the same source, we don't need to check condition (i) for + // any other elements. + if (insertIndex == 0) { + firstElementSource = extractOp.getVector(); + if (static_cast( + firstElementSource.getType().getNumElements()) != + fromElements.getType().getNumElements()) { + return rewriter.notifyMatchFailure(fromElements, + "number of elements differ"); + } + } + + // Check condition (ii), by checking that all elements have same source as + // the first element. Value currentSource = extractOp.getVector(); - // Check condition (i). - if (currentSource != source) { + if (currentSource != firstElementSource) { return rewriter.notifyMatchFailure(fromElements, "element from different vector"); } + // Check condition (iii). + // First, get the index that the element is extracted from. + int64_t extractIndex{0}; + int64_t stride{1}; ArrayRef position = extractOp.getStaticPosition(); - assert(position.size() == static_cast(sourceType.getRank()) && + assert(position.size() == + static_cast(firstElementSource.getType().getRank()) && "scalar extract must have full rank position"); - int64_t stride{1}; - int64_t offset{0}; - for (auto [pos, size] : llvm::zip(llvm::reverse(position), - llvm::reverse(sourceType.getShape()))) { + for (auto [pos, size] : + llvm::zip(llvm::reverse(position), + llvm::reverse(firstElementSource.getType().getShape()))) { if (pos == ShapedType::kDynamic) { return rewriter.notifyMatchFailure( fromElements, "elements not in ascending order (dynamic order)"); } - offset += pos * stride; + extractIndex += pos * stride; stride *= size; } - // Check condition (iii). - if (offset != static_cast(indexMinusOne + 1)) { + + // Second, check that the index of extraction from source and insertion in + // from_elements are the same. + if (extractIndex != static_cast(insertIndex)) { return rewriter.notifyMatchFailure( fromElements, "elements not in ascending order (static order)"); } } - rewriter.replaceOpWithNewOp(fromElements, - fromElements.getType(), source); + rewriter.replaceOpWithNewOp( + fromElements, fromElements.getType(), firstElementSource); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index 14bf5d9df4783..49641eced607f 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -7,7 +7,7 @@ ///===----------------------------------------------===// // CHECK-LABEL: func @extract_scalar_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { // Extract from 0D. %0 = vector.from_elements %a : vector @@ -26,50 +26,50 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32 %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> - // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]] + // CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]] return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 } // ----- // CHECK-LABEL: func @extract_1d_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32> + // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32> %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> - // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32> + // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32> %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[splat1]], %[[splat2]] + // CHECK: return %[[SPLAT1]], %[[SPLAT2]] return %1, %2 : vector<3xf32>, vector<3xf32> } // ----- // CHECK-LABEL: func @extract_2d_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> - // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32> + // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32> %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> - // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32> + // CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32> %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> - // CHECK: return %[[splat1]], %[[splat2]] + // CHECK: return %[[SPLAT1]], %[[SPLAT2]] return %1, %2 : vector<2x2xf32>, vector<2x2xf32> } // ----- // CHECK-LABEL: func @from_elements_to_splat( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { - // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32> + // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32> %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> - // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> + // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> - // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector + // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector %2 = vector.from_elements %a : vector - // CHECK: return %[[splat]], %[[from_el]], %[[splat2]] + // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]] return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector } @@ -80,9 +80,9 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< ///===----------------------------------------------===// // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( -// CHECK-SAME: %[[a:.*]]: vector<1x2xi8>) -// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8> -// CHECK: return %[[shape_cast]] : vector<2xi8> +// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<2xi8> func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> @@ -93,9 +93,9 @@ func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { // ----- // CHECK-LABEL: func @to_shape_cast_rank1_to_rank3( -// CHECK-SAME: %[[a:.*]]: vector<8xi8>) -// CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8> -// CHECK: return %[[shape_cast]] : vector<2x2x2xi8> +// CHECK-SAME: %[[A:.*]]: vector<8xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8> func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> { %0 = vector.extract %arg0[0] : i8 from vector<8xi8> %1 = vector.extract %arg0[1] : i8 from vector<8xi8> @@ -153,3 +153,17 @@ func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> { %2 = vector.from_elements %0, %1 : vector<2xi8> return %2 : vector<2xi8> } + +// ----- + +// The inserted elements are are a subset of the extracted elements. +// [0, 1, 2] -> [1, 1, 2] +// CHECK-LABEL: func @negative_nobijection_order( +// CHECK-NOT: shape_cast +func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> + %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8> + %2 = vector.from_elements %0, %0, %1 : vector<3xi8> + return %2 : vector<3xi8> +} + From c584ac8b13076132af105f592a91875b5cb3fc6d Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 19 May 2025 09:24:40 -0700 Subject: [PATCH 08/10] additional tests --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 108 ++++++++++++------ .../canonicalize/vector-from-elements.mlir | 36 +++++- 2 files changed, 104 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e671dddad3a8f..ddc01dfd9f1fa 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2379,6 +2379,7 @@ std::optional> FMAOp::getShapeForUnroll() { /// ==> rewrite to vector.splat %a : vector<3xf32> static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter) { + if (!llvm::all_equal(fromElementsOp.getElements())) return failure(); rewriter.replaceOpWithNewOp(fromElementsOp, fromElementsOp.getType(), @@ -2386,35 +2387,44 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, return success(); } -/// Rewrite vector.from_elements as vector.shape_cast, if possible. +/// Rewrite vector.from_elements(vector.extract, vector.extract, ...) as +/// vector.shape_cast(vector.extact) if possible. /// /// Example: -/// %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8> -/// %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8> +/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8> +/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8> /// %2 = vector.from_elements %0, %1 : vector<2xi8> /// /// becomes -/// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> +/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8> +/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8> /// /// The requirements for this to be valid are -/// i) source and from_elements result have the same number of elements, -/// ii) all elements are extracted from the same vector (%source), -/// iii) the elements are extracted in ascending order. /// -/// It might be possible to rewrite vector.from_elements as a single -/// vector.extract if (i) is not satisifed, or in some cases as a -/// a single vector_extract_strided_slice if (i) and (iii) are not satisfied, -/// this is left for future consideration. -class FromElementsToShapCast : public OpRewritePattern { +/// i) all elements are extracted from the same vector (%source) +/// ii) the elements form a suffix of %source +/// iii) the elements are extracted contiguously in ascending order + +class FromElementsToShapeCast + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { - // The source of the first element. This is initialized in the first - // iteration of the loop over elements. + // Left for `rewriteFromElementsAsSplat` to avoid divergent + // canonicalizations + if (fromElements.getType().getNumElements() == 1) { + return failure(); + } + + // The source of the first element, the position (N-d vector) that the first + // element is extracted from, and the flattened position (index). These are + // all obtained in the first iteration of the loop over elements. TypedValue firstElementSource; + ArrayRef firstElementExtractPosition; + int64_t firstElementExtractIndex; for (auto [insertIndex, element] : llvm::enumerate(fromElements.getElements())) { @@ -2427,32 +2437,19 @@ class FromElementsToShapCast : public OpRewritePattern { "element not from vector.extract"); } - // Check condition (i) on the first element. As we will check that all - // elements have the same source, we don't need to check condition (i) for - // any other elements. + // Check condition (i) by checking that all elements have same source as + // the first element. if (insertIndex == 0) { firstElementSource = extractOp.getVector(); - if (static_cast( - firstElementSource.getType().getNumElements()) != - fromElements.getType().getNumElements()) { - return rewriter.notifyMatchFailure(fromElements, - "number of elements differ"); - } - } - - // Check condition (ii), by checking that all elements have same source as - // the first element. - Value currentSource = extractOp.getVector(); - if (currentSource != firstElementSource) { + } else if (extractOp.getVector() != firstElementSource) { return rewriter.notifyMatchFailure(fromElements, "element from different vector"); } - // Check condition (iii). - // First, get the index that the element is extracted from. + // Obtain the flattened index of extraction from the N-d position. + ArrayRef position = extractOp.getStaticPosition(); int64_t extractIndex{0}; int64_t stride{1}; - ArrayRef position = extractOp.getStaticPosition(); assert(position.size() == static_cast(firstElementSource.getType().getRank()) && "scalar extract must have full rank position"); @@ -2467,16 +2464,51 @@ class FromElementsToShapCast : public OpRewritePattern { stride *= size; } - // Second, check that the index of extraction from source and insertion in - // from_elements are the same. - if (extractIndex != static_cast(insertIndex)) { + // Check condition (ii) using the extraction index of the first element. + // We check that the position that the first element is extracted + // from has sufficient trailing 0s. For example, in + // ``` + // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8> + // [...] + // %n = vector.from_elements %elm0, [...] : vector<12xi8> + // ``` + // The 2 trailing 0s in the position of extraction of %0 cover 3*4 = 12 + // elements, which is the number of elements of %n, so this is valid. + if (insertIndex == 0) { + const int64_t numFinalElements = + fromElements.getType().getNumElements(); + int64_t numElementsInSourceSuffix = 1; + int index = position.size(); + while (index > 0 && position[index - 1] == 0 && + numElementsInSourceSuffix < numFinalElements) { + numElementsInSourceSuffix *= + firstElementSource.getType().getDimSize(index - 1); + --index; + } + if (numElementsInSourceSuffix != numFinalElements) { + return rewriter.notifyMatchFailure( + fromElements, "elements do not form a suffix of source"); + } + firstElementExtractIndex = extractIndex; + firstElementExtractPosition = + position.drop_back(position.size() - index); + } + + // Check condition (iii) by checking the index of extraction relative + // the first element. + else if (static_cast(insertIndex) + firstElementExtractIndex != + extractIndex) { return rewriter.notifyMatchFailure( fromElements, "elements not in ascending order (static order)"); } } - rewriter.replaceOpWithNewOp( - fromElements, fromElements.getType(), firstElementSource); + auto extracted = rewriter.createOrFold( + fromElements.getLoc(), firstElementSource, firstElementExtractPosition); + + rewriter.replaceOpWithNewOp( + fromElements, fromElements.getType(), extracted); + return success(); } }; @@ -2484,7 +2516,7 @@ class FromElementsToShapCast : public OpRewritePattern { void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(rewriteFromElementsAsSplat); - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index 49641eced607f..ef7bfacdf1a5e 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt --mlir-disable-threading %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s // This file contains some tests of folding/canonicalizing vector.from_elements @@ -109,6 +109,39 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> return %8 : vector<2x2x2xi8> } + +// ----- + +// func.func @bar(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { +// %0 = vector.extract %arg0[1] : vector<3x4xi8> from vector<2x3x4xi8> +// %1 = vector.shape_cast %0 : vector<3x4xi8> to vector<12xi8> +// return %1 : vector<12xi8> + +// CHECK-LABEL: func @source_larger_than_out( +// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]] [1] : vector<3x4xi8> from vector<2x3x4xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<12xi8> + +func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { + %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8> + %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8> + %2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8> + %3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8> + %4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8> + %5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8> + %6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8> + %7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8> + %8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8> + %9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8> + %10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8> + %11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8> + %12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8> + return %12 : vector<12xi8> +} + +// TODO(newling) add more tests where the source is not the same size as out. + // ----- // The extracted elements are recombined into a single vector, but in a new order. @@ -166,4 +199,3 @@ func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> { %2 = vector.from_elements %0, %0, %1 : vector<3xi8> return %2 : vector<3xi8> } - From e154b2248a89d34a6936476761be07e606a226b0 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 20 May 2025 11:29:12 -0700 Subject: [PATCH 09/10] generalize to extract cast --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 117 ++++++++--------- .../canonicalize/vector-from-elements.mlir | 121 ++++++++++++++---- 2 files changed, 149 insertions(+), 89 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ddc01dfd9f1fa..311c2b6387433 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2387,10 +2387,8 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, return success(); } -/// Rewrite vector.from_elements(vector.extract, vector.extract, ...) as -/// vector.shape_cast(vector.extact) if possible. -/// -/// Example: +/// Rewrite from_elements on multiple scalar extracts as a shape_cast +/// on a single extract. Example: /// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8> /// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8> /// %2 = vector.from_elements %0, %1 : vector<2xi8> @@ -2401,30 +2399,32 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, /// /// The requirements for this to be valid are /// -/// i) all elements are extracted from the same vector (%source) -/// ii) the elements form a suffix of %source -/// iii) the elements are extracted contiguously in ascending order +/// i) The elements are extracted from the same vector (%source). +/// +/// ii) The elements form a suffix of %source. Specifically, the number +/// of elements is the same as the product of the last N dimension sizes +/// of %source, for some N. +/// +/// iii) The elements are extracted contiguously in ascending order. + +class FromElementsToShapeCast : public OpRewritePattern { -class FromElementsToShapeCast - : public OpRewritePattern { -public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { - // Left for `rewriteFromElementsAsSplat` to avoid divergent - // canonicalizations - if (fromElements.getType().getNumElements() == 1) { + // Handled by `rewriteFromElementsAsSplat` + if (fromElements.getType().getNumElements() == 1) return failure(); - } - // The source of the first element, the position (N-d vector) that the first - // element is extracted from, and the flattened position (index). These are - // all obtained in the first iteration of the loop over elements. - TypedValue firstElementSource; - ArrayRef firstElementExtractPosition; - int64_t firstElementExtractIndex; + // The common source that all elements are extracted from, if one exists. + TypedValue source; + // The position of the combined extract operation, if one is created. + ArrayRef combinedPosition; + // The expected index of extraction of the current element in the loop, if + // elements are extracted contiguously in ascending order. + SmallVector expectedPosition; for (auto [insertIndex, element] : llvm::enumerate(fromElements.getElements())) { @@ -2440,77 +2440,70 @@ class FromElementsToShapeCast // Check condition (i) by checking that all elements have same source as // the first element. if (insertIndex == 0) { - firstElementSource = extractOp.getVector(); - } else if (extractOp.getVector() != firstElementSource) { + source = extractOp.getVector(); + } else if (extractOp.getVector() != source) { return rewriter.notifyMatchFailure(fromElements, "element from different vector"); } - // Obtain the flattened index of extraction from the N-d position. ArrayRef position = extractOp.getStaticPosition(); - int64_t extractIndex{0}; - int64_t stride{1}; - assert(position.size() == - static_cast(firstElementSource.getType().getRank()) && + int64_t rank = position.size(); + assert(rank == source.getType().getRank() && "scalar extract must have full rank position"); - for (auto [pos, size] : - llvm::zip(llvm::reverse(position), - llvm::reverse(firstElementSource.getType().getShape()))) { - if (pos == ShapedType::kDynamic) { - return rewriter.notifyMatchFailure( - fromElements, "elements not in ascending order (dynamic order)"); - } - extractIndex += pos * stride; - stride *= size; - } - // Check condition (ii) using the extraction index of the first element. - // We check that the position that the first element is extracted - // from has sufficient trailing 0s. For example, in - // ``` - // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8> - // [...] - // %n = vector.from_elements %elm0, [...] : vector<12xi8> - // ``` - // The 2 trailing 0s in the position of extraction of %0 cover 3*4 = 12 + // Check condition (ii) by checking that the position that the first + // element is extracted from has sufficient trailing 0s. For example, in + // + // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8> + // [...] + // %elms = vector.from_elements %elm0, [...] : vector<12xi8> + // + // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12 // elements, which is the number of elements of %n, so this is valid. if (insertIndex == 0) { - const int64_t numFinalElements = - fromElements.getType().getNumElements(); - int64_t numElementsInSourceSuffix = 1; - int index = position.size(); + const int64_t numElms = fromElements.getType().getNumElements(); + int64_t numSuffixElms = 1; + int64_t index = rank; while (index > 0 && position[index - 1] == 0 && - numElementsInSourceSuffix < numFinalElements) { - numElementsInSourceSuffix *= - firstElementSource.getType().getDimSize(index - 1); + numSuffixElms < numElms) { + numSuffixElms *= source.getType().getDimSize(index - 1); --index; } - if (numElementsInSourceSuffix != numFinalElements) { + if (numSuffixElms != numElms) { return rewriter.notifyMatchFailure( fromElements, "elements do not form a suffix of source"); } - firstElementExtractIndex = extractIndex; - firstElementExtractPosition = - position.drop_back(position.size() - index); + expectedPosition = llvm::to_vector(position); + combinedPosition = position.drop_back(rank - index); } - // Check condition (iii) by checking the index of extraction relative - // the first element. - else if (static_cast(insertIndex) + firstElementExtractIndex != - extractIndex) { + // Check condition (iii). + else if (expectedPosition != position) { return rewriter.notifyMatchFailure( fromElements, "elements not in ascending order (static order)"); } + increment(expectedPosition, source.getType().getShape()); } auto extracted = rewriter.createOrFold( - fromElements.getLoc(), firstElementSource, firstElementExtractPosition); + fromElements.getLoc(), source, combinedPosition); rewriter.replaceOpWithNewOp( fromElements, fromElements.getType(), extracted); return success(); } + + /// Increments n-D `indices` by 1 starting from the innermost dimension. + static void increment(MutableArrayRef indices, + ArrayRef shape) { + for (int dim : llvm::reverse(llvm::seq(0, indices.size()))) { + indices[dim] += 1; + if (indices[dim] < shape[dim]) + break; + indices[dim] = 0; + } + } }; void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index ef7bfacdf1a5e..fdab2a8918a2e 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --mlir-disable-threading %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s // This file contains some tests of folding/canonicalizing vector.from_elements @@ -7,7 +7,7 @@ ///===----------------------------------------------===// // CHECK-LABEL: func @extract_scalar_from_from_elements( -// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { // Extract from 0D. %0 = vector.from_elements %a : vector @@ -33,7 +33,7 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32 // ----- // CHECK-LABEL: func @extract_1d_from_from_elements( -// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32> @@ -47,7 +47,7 @@ func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, ve // ----- // CHECK-LABEL: func @extract_2d_from_from_elements( -// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32> @@ -61,7 +61,7 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, // ----- // CHECK-LABEL: func @from_elements_to_splat( -// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32> %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> @@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( // CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> -// CHECK: return %[[SHAPE_CAST]] : vector<2xi8> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8> +// CHECK: return %[[EXTRACT]] : vector<2xi8> func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> @@ -109,20 +109,13 @@ func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> return %8 : vector<2x2x2xi8> } - // ----- -// func.func @bar(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { -// %0 = vector.extract %arg0[1] : vector<3x4xi8> from vector<2x3x4xi8> -// %1 = vector.shape_cast %0 : vector<3x4xi8> to vector<12xi8> -// return %1 : vector<12xi8> - // CHECK-LABEL: func @source_larger_than_out( -// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>) -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]] [1] : vector<3x4xi8> from vector<2x3x4xi8> -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8> -// CHECK: return %[[SHAPE_CAST]] : vector<12xi8> - +// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<12xi8> func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8> %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8> @@ -140,13 +133,70 @@ func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { return %12 : vector<12xi8> } -// TODO(newling) add more tests where the source is not the same size as out. +// ----- + +// This test is similar to `source_larger_than_out` except here the number of elements +// extracted contigously starting from the first position [0,0] could be 6 instead of 3 +// and the pattern would still match. +// CHECK-LABEL: func @suffix_with_excess_zeros( +// CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8> +// CHECK: return %[[EXT]] : vector<3xi8> +func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> + %2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> + %3 = vector.from_elements %0, %1, %2 : vector<3xi8> + return %3 : vector<3xi8> +} + +// ----- + +// CHECK-LABEL: func @large_source_with_shape_cast_required( +// CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8> +func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> { + %0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8> + %1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8> + %2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8> + %3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8> + %4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8> + return %4 : vector<1x4x1xi8> +} + +// ----- + +// Could match, but handled by `rewriteFromElementsAsSplat`. +// CHECK-LABEL: func @extract_single_elm( +// CHECK-NEXT: vector.extract +// CHECK-NEXT: vector.splat +// CHECK-NEXT: return +func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> + %1 = vector.from_elements %0 : vector<1xi8> + return %1 : vector<1xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_contiguous_but_not_suffix( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> + %1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> + %2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8> + %3 = vector.from_elements %0, %1, %2 : vector<3xi8> + return %3 : vector<3xi8> +} // ----- // The extracted elements are recombined into a single vector, but in a new order. // CHECK-LABEL: func @negative_nonascending_order( -// CHECK-NOT: shape_cast +// CHECK-NOT: shape_cast +// CHECK: from_elements func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> @@ -157,7 +207,8 @@ func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { // ----- // CHECK-LABEL: func @negative_nonstatic_extract( -// CHECK-NOT: shape_cast +// CHECK-NOT: shape_cast +// CHECK: from_elements func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> { %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8> %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8> @@ -168,7 +219,8 @@ func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : // ----- // CHECK-LABEL: func @negative_different_sources( -// CHECK-NOT: shape_cast +// CHECK-NOT: shape_cast +// CHECK: from_elements func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8> @@ -178,9 +230,10 @@ func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi // ----- -// CHECK-LABEL: func @negative_source_too_large( -// CHECK-NOT: shape_cast -func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> { +// CHECK-LABEL: func @negative_source_not_suffix( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8> %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> %2 = vector.from_elements %0, %1 : vector<2xi8> @@ -189,13 +242,27 @@ func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> { // ----- -// The inserted elements are are a subset of the extracted elements. +// The inserted elements are a subset of the extracted elements. // [0, 1, 2] -> [1, 1, 2] // CHECK-LABEL: func @negative_nobijection_order( -// CHECK-NOT: shape_cast +// CHECK-NOT: shape_cast +// CHECK: from_elements func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> { %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8> %2 = vector.from_elements %0, %0, %1 : vector<3xi8> return %2 : vector<3xi8> } + +// ----- + +// CHECK-LABEL: func @negative_source_too_small( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> { + %0 = vector.extract %arg0[0] : i8 from vector<2xi8> + %1 = vector.extract %arg0[1] : i8 from vector<2xi8> + %2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8> + return %2 : vector<4xi8> +} + From dd3e231eefc4a9c6e9bf000af149537d87fa8b4d Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 29 May 2025 09:57:43 -0700 Subject: [PATCH 10/10] uber refinement (empty line removal and definite article use..) --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 311c2b6387433..9b6e17cd5abbc 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2379,7 +2379,6 @@ std::optional> FMAOp::getShapeForUnroll() { /// ==> rewrite to vector.splat %a : vector<3xf32> static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, PatternRewriter &rewriter) { - if (!llvm::all_equal(fromElementsOp.getElements())) return failure(); rewriter.replaceOpWithNewOp(fromElementsOp, fromElementsOp.getType(), @@ -2437,8 +2436,8 @@ class FromElementsToShapeCast : public OpRewritePattern { "element not from vector.extract"); } - // Check condition (i) by checking that all elements have same source as - // the first element. + // Check condition (i) by checking that all elements have the same source + // as the first element. if (insertIndex == 0) { source = extractOp.getVector(); } else if (extractOp.getVector() != source) {