Skip to content

Commit bf4d99e

Browse files
authored
[mlir][vector] Add deinterleave operation to vector dialect (#92409)
The deinterleave operation constructs two vectors from a single input vector. The first result vector contains the elements from even indexes of the input, and the second contains elements from odd indexes. This is the inverse of a `vector.interleave` operation. Each output's trailing dimension is half of the size of the input vector's trailing dimension. This operation requires the input vector to have a rank > 0 and an even number of elements in its trailing dimension. The operation supports scalable vectors. Example: ```mlir %0, %1 = vector.deinterleave %a : vector<8xi8> -> vector<4xi8> %2, %3 = vector.deinterleave %b : vector<2x8xi8> -> vector<2x4xi8> %4, %5 = vector.deinterleave %c : vector<2x8x4xi8> -> vector<2x8x2xi8> %6, %7 = vector.deinterleave %d : vector<[8]xf32> -> vector<[4]xf32> %8, %9 = vector.deinterleave %e : vector<2x[6]xf64> -> vector<2x[3]xf64> %10, %11 = vector.deinterleave %f : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64> ```
1 parent ae3f680 commit bf4d99e

File tree

3 files changed

+178
-0
lines changed

3 files changed

+178
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,86 @@ def Vector_InterleaveOp :
543543
}];
544544
}
545545

546+
class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
547+
"the trailing dimension of the results is half the width of source trailing dimension",
548+
"source", result,
549+
[{
550+
[&]() -> ::mlir::VectorType {
551+
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
552+
::mlir::VectorType::Builder builder(vectorType);
553+
auto lastDim = vectorType.getRank() - 1;
554+
auto newDimSize = vectorType.getDimSize(lastDim) / 2;;
555+
if (newDimSize <= 0)
556+
return vectorType; // (invalid input type)
557+
return builder.setDim(lastDim, newDimSize);
558+
}()
559+
}]
560+
>;
561+
562+
def SourceVectorEvenElementCount : PredOpTrait<
563+
"the trailing dimension of the source vector has an even number of elements",
564+
CPred<[{
565+
[&](){
566+
auto srcVec = getSourceVectorType();
567+
return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
568+
}()
569+
}]>
570+
>;
571+
572+
def Vector_DeinterleaveOp :
573+
Vector_Op<"deinterleave", [Pure,
574+
SourceVectorEvenElementCount,
575+
ResultIsHalfSourceVectorType<"res1">,
576+
AllTypesMatch<["res1", "res2"]>
577+
]> {
578+
let summary = "constructs two vectors by deinterleaving an input vector";
579+
let description = [{
580+
The deinterleave operation constructs two vectors from a single input
581+
vector. The first result vector contains the elements from even indexes
582+
of the input, and the second contains elements from odd indexes. This is
583+
the inverse of a `vector.interleave` operation.
584+
585+
Each output's trailing dimension is half of the size of the input
586+
vector's trailing dimension. This operation requires the input vector
587+
to have a rank > 0 and an even number of elements in its trailing
588+
dimension.
589+
590+
The operation supports scalable vectors.
591+
592+
Example:
593+
```mlir
594+
%0, %1 = vector.deinterleave %a
595+
: vector<8xi8> -> vector<4xi8>
596+
%2, %3 = vector.deinterleave %b
597+
: vector<2x8xi8> -> vector<2x4xi8>
598+
%4, %5 = vector.deinterleave %c
599+
: vector<2x8x4xi8> -> vector<2x8x2xi8>
600+
%6, %7 = vector.deinterleave %d
601+
: vector<[8]xf32> -> vector<[4]xf32>
602+
%8, %9 = vector.deinterleave %e
603+
: vector<2x[6]xf64> -> vector<2x[3]xf64>
604+
%10, %11 = vector.deinterleave %f
605+
: vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
606+
```
607+
}];
608+
609+
let arguments = (ins AnyVector:$source);
610+
let results = (outs AnyVector:$res1, AnyVector:$res2);
611+
612+
let assemblyFormat = [{
613+
$source attr-dict `:` type($source) `->` type($res1)
614+
}];
615+
616+
let extraClassDeclaration = [{
617+
VectorType getSourceVectorType() {
618+
return ::llvm::cast<VectorType>(getSource().getType());
619+
}
620+
VectorType getResultVectorType() {
621+
return ::llvm::cast<VectorType>(getRes1().getType());
622+
}
623+
}];
624+
}
625+
546626
def Vector_ExtractElementOp :
547627
Vector_Op<"extractelement", [Pure,
548628
TypesMatchWith<"result type matches element type of vector operand",

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,3 +1798,59 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
17981798
// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
17991799
%op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
18001800
}
1801+
1802+
// -----
1803+
1804+
func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
1805+
// expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
1806+
%0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
1807+
return
1808+
}
1809+
1810+
// -----
1811+
1812+
func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
1813+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}}
1814+
%0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>
1815+
return
1816+
}
1817+
1818+
// -----
1819+
1820+
func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
1821+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
1822+
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
1823+
return
1824+
}
1825+
1826+
// -----
1827+
1828+
func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
1829+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
1830+
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
1831+
return
1832+
}
1833+
1834+
// -----
1835+
1836+
func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
1837+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
1838+
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
1839+
return
1840+
}
1841+
1842+
// -----
1843+
1844+
func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
1845+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
1846+
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
1847+
return
1848+
}
1849+
1850+
// -----
1851+
1852+
func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
1853+
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
1854+
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
1855+
return
1856+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,45 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
11161116
%0 = vector.interleave %a, %b : vector<2x[2]xf64>
11171117
return %0 : vector<2x[4]xf64>
11181118
}
1119+
1120+
// CHECK-LABEL: @deinterleave_1d
1121+
func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
1122+
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>
1123+
%0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>
1124+
return %0, %1 : vector<2xf32>, vector<2xf32>
1125+
}
1126+
1127+
// CHECK-LABEL: @deinterleave_1d_scalable
1128+
func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
1129+
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>
1130+
%0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>
1131+
return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
1132+
}
1133+
1134+
// CHECK-LABEL: @deinterleave_2d
1135+
func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
1136+
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>
1137+
%0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>
1138+
return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
1139+
}
1140+
1141+
// CHECK-LABEL: @deinterleave_2d_scalable
1142+
func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
1143+
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>
1144+
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>
1145+
return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
1146+
}
1147+
1148+
// CHECK-LABEL: @deinterleave_nd
1149+
func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
1150+
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
1151+
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
1152+
return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
1153+
}
1154+
1155+
// CHECK-LABEL: @deinterleave_nd_scalable
1156+
func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
1157+
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
1158+
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
1159+
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
1160+
}

0 commit comments

Comments
 (0)