@@ -2873,7 +2873,8 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2873
2873
static LogicalResult
2874
2874
verifyTransferOp (VectorTransferOpInterface op, ShapedType shapedType,
2875
2875
VectorType vectorType, VectorType maskType,
2876
- AffineMap permutationMap, ArrayAttr inBounds) {
2876
+ VectorType inferredMaskType, AffineMap permutationMap,
2877
+ ArrayAttr inBounds) {
2877
2878
if (op->hasAttr (" masked" )) {
2878
2879
return op->emitOpError (" masked attribute has been removed. "
2879
2880
" Use in_bounds instead." );
@@ -2926,13 +2927,6 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2926
2927
if (permutationMap.getNumResults () != vectorType.getRank ())
2927
2928
return op->emitOpError (" requires a permutation_map with result dims of "
2928
2929
" the same rank as the vector type" );
2929
-
2930
- VectorType expectedMaskType =
2931
- vector::detail::transferMaskType (vectorType, permutationMap);
2932
- if (maskType && expectedMaskType != maskType)
2933
- return op->emitOpError (" expects mask type consistent with permutation "
2934
- " map: " )
2935
- << maskType;
2936
2930
}
2937
2931
2938
2932
if (permutationMap.getNumSymbols () != 0 )
@@ -2942,6 +2936,11 @@ verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2942
2936
return op->emitOpError (" requires a permutation_map with input dims of the "
2943
2937
" same rank as the source type" );
2944
2938
2939
+ if (maskType && maskType != inferredMaskType)
2940
+ return op->emitOpError (" inferred mask type (" )
2941
+ << inferredMaskType << " ) and mask operand type (" << maskType
2942
+ << " ) don't match" ;
2943
+
2945
2944
if (inBounds) {
2946
2945
if (permutationMap.getNumResults () != static_cast <int64_t >(inBounds.size ()))
2947
2946
return op->emitOpError (" expects the optional in_bounds attr of same rank "
@@ -2984,6 +2983,19 @@ void TransferReadOp::print(OpAsmPrinter &p) {
2984
2983
p << " : " << getShapedType () << " , " << getVectorType ();
2985
2984
}
2986
2985
2986
+ // / Infers the mask type for a transfer read given its vector type and
2987
+ // / permutation map. The mask in a transfer read operation applies to the
2988
+ // / tensor/buffer reading part of it and its type should match the shape read
2989
+ // / *before* any permutation or broadcasting.
2990
+ static VectorType inferTransferReadMaskType (VectorType vecType,
2991
+ AffineMap permMap) {
2992
+ auto i1Type = IntegerType::get (permMap.getContext (), 1 );
2993
+ AffineMap invPermMap = inversePermutation (compressUnusedDims (permMap));
2994
+ assert (invPermMap && " Inversed permutation map couldn't be computed" );
2995
+ SmallVector<int64_t , 8 > maskShape = invPermMap.compose (vecType.getShape ());
2996
+ return VectorType::get (maskShape, i1Type);
2997
+ }
2998
+
2987
2999
ParseResult TransferReadOp::parse (OpAsmParser &parser, OperationState &result) {
2988
3000
auto &builder = parser.getBuilder ();
2989
3001
SMLoc typesLoc;
@@ -3014,13 +3026,14 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3014
3026
VectorType vectorType = types[1 ].dyn_cast <VectorType>();
3015
3027
if (!vectorType)
3016
3028
return parser.emitError (typesLoc, " requires vector type" );
3017
- auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName ();
3018
- Attribute mapAttr = result.attributes .get (permutationAttrName);
3019
- if (!mapAttr) {
3020
- auto permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3021
- // Update `mapAttr` that is used later to determine mask type.
3022
- mapAttr = AffineMapAttr::get (permMap);
3023
- result.attributes .set (permutationAttrName, mapAttr);
3029
+ auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName ();
3030
+ Attribute permMapAttr = result.attributes .get (permMapAttrName);
3031
+ AffineMap permMap;
3032
+ if (!permMapAttr) {
3033
+ permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3034
+ result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
3035
+ } else {
3036
+ permMap = permMapAttr.cast <AffineMapAttr>().getValue ();
3024
3037
}
3025
3038
if (parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
3026
3039
parser.resolveOperands (indexInfo, indexType, result.operands ) ||
@@ -3031,10 +3044,9 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3031
3044
if (shapedType.getElementType ().dyn_cast <VectorType>())
3032
3045
return parser.emitError (
3033
3046
maskInfo.location , " does not support masks with vector element type" );
3034
- auto map = mapAttr.dyn_cast <AffineMapAttr>().getValue ();
3035
3047
// Instead of adding the mask type as an op type, compute it based on the
3036
3048
// vector type and the permutation map (to keep the type signature small).
3037
- auto maskType = mlir::vector::detail::transferMaskType (vectorType, map );
3049
+ auto maskType = inferTransferReadMaskType (vectorType, permMap );
3038
3050
if (parser.resolveOperand (maskInfo, maskType, result.operands ))
3039
3051
return failure ();
3040
3052
}
@@ -3052,13 +3064,17 @@ LogicalResult TransferReadOp::verify() {
3052
3064
VectorType maskType = getMaskType ();
3053
3065
auto paddingType = getPadding ().getType ();
3054
3066
auto permutationMap = getPermutationMap ();
3067
+ VectorType inferredMaskType =
3068
+ maskType ? inferTransferReadMaskType (vectorType, permutationMap)
3069
+ : VectorType ();
3055
3070
auto sourceElementType = shapedType.getElementType ();
3056
3071
3057
3072
if (static_cast <int64_t >(getIndices ().size ()) != shapedType.getRank ())
3058
3073
return emitOpError (" requires " ) << shapedType.getRank () << " indices" ;
3059
3074
3060
3075
if (failed (verifyTransferOp (cast<VectorTransferOpInterface>(getOperation ()),
3061
- shapedType, vectorType, maskType, permutationMap,
3076
+ shapedType, vectorType, maskType,
3077
+ inferredMaskType, permutationMap,
3062
3078
getInBounds () ? *getInBounds () : ArrayAttr ())))
3063
3079
return failure ();
3064
3080
@@ -3422,6 +3438,18 @@ void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3422
3438
build (builder, result, vector, dest, indices, permutationMap, inBounds);
3423
3439
}
3424
3440
3441
+ // / Infers the mask type for a transfer write given its vector type and
3442
+ // / permutation map. The mask in a transfer read operation applies to the
3443
+ // / tensor/buffer writing part of it and its type should match the shape written
3444
+ // / *after* any permutation.
3445
+ static VectorType inferTransferWriteMaskType (VectorType vecType,
3446
+ AffineMap permMap) {
3447
+ auto i1Type = IntegerType::get (permMap.getContext (), 1 );
3448
+ SmallVector<int64_t , 8 > maskShape =
3449
+ compressUnusedDims (permMap).compose (vecType.getShape ());
3450
+ return VectorType::get (maskShape, i1Type);
3451
+ }
3452
+
3425
3453
ParseResult TransferWriteOp::parse (OpAsmParser &parser,
3426
3454
OperationState &result) {
3427
3455
auto &builder = parser.getBuilder ();
@@ -3449,11 +3477,14 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3449
3477
ShapedType shapedType = types[1 ].dyn_cast <ShapedType>();
3450
3478
if (!shapedType || !shapedType.isa <MemRefType, RankedTensorType>())
3451
3479
return parser.emitError (typesLoc, " requires memref or ranked tensor type" );
3452
- auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName ();
3453
- auto attr = result.attributes .get (permutationAttrName);
3454
- if (!attr) {
3455
- auto permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3456
- result.attributes .set (permutationAttrName, AffineMapAttr::get (permMap));
3480
+ auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName ();
3481
+ auto permMapAttr = result.attributes .get (permMapAttrName);
3482
+ AffineMap permMap;
3483
+ if (!permMapAttr) {
3484
+ permMap = getTransferMinorIdentityMap (shapedType, vectorType);
3485
+ result.attributes .set (permMapAttrName, AffineMapAttr::get (permMap));
3486
+ } else {
3487
+ permMap = permMapAttr.cast <AffineMapAttr>().getValue ();
3457
3488
}
3458
3489
if (parser.resolveOperand (vectorInfo, vectorType, result.operands ) ||
3459
3490
parser.resolveOperand (sourceInfo, shapedType, result.operands ) ||
@@ -3463,7 +3494,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3463
3494
if (shapedType.getElementType ().dyn_cast <VectorType>())
3464
3495
return parser.emitError (
3465
3496
maskInfo.location , " does not support masks with vector element type" );
3466
- auto maskType = VectorType::get (vectorType. getShape (), builder. getI1Type () );
3497
+ auto maskType = inferTransferWriteMaskType (vectorType, permMap );
3467
3498
if (parser.resolveOperand (maskInfo, maskType, result.operands ))
3468
3499
return failure ();
3469
3500
}
@@ -3489,6 +3520,9 @@ LogicalResult TransferWriteOp::verify() {
3489
3520
VectorType vectorType = getVectorType ();
3490
3521
VectorType maskType = getMaskType ();
3491
3522
auto permutationMap = getPermutationMap ();
3523
+ VectorType inferredMaskType =
3524
+ maskType ? inferTransferWriteMaskType (vectorType, permutationMap)
3525
+ : VectorType ();
3492
3526
3493
3527
if (llvm::size (getIndices ()) != shapedType.getRank ())
3494
3528
return emitOpError (" requires " ) << shapedType.getRank () << " indices" ;
@@ -3499,7 +3533,8 @@ LogicalResult TransferWriteOp::verify() {
3499
3533
return emitOpError (" should not have broadcast dimensions" );
3500
3534
3501
3535
if (failed (verifyTransferOp (cast<VectorTransferOpInterface>(getOperation ()),
3502
- shapedType, vectorType, maskType, permutationMap,
3536
+ shapedType, vectorType, maskType,
3537
+ inferredMaskType, permutationMap,
3503
3538
getInBounds () ? *getInBounds () : ArrayAttr ())))
3504
3539
return failure ();
3505
3540
0 commit comments