diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 890a5e9e5c9b4..fcfb401fd9867 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -33,6 +33,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SubsetOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LLVM.h" @@ -2387,9 +2388,129 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, return success(); } +/// Rewrite from_elements on multiple scalar extracts as a shape_cast +/// on a single extract. Example: +/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8> +/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8> +/// %2 = vector.from_elements %0, %1 : vector<2xi8> +/// +/// becomes +/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8> +/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8> +/// +/// The requirements for this to be valid are +/// +/// i) The elements are extracted from the same vector (%source). +/// +/// ii) The elements form a suffix of %source. Specifically, the number +/// of elements is the same as the product of the last N dimension sizes +/// of %source, for some N. +/// +/// iii) The elements are extracted contiguously in ascending order. + +class FromElementsToShapeCast : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FromElementsOp fromElements, + PatternRewriter &rewriter) const override { + + // Handled by `rewriteFromElementsAsSplat` + if (fromElements.getType().getNumElements() == 1) + return failure(); + + // The common source that all elements are extracted from, if one exists. + TypedValue source; + // The position of the combined extract operation, if one is created. + ArrayRef combinedPosition; + // The expected index of extraction of the current element in the loop, if + // elements are extracted contiguously in ascending order. + SmallVector expectedPosition; + + for (auto [insertIndex, element] : + llvm::enumerate(fromElements.getElements())) { + + // Check that the element is from a vector.extract operation. + auto extractOp = + dyn_cast_if_present(element.getDefiningOp()); + if (!extractOp) { + return rewriter.notifyMatchFailure(fromElements, + "element not from vector.extract"); + } + + // Check condition (i) by checking that all elements have the same source + // as the first element. + if (insertIndex == 0) { + source = extractOp.getVector(); + } else if (extractOp.getVector() != source) { + return rewriter.notifyMatchFailure(fromElements, + "element from different vector"); + } + + ArrayRef position = extractOp.getStaticPosition(); + int64_t rank = position.size(); + assert(rank == source.getType().getRank() && + "scalar extract must have full rank position"); + + // Check condition (ii) by checking that the position that the first + // element is extracted from has sufficient trailing 0s. For example, in + // + // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8> + // [...] + // %elms = vector.from_elements %elm0, [...] : vector<12xi8> + // + // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12 + // elements, which is the number of elements of %n, so this is valid. + if (insertIndex == 0) { + const int64_t numElms = fromElements.getType().getNumElements(); + int64_t numSuffixElms = 1; + int64_t index = rank; + while (index > 0 && position[index - 1] == 0 && + numSuffixElms < numElms) { + numSuffixElms *= source.getType().getDimSize(index - 1); + --index; + } + if (numSuffixElms != numElms) { + return rewriter.notifyMatchFailure( + fromElements, "elements do not form a suffix of source"); + } + expectedPosition = llvm::to_vector(position); + combinedPosition = position.drop_back(rank - index); + } + + // Check condition (iii). + else if (expectedPosition != position) { + return rewriter.notifyMatchFailure( + fromElements, "elements not in ascending order (static order)"); + } + increment(expectedPosition, source.getType().getShape()); + } + + auto extracted = rewriter.createOrFold( + fromElements.getLoc(), source, combinedPosition); + + rewriter.replaceOpWithNewOp( + fromElements, fromElements.getType(), extracted); + + return success(); + } + + /// Increments n-D `indices` by 1 starting from the innermost dimension. + static void increment(MutableArrayRef indices, + ArrayRef shape) { + for (int dim : llvm::reverse(llvm::seq(0, indices.size()))) { + indices[dim] += 1; + if (indices[dim] < shape[dim]) + break; + indices[dim] = 0; + } + } +}; + void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(rewriteFromElementsAsSplat); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index a6543aafd1c77..a06a9f67d54dc 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2943,75 +2943,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, // ----- -// CHECK-LABEL: func @extract_scalar_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { - // Extract from 0D. - %0 = vector.from_elements %a : vector - %1 = vector.extract %0[] : f32 from vector - - // Extract from 1D. - %2 = vector.from_elements %a : vector<1xf32> - %3 = vector.extract %2[0] : f32 from vector<1xf32> - %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32> - %5 = vector.extract %4[4] : f32 from vector<5xf32> - - // Extract from 2D. - %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32> - %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> - %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> - %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> - - // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]] - return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 -} - -// ----- - -// CHECK-LABEL: func @extract_1d_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { - %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32> - %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> - // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32> - %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[splat1]], %[[splat2]] - return %1, %2 : vector<3xf32>, vector<3xf32> -} - -// ----- - -// CHECK-LABEL: func @extract_2d_from_from_elements( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { - %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> - // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32> - %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> - // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32> - %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> - // CHECK: return %[[splat1]], %[[splat2]] - return %1, %2 : vector<2x2xf32>, vector<2x2xf32> -} - -// ----- - -// CHECK-LABEL: func @from_elements_to_splat( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) -func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { - // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32> - %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> - // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> - %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> - // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector - %2 = vector.from_elements %a : vector - // CHECK: return %[[splat]], %[[from_el]], %[[splat2]] - return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector -} - -// ----- - // CHECK-LABEL: func @vector_insert_const_regression( // CHECK: llvm.mlir.undef // CHECK: vector.insert diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir new file mode 100644 index 0000000000000..fdab2a8918a2e --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -0,0 +1,268 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s + +// This file contains some tests of folding/canonicalizing vector.from_elements + +///===----------------------------------------------===// +/// Tests of `rewriteFromElementsAsSplat` +///===----------------------------------------------===// + +// CHECK-LABEL: func @extract_scalar_from_from_elements( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { + // Extract from 0D. + %0 = vector.from_elements %a : vector + %1 = vector.extract %0[] : f32 from vector + + // Extract from 1D. + %2 = vector.from_elements %a : vector<1xf32> + %3 = vector.extract %2[0] : f32 from vector<1xf32> + %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32> + %5 = vector.extract %4[4] : f32 from vector<5xf32> + + // Extract from 2D. + %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> + %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32> + %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> + %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> + + // CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]] + return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @extract_1d_from_from_elements( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { + %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> + // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32> + %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> + // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32> + %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> + // CHECK: return %[[SPLAT1]], %[[SPLAT2]] + return %1, %2 : vector<3xf32>, vector<3xf32> +} + +// ----- + +// CHECK-LABEL: func @extract_2d_from_from_elements( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { + %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> + // CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32> + %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> + // CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32> + %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> + // CHECK: return %[[SPLAT1]], %[[SPLAT2]] + return %1, %2 : vector<2x2xf32>, vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @from_elements_to_splat( +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) +func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { + // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32> + %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> + // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> + %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> + // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector + %2 = vector.from_elements %a : vector + // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]] + return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector +} + +// ----- + +///===----------------------------------------------===// +/// Tests of `FromElementsToShapeCast` +///===----------------------------------------------===// + +// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( +// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8> +// CHECK: return %[[EXTRACT]] : vector<2xi8> +func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> + %4 = vector.from_elements %0, %1 : vector<2xi8> + return %4 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3( +// CHECK-SAME: %[[A:.*]]: vector<8xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8> +func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> { + %0 = vector.extract %arg0[0] : i8 from vector<8xi8> + %1 = vector.extract %arg0[1] : i8 from vector<8xi8> + %2 = vector.extract %arg0[2] : i8 from vector<8xi8> + %3 = vector.extract %arg0[3] : i8 from vector<8xi8> + %4 = vector.extract %arg0[4] : i8 from vector<8xi8> + %5 = vector.extract %arg0[5] : i8 from vector<8xi8> + %6 = vector.extract %arg0[6] : i8 from vector<8xi8> + %7 = vector.extract %arg0[7] : i8 from vector<8xi8> + %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8> + return %8 : vector<2x2x2xi8> +} + +// ----- + +// CHECK-LABEL: func @source_larger_than_out( +// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<12xi8> +func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { + %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8> + %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8> + %2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8> + %3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8> + %4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8> + %5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8> + %6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8> + %7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8> + %8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8> + %9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8> + %10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8> + %11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8> + %12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8> + return %12 : vector<12xi8> +} + +// ----- + +// This test is similar to `source_larger_than_out` except here the number of elements +// extracted contigously starting from the first position [0,0] could be 6 instead of 3 +// and the pattern would still match. +// CHECK-LABEL: func @suffix_with_excess_zeros( +// CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8> +// CHECK: return %[[EXT]] : vector<3xi8> +func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> + %2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> + %3 = vector.from_elements %0, %1, %2 : vector<3xi8> + return %3 : vector<3xi8> +} + +// ----- + +// CHECK-LABEL: func @large_source_with_shape_cast_required( +// CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8> +func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> { + %0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8> + %1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8> + %2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8> + %3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8> + %4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8> + return %4 : vector<1x4x1xi8> +} + +// ----- + +// Could match, but handled by `rewriteFromElementsAsSplat`. +// CHECK-LABEL: func @extract_single_elm( +// CHECK-NEXT: vector.extract +// CHECK-NEXT: vector.splat +// CHECK-NEXT: return +func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> + %1 = vector.from_elements %0 : vector<1xi8> + return %1 : vector<1xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_contiguous_but_not_suffix( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> + %1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> + %2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8> + %3 = vector.from_elements %0, %1, %2 : vector<3xi8> + return %3 : vector<3xi8> +} + +// ----- + +// The extracted elements are recombined into a single vector, but in a new order. +// CHECK-LABEL: func @negative_nonascending_order( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_nonstatic_extract( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> { + %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_different_sources( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_not_suffix( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// The inserted elements are a subset of the extracted elements. +// [0, 1, 2] -> [1, 1, 2] +// CHECK-LABEL: func @negative_nobijection_order( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> + %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8> + %2 = vector.from_elements %0, %0, %1 : vector<3xi8> + return %2 : vector<3xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_too_small( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> { + %0 = vector.extract %arg0[0] : i8 from vector<2xi8> + %1 = vector.extract %arg0[1] : i8 from vector<2xi8> + %2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8> + return %2 : vector<4xi8> +} +