Skip to content

Commit e9e25f0

Browse files
authored
[mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors (#121458)
This patch enforces a restriction in the Vector dialect: the non-indexed operands of `vector.insert` and `vector.extract` must no longer be 0-D vectors. In other words, rank-0 vector types like `vector<f32>` are disallowed as the source or result. EXAMPLES -------- The following are now **illegal** (note the use of `vector<f32>`): ```mlir %0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32> %1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` Note, this change serves three goals. These are summarised below. ## 1. REDUCED AMBIGUITY By enforcing scalar-only semantics when the result (`vector.extract`) or source (`vector.insert`) are rank-0, we eliminate ambiguity in interpretation. Prior to this patch, both `f32` and `vector<f32>` were accepted. ## 2. MATCH IMPLEMENTATION TO DOCUMENTATION The current behaviour contradicts the documented intent. For example, `vector.extract` states: > Degenerates to an element type if n-k is zero. This patch enforces that intent in code. ## 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT With the stricter semantics in place, it’s natural and consistent to make `vector.insert` behave symmetrically to `vector.extract`, i.e., degenerate the source type to a scalar when n = 0. NOTES FOR REVIEWERS ------------------- 1. Main change is in "VectorOps.cpp", where stricter type checks are implemented. 2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to remove now-illegal examples. 2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now require an additional `vector.extract` when a preceding `vector.transfer_read` generates a rank-0 vector. RELATED RFC ----------- * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
1 parent d2ca10a commit e9e25f0

File tree

5 files changed

+37
-26
lines changed

5 files changed

+37
-26
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,9 @@ def Vector_ExtractOp :
703703
InferTypeOpAdaptorWithIsCompatible]> {
704704
let summary = "extract operation";
705705
let description = [{
706-
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
707-
the proper position. Degenerates to an element type if n-k is zero.
706+
Extracts an (n − k)-D result sub-vector from an n-D source vector at a
707+
specified k-D position. When n = k, the result degenerates to a scalar
708+
element.
708709

709710
Static and dynamic indices must be greater or equal to zero and less than
710711
the size of the corresponding dimension. The result is undefined if any
@@ -716,7 +717,6 @@ def Vector_ExtractOp :
716717
```mlir
717718
%1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32>
718719
%2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32>
719-
%3 = vector.extract %1[]: vector<f32> from vector<f32>
720720
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
721721
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
722722
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
@@ -949,9 +949,9 @@ def Vector_InsertOp :
949949
AllTypesMatch<["dest", "result"]>]> {
950950
let summary = "insert operation";
951951
let description = [{
952-
Takes an n-D source vector, an (n+k)-D destination vector and a k-D position
953-
and inserts the n-D source into the (n+k)-D destination at the proper
954-
position. Degenerates to a scalar or a 0-d vector source type when n = 0.
952+
Inserts an (n - k)-D sub-vector (value-to-store) into an n-D destination
953+
vector at a specified k-D position. When n = 0, value-to-store degenerates
954+
to a scalar element inserted into the n-D destination vector.
955955

956956
Static and dynamic indices must be greater or equal to zero and less than
957957
the size of the corresponding dimension. The result is undefined if any
@@ -963,8 +963,7 @@ def Vector_InsertOp :
963963
```mlir
964964
%2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32>
965965
%5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32>
966-
%8 = vector.insert %6, %7[] : f32 into vector<f32>
967-
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
966+
%11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32>
968967
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
969968
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
970969
```

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,8 @@ struct UnrollTransferReadConversion
13241324
for (int64_t i = 0; i < dimSize; ++i) {
13251325
Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
13261326

1327+
// FIXME: Rename this lambda - it does much more than just
1328+
// in-bounds-check generation.
13271329
vec = generateInBoundsCheck(
13281330
rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
13291331
/*inBoundsCase=*/
@@ -1338,12 +1340,21 @@ struct UnrollTransferReadConversion
13381340
insertionIndices.push_back(rewriter.getIndexAttr(i));
13391341

13401342
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1343+
13411344
auto newXferOp = b.create<vector::TransferReadOp>(
13421345
loc, newXferVecType, xferOp.getBase(), xferIndices,
13431346
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
13441347
xferOp.getPadding(), Value(), inBoundsAttr);
13451348
maybeAssignMask(b, xferOp, newXferOp, i);
1346-
return b.create<vector::InsertOp>(loc, newXferOp, vec,
1349+
1350+
Value valToInser = newXferOp.getResult();
1351+
if (newXferVecType.getRank() == 0) {
1352+
// vector.insert does not accept rank-0 as the non-indexed
1353+
// argument. Extract the scalar before inserting.
1354+
valToInser = b.create<vector::ExtractOp>(loc, valToInser,
1355+
SmallVector<int64_t>());
1356+
}
1357+
return b.create<vector::InsertOp>(loc, valToInser, vec,
13471358
insertionIndices);
13481359
},
13491360
/*outOfBoundsCase=*/

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
13841384
}
13851385

13861386
LogicalResult vector::ExtractOp::verify() {
1387+
if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1388+
if (resTy.getRank() == 0)
1389+
return emitError(
1390+
"expected a scalar instead of a 0-d vector as the result type");
1391+
13871392
// Note: This check must come before getMixedPosition() to prevent a crash.
13881393
auto dynamicMarkersCount =
13891394
llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3211,6 +3216,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
32113216
}
32123217

32133218
LogicalResult InsertOp::verify() {
3219+
if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3220+
if (srcTy.getRank() == 0)
3221+
return emitError(
3222+
"expected a scalar instead of a 0-d vector as the source operand");
3223+
32143224
SmallVector<OpFoldResult> position = getMixedPosition();
32153225
auto destVectorType = getDestVectorType();
32163226
if (position.size() > static_cast<unsigned>(destVectorType.getRank()))

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) {
178178

179179
// -----
180180

181-
func.func @extract_0d(%arg0: vector<f32>) {
182-
// expected-error@+1 {{expected position attribute of rank no greater than vector rank}}
183-
%1 = vector.extract %arg0[0] : f32 from vector<f32>
181+
func.func @extract_0d_result(%arg0: vector<f32>) {
182+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the result type}}
183+
%1 = vector.extract %arg0[] : vector<f32> from vector<f32>
184184
}
185185

186186
// -----
@@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
259259

260260
// -----
261261

262-
func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263-
// expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
264-
%1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
265-
}
266-
267-
// -----
268-
269-
func.func @insert_0d(%a: f32, %b: vector<f32>) {
270-
// expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}}
271-
%1 = vector.insert %a, %b[0] : f32 into vector<f32>
262+
func.func @insert_0d_value_to_store(%a: vector<f32>, %b: vector<4x8x16xf32>) {
263+
// expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
264+
%1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
272265
}
273266

274267
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
300300
}
301301

302302
// CHECK-LABEL: @insert_0d
303-
func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
303+
func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
304304
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
305305
%1 = vector.insert %a, %b[] : f32 into vector<f32>
306-
// CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
307-
%2 = vector.insert %b, %c[0, 1] : vector<f32> into vector<2x3xf32>
308-
return %1, %2 : vector<f32>, vector<2x3xf32>
306+
return %1 : vector<f32>
309307
}
310308

311309
// CHECK-LABEL: @insert_poison_idx

0 commit comments

Comments
 (0)