@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
48
48
49
49
// -----
50
50
51
+ // No restriction on vector sizes to allow capturing workgroup-sized operations.
52
+ // The operations can then be progressively resized through distribution down
53
+ // to hardware compatible sizes.
54
+
55
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
56
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
57
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
58
+ func.func @dpas_large_dims (%lhs: vector <128 x512 xf16 >, %rhs: vector <512 x256 xf16 >,
59
+ %acc: vector <128 x256 xf32 >) -> vector <128 x256 xf32 > {
60
+ %3 = vector.contract
61
+ {index ing_maps = [#map , #map1 , #map2 ],
62
+ iterator_types = [" parallel" , " parallel" , " reduction" ],
63
+ kind = #vector.kind <add >} %lhs , %rhs , %acc
64
+ : vector <128 x512 xf16 >, vector <512 x256 xf16 > into vector <128 x256 xf32 >
65
+ return %3 : vector <128 x256 xf32 >
66
+ }
67
+
68
+ // CHECK-LABEL: @dpas_large_dims(
69
+ // CHECK-SAME: %[[LHS:.+]]: vector<128x512xf16>,
70
+ // CHECK-SAME: %[[RHS:.+]]: vector<512x256xf16>,
71
+ // CHECK-SAME: %[[ACC:.+]]: vector<128x256xf32>
72
+ // CHECK: %[[DPAS:.+]] = xegpu.dpas
73
+ // CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
74
+ // CHECK-SAME: {{.*}}-> vector<128x256xf32>
75
+ // CHECK: return %[[DPAS]]
76
+
77
+ // -----
78
+
51
79
// For simplicity, only plain data layouts are currently supported.
52
80
// VNNI packing is applied later as a separate lowering step.
53
81
@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
138
166
139
167
// CHECK-LABEL: @negative_gemm_transpose_b(
140
168
// CHECK: vector.contract
141
-
142
- // -----
143
-
144
- #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
145
- #map1 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
146
- #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
147
- func.func @negative_n_dim_size (%lhs: vector <8 x16 xf16 >, %rhs: vector <16 x32 xf16 >,
148
- %acc: vector <8 x32 xf32 >) -> vector <8 x32 xf32 > {
149
- %3 = vector.contract
150
- {index ing_maps = [#map , #map1 , #map2 ],
151
- iterator_types = [" parallel" , " parallel" , " reduction" ],
152
- kind = #vector.kind <add >} %lhs , %rhs , %acc
153
- : vector <8 x16 xf16 >, vector <16 x32 xf16 > into vector <8 x32 xf32 >
154
- return %3 : vector <8 x32 xf32 >
155
- }
156
-
157
- // CHECK-LABEL: @negative_n_dim_size(
158
- // CHECK: vector.contract
0 commit comments