Skip to content

Commit 06ae0c2

Browse files
authored
[mlir][xegpu] Remove vector contract to dpas size restriction (#147470)
Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribution.
1 parent 253f8b6 commit 06ae0c2

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,6 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
339339
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
340340
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
341341

342-
// TODO: Update shape validation to be target aware.
343-
auto accShape = accType.getShape();
344-
int64_t dimN = accShape[1];
345-
if (dimN != 8 && dimN != 16)
346-
return rewriter.notifyMatchFailure(contractOp,
347-
"Invalid operand dimensions");
348-
349342
auto dpasOp = rewriter.create<xegpu::DpasOp>(
350343
loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
351344
rewriter.replaceOp(contractOp, dpasOp);

mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
4848

4949
// -----
5050

51+
// No restriction on vector sizes to allow capturing workgroup-sized operations.
52+
// The operations can then be progressively resized through distribution down
53+
// to hardware compatible sizes.
54+
55+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
56+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
57+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
58+
func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>,
59+
%acc: vector<128x256xf32>) -> vector<128x256xf32> {
60+
%3 = vector.contract
61+
{indexing_maps = [#map, #map1, #map2],
62+
iterator_types = ["parallel", "parallel", "reduction"],
63+
kind = #vector.kind<add>} %lhs, %rhs, %acc
64+
: vector<128x512xf16>, vector<512x256xf16> into vector<128x256xf32>
65+
return %3 : vector<128x256xf32>
66+
}
67+
68+
// CHECK-LABEL: @dpas_large_dims(
69+
// CHECK-SAME: %[[LHS:.+]]: vector<128x512xf16>,
70+
// CHECK-SAME: %[[RHS:.+]]: vector<512x256xf16>,
71+
// CHECK-SAME: %[[ACC:.+]]: vector<128x256xf32>
72+
// CHECK: %[[DPAS:.+]] = xegpu.dpas
73+
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
74+
// CHECK-SAME: {{.*}}-> vector<128x256xf32>
75+
// CHECK: return %[[DPAS]]
76+
77+
// -----
78+
5179
// For simplicity, only plain data layouts are currently supported.
5280
// VNNI packing is applied later as a separate lowering step.
5381

@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
138166

139167
// CHECK-LABEL: @negative_gemm_transpose_b(
140168
// CHECK: vector.contract
141-
142-
// -----
143-
144-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
145-
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
146-
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
147-
func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
148-
%acc: vector<8x32xf32>) -> vector<8x32xf32> {
149-
%3 = vector.contract
150-
{indexing_maps = [#map, #map1, #map2],
151-
iterator_types = ["parallel", "parallel", "reduction"],
152-
kind = #vector.kind<add>} %lhs, %rhs, %acc
153-
: vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
154-
return %3 : vector<8x32xf32>
155-
}
156-
157-
// CHECK-LABEL: @negative_n_dim_size(
158-
// CHECK: vector.contract

0 commit comments

Comments
 (0)