From 8192475e41b4d57d361860410e1235a32ba718a1 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Sat, 9 Nov 2024 22:11:20 -0800 Subject: [PATCH 1/2] [mlir][vector] Allow integer indices in vector.extract/insert ops `vector.extract` and `vector.insert` can currently take an `i64` constant or an `index` type value as indices. The `index` type will usually lower to an `i32` or `i64` type. However, we are often indexing really small vector dimensions where smaller integers could be used. This PR extends both ops to accept any integer value as indices. For example: ``` %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32> %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32> %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32> ``` This led to some changes to the ops' parser and printer. When a value index is provided, the index type is printed as part of the index list. All the value indices provided must match that type. When no value index is provided, no index type is printed. --- .../mlir/Dialect/Vector/IR/VectorOps.td | 23 +++-- mlir/include/mlir/IR/OpImplementation.h | 21 +++- .../mlir/Interfaces/ViewLikeInterface.h | 29 +++++- mlir/lib/AsmParser/AsmParserImpl.h | 23 ++++- mlir/lib/AsmParser/Parser.cpp | 11 ++- mlir/lib/AsmParser/Parser.h | 9 +- .../VectorToArmSME/VectorToArmSME.cpp | 10 +- .../Conversion/VectorToSCF/VectorToSCF.cpp | 8 +- mlir/lib/Interfaces/ViewLikeInterface.cpp | 46 +++++++-- .../VectorToArmSME/unsupported.mlir | 10 +- .../VectorToArmSME/vector-to-arm-sme.mlir | 98 +++++++++---------- .../VectorToLLVM/vector-to-llvm.mlir | 16 +-- .../Conversion/VectorToSCF/vector-to-scf.mlir | 8 +- .../VectorToSPIRV/vector-to-spirv.mlir | 12 +-- .../Dialect/ArmSME/outer-product-fusion.mlir | 4 +- .../Dialect/ArmSME/vector-legalization.mlir | 26 ++--- mlir/test/Dialect/Linalg/hoisting.mlir | 4 +- .../Dialect/Linalg/transform-ops-invalid.mlir | 2 +- mlir/test/Dialect/Vector/canonicalize.mlir | 16 +-- mlir/test/Dialect/Vector/invalid.mlir | 65 ++++++++++++ mlir/test/Dialect/Vector/ops.mlir | 51 +++++++--- .../vector-emulate-narrow-type-unaligned.mlir | 24 ++--- 22 files changed, 353 insertions(+), 163 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c5b08d6aa022b..dad08305b2a64 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -695,14 +695,14 @@ def Vector_ExtractOp : %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32> %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32> %3 = vector.extract %1[]: vector from vector - %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> - %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> + %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32> + %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32> ``` }]; let arguments = (ins AnyVectorOfAnyRank:$vector, - Variadic:$dynamic_position, + Variadic:$dynamic_position, DenseI64ArrayAttr:$static_position ); let results = (outs AnyType:$result); @@ -737,7 +737,8 @@ def Vector_ExtractOp : let assemblyFormat = [{ $vector `` - custom($dynamic_position, $static_position) + custom($dynamic_position, $static_position, + type($dynamic_position)) attr-dict `:` type($result) `from` type($vector) }]; @@ -883,15 +884,15 @@ def Vector_InsertOp : %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32> %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> - %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%a, %b, %c : index] : vector into vector<4x8x16xf32> + %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32> ``` }]; let arguments = (ins AnyType:$source, AnyVectorOfAnyRank:$dest, - Variadic:$dynamic_position, + Variadic:$dynamic_position, DenseI64ArrayAttr:$static_position ); let results = (outs AnyVectorOfAnyRank:$result); @@ -926,7 +927,9 @@ def Vector_InsertOp : }]; let assemblyFormat = [{ - $source `,` $dest custom($dynamic_position, $static_position) + $source `,` $dest + custom($dynamic_position, $static_position, + type($dynamic_position)) attr-dict `:` type($source) `into` type($dest) }]; @@ -1344,7 +1347,7 @@ def Vector_TransferReadOp : %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref // Update the temporary gathered slice with the individual element %slice = memref.load %tmp : memref> -> vector<3x4x5xf32> - %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32> + %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32> memref.store %updated, %tmp : memref> }}} // At this point we gathered the elements from the original @@ -1367,7 +1370,7 @@ def Vector_TransferReadOp : %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref %slice = memref.load %tmp : memref> -> vector<3x4x5xf32> // Here we only store to the first element in dimension one - %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32> + %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32> memref.store %updated, %tmp : memref> }} // At this point we gathered the elements from the original diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index a7222794f320b..699dd1da863b6 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -794,16 +794,26 @@ class AsmParser { }; /// Parse a list of comma-separated items with an optional delimiter. If a - /// delimiter is provided, then an empty list is allowed. If not, then at + /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. /// + /// `parseSuffixFn` is an optional function to parse any suffix that can be + /// appended to the comma separated list within the delimiter. + /// /// contextMessage is an optional message appended to "expected '('" sorts of /// diagnostics when parsing the delimeters. - virtual ParseResult + virtual ParseResult parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElementFn, + std::optional> parseSuffixFn = std::nullopt, + StringRef contextMessage = StringRef()) = 0; + ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref parseElementFn, - StringRef contextMessage = StringRef()) = 0; - + StringRef contextMessage) { + return parseCommaSeparatedList(delimiter, parseElementFn, + /*parseSuffixFn=*/std::nullopt, + contextMessage); + } /// Parse a comma separated list of elements that must have at least one entry /// in it. ParseResult @@ -1319,6 +1329,9 @@ class AsmParser { virtual ParseResult parseOptionalColonTypeList(SmallVectorImpl &result) = 0; + /// Parse an optional colon followed by a type. + virtual ParseResult parseOptionalColonType(Type &result) = 0; + /// Parse a keyword followed by a type. ParseResult parseKeywordType(const char *keyword, Type &result) { return failure(parseKeyword(keyword) || parseType(result)); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 3dcbd2f1af193..1971c25a8f20b 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` /// is non-empty, it is expected to contain as many elements as `values` /// indicating their types. This allows idiomatic printing of mixed value and -/// integer attributes in a list. E.g. -/// `[%arg0 : index, 7, 42, %arg42 : i32]`. +/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`. +/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the +/// same and only one type is printed at the end of the list. E.g., +/// `[0, %arg2, 3, %arg42, 2 : i8]`. /// /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. /// This notation is similar to how scalable dims are marked when defining @@ -108,7 +110,8 @@ void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, ArrayRef scalables, TypeRange valueTypes = TypeRange(), - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, + bool hasSameTypeDynamicValues = false); inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, @@ -123,6 +126,13 @@ inline void printDynamicIndexList( return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, delimiter); } +inline void printSameTypeDynamicIndexList( + OpAsmPrinter &printer, Operation *op, OperandRange values, + ArrayRef integers, TypeRange valueTypes = TypeRange(), + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, + delimiter, /*hasSameTypeDynamicValues=*/true); +} /// Parser hook for custom directive in assemblyFormat. /// @@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList( SmallVectorImpl &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl *valueTypes = nullptr, - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, + bool hasSameTypeDynamicValues = false); inline ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl &values, @@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList( return parseDynamicIndexList(parser, values, integers, scalableVals, &valueTypes, delimiter); } +inline ParseResult parseSameTypeDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + DenseBoolArrayAttr scalableVals = {}; + return parseDynamicIndexList(parser, values, integers, scalableVals, + &valueTypes, delimiter, + /*hasSameTypeDynamicValues=*/true); +} /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 04250f63dcd25..4d5b93ec09d17 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT { /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. - ParseResult parseCommaSeparatedList(Delimiter delimiter, - function_ref parseElt, - StringRef contextMessage) override { - return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); + ParseResult parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElt, + std::optional> parseSuffix, + StringRef contextMessage) override { + return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix, + contextMessage); } + using BaseT::parseCommaSeparatedList; + //===--------------------------------------------------------------------===// // Keyword Parsing //===--------------------------------------------------------------------===// @@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT { return parser.parseTypeListNoParens(result); } + /// Parse an optional colon followed by a type. + ParseResult parseOptionalColonType(Type &result) override { + SmallVector types; + ParseResult parseResult = parseOptionalColonTypeList(types); + if (llvm::succeeded(parseResult) && types.size() > 1) + return emitError(getCurrentLocation(), "expected single type"); + if (!types.empty()) + result = types[0]; + return parseResult; + } + ParseResult parseDimensionList(SmallVectorImpl &dimensions, bool allowDynamic, bool withTrailingX) override { diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 8f19487d80fa3..6476910f71eb7 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default; /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. -ParseResult -Parser::parseCommaSeparatedList(Delimiter delimiter, - function_ref parseElementFn, - StringRef contextMessage) { +ParseResult Parser::parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElementFn, + std::optional> parseSuffixFn, + StringRef contextMessage) { switch (delimiter) { case Delimiter::None: break; @@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter, return failure(); } + if (parseSuffixFn && (*parseSuffixFn)()) + return failure(); + switch (delimiter) { case Delimiter::None: return success(); diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index bf91831798056..1ebca05bbcb2e 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -46,10 +46,17 @@ class Parser { /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. + ParseResult parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElementFn, + std::optional> parseSuffixFn = std::nullopt, + StringRef contextMessage = StringRef()); ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref parseElementFn, - StringRef contextMessage = StringRef()); + StringRef contextMessage) { + return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt, + contextMessage); + } /// Parse a comma separated list of elements that must have at least one entry /// in it. diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 55965d9c2a531..c5c3353bf0477 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering /// /// Example: /// ``` -/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> +/// %el = vector.extract %tile[%row, %col : index] : i32 from +/// vector<[4]x[4]xi32> /// ``` /// Becomes: /// ``` /// %slice = arm_sme.extract_tile_slice %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> -/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> +/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32> /// ``` struct VectorExtractToArmSMELowering : public OpRewritePattern { @@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering /// ``` /// %slice = arm_sme.extract_tile_slice %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> -/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> -/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] +/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into +/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice, +/// %tile[%row] /// : vector<[4]xi32> into vector<[4]x[4]xi32> /// ``` struct VectorInsertToArmSMELowering diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 3a4dc806efe97..b623a86c53ee7 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) { /// %vscale = vector.vscale /// %c4_vscale = arith.muli %vscale, %c4 : index /// scf.for %idx = %c0 to %c4_vscale step %c1 { -/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32> -/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32> -/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32> -/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32> +/// %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32> +/// %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32> +/// %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32> +/// %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32> /// %slice_i = affine.apply #map(%idx)[%i] /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32> /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]} diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index ca33636336bf0..8e44ff60eec87 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, ArrayRef scalables, TypeRange valueTypes, - AsmParser::Delimiter delimiter) { + AsmParser::Delimiter delimiter, + bool hasSameTypeDynamicValues) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); printer << leftDelimiter; @@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, printer << "["; if (ShapedType::isDynamic(integer)) { printer << values[dynamicValIdx]; - if (!valueTypes.empty()) + if (!hasSameTypeDynamicValues && !valueTypes.empty()) printer << " : " << valueTypes[dynamicValIdx]; ++dynamicValIdx; } else { @@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, scalableIndexIdx++; }); + if (hasSameTypeDynamicValues && !valueTypes.empty()) { + assert(std::all_of(valueTypes.begin(), valueTypes.end(), + [&](Type type) { return type == valueTypes[0]; }) && + "Expected the same value types"); + printer << " : " << valueTypes[0]; + } + printer << rightDelimiter; } @@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, - SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter) { + SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter, + bool hasSameTypeDynamicValues) { SmallVector integerVals; SmallVector scalableVals; @@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList( if (res.has_value() && succeeded(res.value())) { values.push_back(operand); integerVals.push_back(ShapedType::kDynamic); - if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) + if (!hasSameTypeDynamicValues && valueTypes && + parser.parseColonType(valueTypes->emplace_back())) return failure(); } else { int64_t integer; @@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList( return failure(); return success(); }; + auto parseColonType = [&]() -> ParseResult { + if (hasSameTypeDynamicValues) { + assert(valueTypes && "Expected non-null value types"); + assert(valueTypes->empty() && "Expected no parsed value types"); + + Type dynValType; + if (parser.parseOptionalColonType(dynValType)) + return failure(); + + if (!dynValType && !values.empty()) + return parser.emitError(parser.getNameLoc()) + << "expected a type for dynamic indices"; + if (dynValType) { + if (values.empty()) + return parser.emitError(parser.getNameLoc()) + << "expected no type for constant indices"; + + // Broadcast the single type to all the dynamic values. + valueTypes->append(values.size(), dynValType); + } + } + return success(); + }; if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, - " in dynamic index list")) + parseColonType, " in dynamic index list")) return parser.emitError(parser.getNameLoc()) - << "expected SSA value or integer"; + << "expected a valid list of SSA values or integers"; + integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); return success(); diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir index ff7b4bcb5f65a..c93dbf8836f6c 100644 --- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir +++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir @@ -151,7 +151,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest // CHECK-NOT: arm_sme.store_tile_slice func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { %c0 = arith.constant 0 : index - %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32> vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref return } @@ -202,7 +202,7 @@ func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index { // CHECK-NOT: arm_sve.psel %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1> - %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1> + %slice = vector.extract %mask[%index : index] : vector<[32]xi1> from vector<[4]x[32]xi1> return %slice : vector<[32]xi1> } @@ -215,7 +215,7 @@ func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index { // CHECK-NOT: arm_sve.psel %mask = vector.create_mask %a, %b : vector<4x[8]xi1> - %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1> + %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<4x[8]xi1> return %slice : vector<[8]xi1> } @@ -227,7 +227,7 @@ func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1> { // CHECK-NOT: arm_sve.psel - %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> + %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<[4]x[8]xi1> return %slice : vector<[8]xi1> } @@ -240,6 +240,6 @@ func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index { // CHECK-NOT: arm_sve.psel %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> - %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1> + %el = vector.extract %mask[2, %index : index] : i1 from vector<[4]x[8]xi1> return %el : i1 } diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir index 0f973af799634..6ca19c5746ea1 100644 --- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir +++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir @@ -345,7 +345,7 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb // CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { %c0 = arith.constant 0 : index - %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32> vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref return } @@ -361,7 +361,7 @@ func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref, vector<[4]xi1>, vector<[4]x[4]xf32> func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref, %mask: vector<[4]xi1>, %slice_index: index) { %c0 = arith.constant 0 : index - %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32> vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref return } @@ -927,7 +927,7 @@ func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vect // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-NEXT: arm_sme.insert_tile_slice %[[SLICE]], %[[TILE]][%[[INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[4]xi32> into vector<[4]x[4]xi32> return %new_tile : vector<[4]x[4]xi32> } @@ -937,7 +937,7 @@ func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vect func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[16]xi8> into vector<[16]x[16]xi8> return %new_tile : vector<[16]x[16]xi8> } @@ -947,7 +947,7 @@ func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vecto func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xi16> into vector<[8]x[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xi16> into vector<[8]x[8]xi16> return %new_tile : vector<[8]x[8]xi16> } @@ -957,7 +957,7 @@ func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vect func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xi64> into vector<[2]x[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[2]xi64> into vector<[2]x[2]xi64> return %new_tile : vector<[2]x[2]xi64> } @@ -967,7 +967,7 @@ func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vect func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[1]xi128> into vector<[1]x[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[1]xi128> into vector<[1]x[1]xi128> return %new_tile : vector<[1]x[1]xi128> } @@ -977,7 +977,7 @@ func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> ve func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xf16> into vector<[8]x[8]xf16> return %new_tile : vector<[8]x[8]xf16> } @@ -987,7 +987,7 @@ func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vect func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xbf16> into vector<[8]x[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xbf16> into vector<[8]x[8]xbf16> return %new_tile : vector<[8]x[8]xbf16> } @@ -997,7 +997,7 @@ func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> ve func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[4]xf32> into vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[4]xf32> into vector<[4]x[4]xf32> return %new_tile : vector<[4]x[4]xf32> } @@ -1007,7 +1007,7 @@ func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vect func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xf64> into vector<[2]x[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[2]xf64> into vector<[2]x[2]xf64> return %new_tile : vector<[2]x[2]xf64> } @@ -1020,10 +1020,10 @@ func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vect func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> { // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32> - // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32> + // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]] : index] : i32 into vector<[4]xi32> // CHECK-NEXT: arm_sme.insert_tile_slice %[[NEW_SLICE]], %[[TILE]][%[[ROW]]] : vector<[4]xi32> into vector<[4]x[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i32 into vector<[4]x[4]xi32> return %new_tile : vector<[4]x[4]xi32> } @@ -1035,7 +1035,7 @@ func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[16]xi8> into vector<[16]x[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i8 into vector<[16]x[16]xi8> return %new_tile : vector<[16]x[16]xi8> } @@ -1047,7 +1047,7 @@ func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xi16> into vector<[8]x[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i16 into vector<[8]x[8]xi16> return %new_tile : vector<[8]x[8]xi16> } @@ -1059,7 +1059,7 @@ func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xi64> into vector<[2]x[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i64 into vector<[2]x[2]xi64> return %new_tile : vector<[2]x[2]xi64> } @@ -1071,7 +1071,7 @@ func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> ve // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[1]xi128> into vector<[1]x[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i128 into vector<[1]x[1]xi128> return %new_tile : vector<[1]x[1]xi128> } @@ -1083,7 +1083,7 @@ func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xf16> into vector<[8]x[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16> + %new_tile = vector.insert %el, %tile[%row, %col : index] : f16 into vector<[8]x[8]xf16> return %new_tile : vector<[8]x[8]xf16> } @@ -1095,7 +1095,7 @@ func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> ve // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xbf16> into vector<[8]x[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16> + %new_tile = vector.insert %el, %tile[%row, %col : index] : bf16 into vector<[8]x[8]xbf16> return %new_tile : vector<[8]x[8]xbf16> } @@ -1107,7 +1107,7 @@ func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[4]xf32> into vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32> + %new_tile = vector.insert %el, %tile[%row, %col : index] : f32 into vector<[4]x[4]xf32> return %new_tile : vector<[4]x[4]xf32> } @@ -1119,7 +1119,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xf64> into vector<[2]x[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64> + %new_tile = vector.insert %el, %tile[%row, %col : index] : f64 into vector<[2]x[2]xf64> return %new_tile : vector<[2]x[2]xf64> } @@ -1135,7 +1135,7 @@ func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> { // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK: arm_sme.extract_tile_slice %[[TILE]][%[[INDEX]]] : vector<[4]xi32> from vector<[4]x[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32> + %slice = vector.extract %tile[%row : index] : vector<[4]xi32> from vector<[4]x[4]xi32> return %slice : vector<[4]xi32> } @@ -1145,7 +1145,7 @@ func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> { func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8> + %slice = vector.extract %tile[%row : index] : vector<[16]xi8> from vector<[16]x[16]xi8> return %slice : vector<[16]xi8> } @@ -1155,7 +1155,7 @@ func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> { func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16> + %slice = vector.extract %tile[%row : index] : vector<[8]xi16> from vector<[8]x[8]xi16> return %slice : vector<[8]xi16> } @@ -1165,7 +1165,7 @@ func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> { func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64> + %slice = vector.extract %tile[%row : index] : vector<[2]xi64> from vector<[2]x[2]xi64> return %slice : vector<[2]xi64> } @@ -1175,7 +1175,7 @@ func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> { func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128> + %slice = vector.extract %tile[%row : index] : vector<[1]xi128> from vector<[1]x[1]xi128> return %slice : vector<[1]xi128> } @@ -1185,7 +1185,7 @@ func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> { func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16> + %slice = vector.extract %tile[%row : index] : vector<[8]xf16> from vector<[8]x[8]xf16> return %slice : vector<[8]xf16> } @@ -1195,7 +1195,7 @@ func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> { func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16> + %slice = vector.extract %tile[%row : index] : vector<[8]xbf16> from vector<[8]x[8]xbf16> return %slice : vector<[8]xbf16> } @@ -1205,7 +1205,7 @@ func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> { func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %tile[%row : index] : vector<[4]xf32> from vector<[4]x[4]xf32> return %slice : vector<[4]xf32> } @@ -1215,7 +1215,7 @@ func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> { func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64> + %slice = vector.extract %tile[%row : index] : vector<[2]xf64> from vector<[2]x[2]xf64> return %slice : vector<[2]xf64> } @@ -1227,9 +1227,9 @@ func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> { func.func @vector_extract_element(%row: index, %col: index) -> i32 { // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32> - // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32> + // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]] : index] : i32 from vector<[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32> + %el = vector.extract %tile[%row, %col : index] : i32 from vector<[4]x[4]xi32> return %el : i32 } @@ -1238,9 +1238,9 @@ func.func @vector_extract_element(%row: index, %col: index) -> i32 { // CHECK-LABEL: @vector_extract_element_i8 func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i8 from vector<[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8> + %el = vector.extract %tile[%row, %col : index] : i8 from vector<[16]x[16]xi8> return %el : i8 } @@ -1249,9 +1249,9 @@ func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 { // CHECK-LABEL: @vector_extract_element_i16 func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i16 from vector<[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16> + %el = vector.extract %tile[%row, %col : index] : i16 from vector<[8]x[8]xi16> return %el : i16 } @@ -1260,9 +1260,9 @@ func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 { // CHECK-LABEL: @vector_extract_element_i64 func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i64 from vector<[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64> + %el = vector.extract %tile[%row, %col : index] : i64 from vector<[2]x[2]xi64> return %el : i64 } @@ -1271,9 +1271,9 @@ func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 { // CHECK-LABEL: @vector_extract_element_i128 func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i128 from vector<[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128> + %el = vector.extract %tile[%row, %col : index] : i128 from vector<[1]x[1]xi128> return %el : i128 } @@ -1282,9 +1282,9 @@ func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 { // CHECK-LABEL: @vector_extract_element_f16 func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f16 from vector<[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16> + %el = vector.extract %tile[%row, %col : index] : f16 from vector<[8]x[8]xf16> return %el : f16 } @@ -1293,9 +1293,9 @@ func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 { // CHECK-LABEL: @vector_extract_element_bf16 func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : bf16 from vector<[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16> + %el = vector.extract %tile[%row, %col : index] : bf16 from vector<[8]x[8]xbf16> return %el : bf16 } @@ -1304,9 +1304,9 @@ func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 { // CHECK-LABEL: @vector_extract_element_f32 func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f32 from vector<[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32> + %el = vector.extract %tile[%row, %col : index] : f32 from vector<[4]x[4]xf32> return %el : f32 } @@ -1315,9 +1315,9 @@ func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 { // CHECK-LABEL: @vector_extract_element_f64 func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f64 from vector<[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64> + %el = vector.extract %tile[%row, %col : index] : f64 from vector<[2]x[2]xf64> return %el : f64 } @@ -1335,7 +1335,7 @@ func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: ind // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1> // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1> %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> - %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> + %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<[4]x[8]xi1> return %slice : vector<[8]xi1> } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 03bcb341efea2..953d846dceb69 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1239,7 +1239,7 @@ func.func @extract_scalar_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>) // ----- func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[%arg1]: f32 from vector<16xf32> + %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<16xf32> return %0 : f32 } // CHECK-LABEL: @extract_scalar_from_vec_1d_f32_dynamic_idx @@ -1248,7 +1248,7 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %ar // CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32> func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32> + %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<[16]xf32> return %0 : f32 } // CHECK-LABEL: @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable @@ -1259,7 +1259,7 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16 // ----- func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32> + %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x16xf32> return %0 : f32 } @@ -1269,7 +1269,7 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, % // CHECK: vector.extract func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32> + %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x[16]xf32> return %0 : f32 } @@ -1460,7 +1460,7 @@ func.func @insert_scalar_into_vec_3d_f32_scalable(%arg0: f32, %arg1: vector<4x8x func.func @insert_scalar_into_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg1: f32, %arg2: index) -> vector<16xf32> { - %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32> + %0 = vector.insert %arg1, %arg0[%arg2 : index] : f32 into vector<16xf32> return %0 : vector<16xf32> } @@ -1471,7 +1471,7 @@ func.func @insert_scalar_into_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg func.func @insert_scalar_into_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: f32, %arg2: index) -> vector<[16]xf32> { - %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<[16]xf32> + %0 = vector.insert %arg1, %arg0[%arg2 : index] : f32 into vector<[16]xf32> return %0 : vector<[16]xf32> } @@ -1484,7 +1484,7 @@ func.func @insert_scalar_into_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16] func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: f32, %idx: index) -> vector<1x16xf32> { - %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x16xf32> + %0 = vector.insert %arg1, %arg0[0, %idx : index] : f32 into vector<1x16xf32> return %0 : vector<1x16xf32> } @@ -1495,7 +1495,7 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: f32, %idx: index) -> vector<1x[16]xf32> { - %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x[16]xf32> + %0 = vector.insert %arg1, %arg0[0, %idx : index] : f32 into vector<1x[16]xf32> return %0 : vector<1x[16]xf32> } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 5a6da3a06387a..7d25d2b1c1e99 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -828,10 +828,10 @@ func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: mem // FULL-UNROLL: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index // FULL-UNROLL: scf.for %[[VAL_13:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { // FULL-UNROLL: %[[SLICE_I:.*]] = affine.apply #[[$SLICE_MAP]](%[[VAL_13]]){{\[}}%[[I]]] -// FULL-UNROLL: %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> -// FULL-UNROLL: %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> -// FULL-UNROLL: %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> -// FULL-UNROLL: %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> // FULL-UNROLL: %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32> // FULL-UNROLL: vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 8796f153c4911..dc8272c7c82a7 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -191,7 +191,7 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: return %[[R]] func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 { - %0 = vector.extract %arg0[%id] : f32 from vector<1xf32> + %0 = vector.extract %arg0[%id : index] : f32 from vector<1xf32> return %0: f32 } @@ -202,7 +202,7 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 // CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { - %0 = vector.extract %arg0[%id] : f32 from vector<4xf32> + %0 = vector.extract %arg0[%id : index] : f32 from vector<4xf32> return %0: f32 } @@ -211,7 +211,7 @@ func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { // CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 { %idx = arith.constant 1 : index - %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32> + %0 = vector.extract %arg0[%idx : index] : f32 from vector<4xf32> return %0: f32 } @@ -252,7 +252,7 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] // CHECK: return %[[R]] func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> { - %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32> + %1 = vector.insert %arg1, %arg0[%id : index] : f32 into vector<1xf32> return %1 : vector<1xf32> } @@ -263,7 +263,7 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 // CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32> + %0 = vector.insert %val, %arg0[%id : index] : f32 into vector<4xf32> return %0: vector<4xf32> } @@ -274,7 +274,7 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect // CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { %idx = arith.constant 2 : index - %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32> + %0 = vector.insert %val, %arg0[%idx : index] : f32 into vector<4xf32> return %0: vector<4xf32> } diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir index 9000551783576..bac1c1cb5615e 100644 --- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir +++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir @@ -814,12 +814,12 @@ func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> { // CHECK-LABEL: @non_constant_extract_from_arith_ext( // CHECK-SAME: %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>, // CHECK-SAME: %[[DIM:[a-z0-9]+]]: index -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]] : index] : vector<[8]xi8> from vector<4x[8]xi8> // CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32> // CHECK: return %[[EXTEND]] func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> { %0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32> - %1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32> + %1 = vector.extract %0[%dim : index] : vector<[8]xi32> from vector<4x[8]xi32> return %1 : vector<[8]xi32> } diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 458906a187982..61b6981b194a6 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -179,10 +179,10 @@ func.func @transfer_write_f16_scalable_16x8(%dest: memref, %vec: vector // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] { - // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> + // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]] : index] : vector<[8]xf16> from vector<[8]x[8]xf16> // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> + // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]] : index] : vector<[8]xf16> from vector<[8]x[8]xf16> // CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref // CHECK-NEXT: } // CHECK-NEXT: return @@ -224,20 +224,20 @@ func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref, %dim0: // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1> // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { - // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> + // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]] : index] : vector<[8]xi1> from vector<[8]x[8]xi1> // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> + // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]] : index] : vector<[8]xi1> from vector<[8]x[8]xi1> // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: } %c0 = arith.constant 0 : index @@ -313,16 +313,16 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref, %dest: me // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { - // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: } // CHECK-NEXT: return @@ -399,7 +399,7 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1> // CHECK-NEXT: return %[[EXTRACT]] %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1> - %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> + %extract = vector.extract %mask[%index : index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> return %extract : vector<[4]x[4]xi1> } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 4e1035e038ca5..1f077409a6c66 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -734,7 +734,7 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func.func @hoist_vector_broadcasts // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> { -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]] : index] : vector<4xf32> from vector<3x4xf32> // CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} { // CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32> // CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32> @@ -744,7 +744,7 @@ module attributes {transform.with_named_sequence} { func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> { %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> { - %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32> + %extract = vector.extract %iarg[%pos : index] : vector<4xf32> from vector<3x4xf32> %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32> %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32> scf.yield %broadcast : vector<3x4xf32> diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir index fbebb97a11983..fe108e47d5dd3 100644 --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -88,7 +88,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): %0 = transform.param.constant 2 : i64 -> !transform.param // expected-error@below {{expected ']' in dynamic index list}} - // expected-error@below {{custom op 'transform.structured.vectorize' expected SSA value or integer}} + // expected-error@below {{custom op 'transform.structured.vectorize' expected a valid list of SSA values or integers}} transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param, 2] : !transform.any_op, !transform.param } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 5ae769090dac6..db15a0562ef4e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -126,7 +126,7 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1> // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1> // CHECK-NOT: vector.extract - %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1> + %extract = vector.extract %mask[2, %index : index] : vector<6xi1> from vector<4x4x6xi1> return %extract : vector<6xi1> } @@ -140,7 +140,7 @@ func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %in %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1> // CHECK: arith.constant dense : vector<6xi1> // CHECK-NOT: vector.extract - %extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1> + %extract = vector.extract %mask[0, %index : index] : vector<6xi1> from vector<1x4x6xi1> return %extract : vector<6xi1> } @@ -153,8 +153,8 @@ func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %inde %mask = vector.create_mask %c2, %dim0 : vector<4x6xi1> // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1> - // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1> - %extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1> + // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]] : index] : vector<6xi1> from vector<4x6xi1> + %extract = vector.extract %mask[%index : index] : vector<6xi1> from vector<4x6xi1> return %extract : vector<6xi1> } @@ -167,8 +167,8 @@ func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0 %mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1> // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1> - // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1> - %extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1> + // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]] : index] : vector<4xi1> from vector<2x4x4xi1> + %extract = vector.extract %mask[1, %index0 : index] : vector<4xi1> from vector<2x4x4xi1> return %extract : vector<4xi1> } @@ -1918,10 +1918,10 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, % // CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts // CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index) -// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]] : index] : vector<4xf32> from vector<2x4xf32> // CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32> func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 { - %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32> + %0 = vector.extract %v[%index : index] : vector<4xf32> from vector<2x4xf32> %1 = vector.extract %0[1] : f32 from vector<4xf32> return %1 : f32 } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index d591c60acb64e..ae520c33dcb50 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -148,6 +148,39 @@ func.func @extract_vector_type(%arg0: index) { %1 = vector.extract %arg0[] : index from index } +// ----- +func.func @extract_vector_mixed_index_types(%arg0 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}} + // expected-note@-2 {{prior use here}} + %1 = vector.extract %arg0[%i32_idx, %i8_idx : i8] : f32 from vector<8x16xf32> +} + +// ----- +func.func @extract_vector_index_vals_no_type(%arg0 : vector<8xf32>, + %i32_idx: i32) { + // expected-error@+2 {{expected a type for dynamic indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.extract %arg0[%i32_idx] : f32 from vector<8x16xf32> +} + +// ----- +func.func @extract_vector_index_vals_multiple_types(%arg0 : vector<8xf32>, + %i8_idx : i8, + %i32_idx : i32) { + // expected-error@+2 {{expected single type}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.extract %arg0[%i8_idx, %i32_idx : i8, i32] : f32 from vector<8x16xf32> +} + +// ----- +func.func @extract_vector_index_consts_type(%arg0 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{'vector.extract' expected no type for constant indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.extract %arg0[5, 3 : index] : f32 from vector<8x16xf32> +} + // ----- func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) { @@ -271,6 +304,38 @@ func.func @insert_0d(%a: f32, %b: vector) { %1 = vector.insert %a, %b[0] : f32 into vector } +// ----- +func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}} + // expected-note@-2 {{prior use here}} + %1 = vector.insert %arg0, %arg1[%i32_idx, %i8_idx : i8] : f32 into vector<8x16xf32> +} + +// ----- +func.func @extract_vector_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>, + %i32_idx: i32) { + // expected-error@+2 {{expected a type for dynamic indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.insert %arg0, %arg1[%i32_idx] : f32 into vector<8x16xf32> +} + +// ----- +func.func @extract_vector_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>, + %i8_idx : i8, %i32_idx : i32) { + // expected-error@+2 {{expected single type}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.insert %arg0, %arg1[%i8_idx, %i32_idx : i8, i32] : f32 into vector<8x16xf32> +} + +// ----- +func.func @extract_vector_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{'vector.insert' expected no type for constant indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.insert %arg0, %arg1[5, 3 : index] : f32 into vector<8x16xf32> +} + // ----- func.func @outerproduct_num_operands(%arg0: f32) { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 3baacba9b6124..fb5769e7a61e7 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -224,12 +224,26 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) // CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index) -> (vector<8x16xf32>, vector<16xf32>, f32) { - // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32> - %0 = vector.extract %arg0[%idx] : vector<8x16xf32> from vector<4x8x16xf32> - // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<16xf32> from vector<4x8x16xf32> - %1 = vector.extract %arg0[%idx, %idx] : vector<16xf32> from vector<4x8x16xf32> - // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]]] : f32 from vector<4x8x16xf32> - %2 = vector.extract %arg0[%idx, 5, %idx] : f32 from vector<4x8x16xf32> + // CHECK: vector.extract %[[VEC]][%[[IDX]] : index] : vector<8x16xf32> from vector<4x8x16xf32> + %0 = vector.extract %arg0[%idx : index] : vector<8x16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]] : index] : vector<16xf32> from vector<4x8x16xf32> + %1 = vector.extract %arg0[%idx, %idx : index] : vector<16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]] : index] : f32 from vector<4x8x16xf32> + %2 = vector.extract %arg0[%idx, 5, %idx : index] : f32 from vector<4x8x16xf32> + return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32 +} + +// CHECK-LABEL: @extract_val_int +// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8 +func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, + %i8_idx: i8) + -> (vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract %[[VEC]][%[[I32_IDX]] : i32] : vector<8x16xf32> from vector<4x8x16xf32> + %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> from vector<4x8x16xf32> + %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 from vector<4x8x16xf32> + %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32> return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32 } @@ -274,12 +288,25 @@ func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, // CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { - // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]]] : vector<8x16xf32> into vector<4x8x16xf32> - %0 = vector.insert %c, %res[%idx] : vector<8x16xf32> into vector<4x8x16xf32> - // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]]] : vector<16xf32> into vector<4x8x16xf32> - %1 = vector.insert %b, %res[%idx, %idx] : vector<16xf32> into vector<4x8x16xf32> - // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]]] : f32 into vector<4x8x16xf32> - %2 = vector.insert %a, %res[%idx, 5, %idx] : f32 into vector<4x8x16xf32> + // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]] : index] : vector<8x16xf32> into vector<4x8x16xf32> + %0 = vector.insert %c, %res[%idx : index] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]] : index] : vector<16xf32> into vector<4x8x16xf32> + %1 = vector.insert %b, %res[%idx, %idx : index] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]] : index] : f32 into vector<4x8x16xf32> + %2 = vector.insert %a, %res[%idx, 5, %idx : index] : f32 into vector<4x8x16xf32> + return %2 : vector<4x8x16xf32> +} + +// CHECK-LABEL: @insert_val_int +// CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8 +func.func @insert_val_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + // CHECK: vector.insert %[[C]], %{{.*}}[%[[I32_IDX]] : i32] : vector<8x16xf32> into vector<4x8x16xf32> + %0 = vector.insert %c, %res[%i32_idx : i32] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[B]], %{{.*}}[%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> into vector<4x8x16xf32> + %1 = vector.insert %b, %res[%i8_idx, %i8_idx : i8] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[A]], %{{.*}}[%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 into vector<4x8x16xf32> + %2 = vector.insert %a, %res[%i8_idx, 5, %i8_idx : i8] : f32 into vector<4x8x16xf32> return %2 : vector<4x8x16xf32> } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 0cecaddc5733e..4bc84fcc9c31f 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -91,13 +91,13 @@ func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]] : index] : i2 from vector<8xi2> // ----- @@ -119,13 +119,13 @@ func.func @vector_load_i2_dynamic_indexing_mixed(%idx: index) -> vector<3xi2> { // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]] : index] : i2 from vector<8xi2> // ----- @@ -147,13 +147,13 @@ func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index) // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2> // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]] : index] : i2 from vector<8xi2> // ----- @@ -176,10 +176,10 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2> // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]] : index] : i2 from vector<8xi2> From c854059144d214877b19013b8dcb1f1357f210ea Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Sat, 16 Nov 2024 13:59:36 -0800 Subject: [PATCH 2/2] Feedback --- .../VectorToLLVM/vector-to-llvm.mlir | 68 +++++++++++++++++++ .../VectorToSPIRV/vector-to-spirv.mlir | 44 ++++++++++++ mlir/test/Dialect/Vector/invalid.mlir | 24 +++---- mlir/test/Dialect/Vector/ops.mlir | 31 +++++---- 4 files changed, 142 insertions(+), 25 deletions(-) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 953d846dceb69..acbf0f71b38d2 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1119,6 +1119,38 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 { // CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32> // CHECK: return {{.*}} : f32 +// ----- + +func.func @extract_i32_index(%arg0: vector<16xf32>, %arg1: i32) -> f32 { + %0 = vector.extract %arg0[%arg1 : i32]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_i32_index +// CHECK: llvm.extractelement {{.*}}[{{.*}} : i32] : vector<16xf32> +// CHECK: return {{.*}} : f32 + +// ----- + +func.func @extract_i8_index(%arg0: vector<16xf32>, %arg1: i8) -> f32 { + %0 = vector.extract %arg0[%arg1 : i8]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_i8_index +// CHECK: llvm.extractelement {{.*}}[{{.*}} : i8] : vector<16xf32> +// CHECK: return {{.*}} : f32 + +// ----- + +func.func @extract_i1_index(%arg0: vector<16xf32>, %arg1: i1) -> f32 { + %0 = vector.extract %arg0[%arg1 : i1]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_i1_index +// CHECK: llvm.extractelement {{.*}}[{{.*}} : i1] : vector<16xf32> +// CHECK: return {{.*}} : f32 + +// ----- + func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 { %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32> return %0 : f32 @@ -1247,6 +1279,8 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %ar // CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64 // CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32> +// ----- + func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 { %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<[16]xf32> return %0 : f32 @@ -1268,6 +1302,8 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, % // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx( // CHECK: vector.extract +// ----- + func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 { %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x[16]xf32> return %0 : f32 @@ -1356,6 +1392,38 @@ func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> ve // CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> // CHECK: return {{.*}} : vector<4xf32> +// ----- + +func.func @insert_i32_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i32) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[%arg2 : i32] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: @insert_i32_index +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i32] : vector<4xf32> +// CHECK: return {{.*}} : vector<4xf32> + +// ----- + +func.func @insert_i8_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i8) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[%arg2 : i8] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: @insert_i8_index +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i8] : vector<4xf32> +// CHECK: return {{.*}} : vector<4xf32> + +// ----- + +func.func @insert_i1_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i1) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[%arg2 : i1] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: @insert_i1_index +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i1] : vector<4xf32> +// CHECK: return {{.*}} : vector<4xf32> + +// ----- + func.func @insert_scalar_into_vec_1d_f32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { %0 = vector.insert %arg0, %arg1[3] : f32 into vector<[4]xf32> return %0 : vector<[4]xf32> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index dc8272c7c82a7..7b7f128c1180b 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -206,6 +206,28 @@ func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { return %0: f32 } +// ----- + +// CHECK-LABEL: @extract_i32_index +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @extract_i32_index(%arg0 : vector<4xf32>, %id : i32) -> f32 { + %0 = vector.extract %arg0[%id : i32] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_i8_index +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i8 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i8 +func.func @extract_i8_index(%arg0 : vector<4xf32>, %id : i8) -> f32 { + %0 = vector.extract %arg0[%id : i8] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_dynamic_cst // CHECK-SAME: %[[V:.*]]: vector<4xf32> // CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> @@ -269,6 +291,28 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect // ----- +// CHECK-LABEL: @insert_i32_index +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: i32 +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { + %0 = vector.insert %val, %arg0[%id : i32] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_i8_index +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: i8 +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i8 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i8 +func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : i8) -> vector<4xf32> { + %0 = vector.insert %val, %arg0[%id : i8] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + // CHECK-LABEL: @insert_dynamic_cst // CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32> // CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ae520c33dcb50..90a71b8e52425 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -149,7 +149,7 @@ func.func @extract_vector_type(%arg0: index) { } // ----- -func.func @extract_vector_mixed_index_types(%arg0 : vector<8x16xf32>, +func.func @extract_mixed_index_types(%arg0 : vector<8x16xf32>, %i32_idx: i32, %i8_idx: i8) { // expected-error@+2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}} // expected-note@-2 {{prior use here}} @@ -157,7 +157,7 @@ func.func @extract_vector_mixed_index_types(%arg0 : vector<8x16xf32>, } // ----- -func.func @extract_vector_index_vals_no_type(%arg0 : vector<8xf32>, +func.func @extract_index_vals_no_type(%arg0 : vector<8xf32>, %i32_idx: i32) { // expected-error@+2 {{expected a type for dynamic indices}} // expected-error@+1 {{expected a valid list of SSA values or integers}} @@ -165,7 +165,7 @@ func.func @extract_vector_index_vals_no_type(%arg0 : vector<8xf32>, } // ----- -func.func @extract_vector_index_vals_multiple_types(%arg0 : vector<8xf32>, +func.func @extract_index_vals_multiple_types(%arg0 : vector<8xf32>, %i8_idx : i8, %i32_idx : i32) { // expected-error@+2 {{expected single type}} @@ -174,7 +174,7 @@ func.func @extract_vector_index_vals_multiple_types(%arg0 : vector<8xf32>, } // ----- -func.func @extract_vector_index_consts_type(%arg0 : vector<8x16xf32>, +func.func @extract_index_consts_type(%arg0 : vector<8x16xf32>, %i32_idx: i32, %i8_idx: i8) { // expected-error@+2 {{'vector.extract' expected no type for constant indices}} // expected-error@+1 {{expected a valid list of SSA values or integers}} @@ -305,32 +305,32 @@ func.func @insert_0d(%a: f32, %b: vector) { } // ----- -func.func @extract_vector_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, - %i32_idx: i32, %i8_idx: i8) { +func.func @insert_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { // expected-error@+2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}} // expected-note@-2 {{prior use here}} %1 = vector.insert %arg0, %arg1[%i32_idx, %i8_idx : i8] : f32 into vector<8x16xf32> } // ----- -func.func @extract_vector_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>, - %i32_idx: i32) { +func.func @insert_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>, + %i32_idx: i32) { // expected-error@+2 {{expected a type for dynamic indices}} // expected-error@+1 {{expected a valid list of SSA values or integers}} %1 = vector.insert %arg0, %arg1[%i32_idx] : f32 into vector<8x16xf32> } // ----- -func.func @extract_vector_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>, - %i8_idx : i8, %i32_idx : i32) { +func.func @insert_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>, + %i8_idx : i8, %i32_idx : i32) { // expected-error@+2 {{expected single type}} // expected-error@+1 {{expected a valid list of SSA values or integers}} %1 = vector.insert %arg0, %arg1[%i8_idx, %i32_idx : i8, i32] : f32 into vector<8x16xf32> } // ----- -func.func @extract_vector_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>, - %i32_idx: i32, %i8_idx: i8) { +func.func @insert_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { // expected-error@+2 {{'vector.insert' expected no type for constant indices}} // expected-error@+1 {{expected a valid list of SSA values or integers}} %1 = vector.insert %arg0, %arg1[5, 3 : index] : f32 into vector<8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index fb5769e7a61e7..5cc2ba366febc 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -222,8 +222,8 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) // CHECK-LABEL: @extract_val_idx // CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index -func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index) - -> (vector<8x16xf32>, vector<16xf32>, f32) { +func.func @extract_index_as_index(%arg0: vector<4x8x16xf32>, %idx: index) + -> (vector<8x16xf32>, vector<16xf32>, f32) { // CHECK: vector.extract %[[VEC]][%[[IDX]] : index] : vector<8x16xf32> from vector<4x8x16xf32> %0 = vector.extract %arg0[%idx : index] : vector<8x16xf32> from vector<4x8x16xf32> // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]] : index] : vector<16xf32> from vector<4x8x16xf32> @@ -234,17 +234,19 @@ func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index) } // CHECK-LABEL: @extract_val_int -// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8 -func.func @extract_val_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, - %i8_idx: i8) - -> (vector<8x16xf32>, vector<16xf32>, f32) { +// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8, %[[I1_IDX:.+]]: i1 +func.func @extract_index_as_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, + %i8_idx: i8, %i1_idx: i1) + -> (vector<8x16xf32>, vector<16xf32>, f32, vector<16xf32>) { // CHECK: vector.extract %[[VEC]][%[[I32_IDX]] : i32] : vector<8x16xf32> from vector<4x8x16xf32> %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32> // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> from vector<4x8x16xf32> %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32> // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 from vector<4x8x16xf32> %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32> - return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32 + // CHECK-NEXT: vector.extract %[[VEC]][%[[I1_IDX]], 2 : i1] : vector<16xf32> from vector<4x8x16xf32> + %3 = vector.extract %arg0[%i1_idx, 2 : i1] : vector<16xf32> from vector<4x8x16xf32> + return %0, %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32, vector<16xf32> } // CHECK-LABEL: @extract_0d @@ -286,8 +288,8 @@ func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, // CHECK-LABEL: @insert_val_idx // CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index -func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, - %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { +func.func @insert_index_as_index(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]] : index] : vector<8x16xf32> into vector<4x8x16xf32> %0 = vector.insert %c, %res[%idx : index] : vector<8x16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]] : index] : vector<16xf32> into vector<4x8x16xf32> @@ -298,16 +300,19 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, } // CHECK-LABEL: @insert_val_int -// CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8 -func.func @insert_val_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, - %i32_idx: i32, %i8_idx: i8, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { +// CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8, %[[I1_IDX:.+]]: i1 +func.func @insert_index_as_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8, %i1_idx: i1, %res: vector<4x8x16xf32>) + -> (vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>) { // CHECK: vector.insert %[[C]], %{{.*}}[%[[I32_IDX]] : i32] : vector<8x16xf32> into vector<4x8x16xf32> %0 = vector.insert %c, %res[%i32_idx : i32] : vector<8x16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %[[B]], %{{.*}}[%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> into vector<4x8x16xf32> %1 = vector.insert %b, %res[%i8_idx, %i8_idx : i8] : vector<16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %[[A]], %{{.*}}[%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 into vector<4x8x16xf32> %2 = vector.insert %a, %res[%i8_idx, 5, %i8_idx : i8] : f32 into vector<4x8x16xf32> - return %2 : vector<4x8x16xf32> + // CHECK-NEXT: vector.insert %[[B]], %{{.*}}[%[[I1_IDX]], 2 : i1] : vector<16xf32> into vector<4x8x16xf32> + %3 = vector.insert %b, %res[%i1_idx, 2 : i1] : vector<16xf32> into vector<4x8x16xf32> + return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32> } // CHECK-LABEL: @insert_0d