@@ -20,12 +20,14 @@ limitations under the License.
20
20
#include < memory>
21
21
#include < numeric>
22
22
#include < optional>
23
+ #include < tuple>
23
24
#include < utility>
24
25
#include < vector>
25
26
26
27
#include " absl/log/check.h"
27
28
#include " llvm/ADT/ArrayRef.h"
28
29
#include " llvm/ADT/STLExtras.h"
30
+ #include " llvm/ADT/Sequence.h"
29
31
#include " llvm/ADT/SmallVector.h"
30
32
#include " llvm/ADT/StringMap.h"
31
33
#include " mlir/Dialect/Arith/IR/Arith.h"
@@ -110,6 +112,152 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
110
112
bool need_elementwise_canonicalization (const CanonicalizeContext &ctx,
111
113
Operation &op);
112
114
115
+ // Returns the collapsed lhs, rhs, acc and the new dimension numbers if the
116
+ // non-contracting dims can be collapsed, otherwise returns std::nullopt.
117
+ std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
118
+ TypedValue<VectorType>, tpu::DotDimensionNumbersAttr>>
119
+ collapse_matmul_non_contracting_dims (
120
+ CanonicalBuilder &builder, TypedValue<VectorType> lhs,
121
+ TypedValue<VectorType> rhs, TypedValue<VectorType> acc,
122
+ const tpu::DotDimensionNumbersAttr &dimension_numbers) {
123
+ // Collapse
124
+ //
125
+ // 1. [batch_dims, non_contracting_dims, contracting_dims] into
126
+ // [batch_dims, prod(non_contracting_dims), contracting_dims] or
127
+ // 2. [batch_dims, contracting_dims, non_contracting_dims] into
128
+ // [batch_dims, contracting_dims, prod(non_contracting_dims)].
129
+ //
130
+ // Returns a tuple of [new_operand, new_non_contracting_dims,
131
+ // new_contracting_dims]. new_operand is nullptr if the operand does not need
132
+ // to be collapsed.
133
+ // TODO(b/413194126): Some shapes will trigger unsupported
134
+ // vector::ShapeCastOp.
135
+ auto maybe_collapse_non_contracting_dims =
136
+ [&](TypedValue<VectorType> operand,
137
+ ArrayRef<int64_t > non_contracting_dims,
138
+ ArrayRef<int64_t > contracting_dims, ArrayRef<int64_t > batch_dims)
139
+ -> std::tuple<TypedValue<VectorType>, SmallVector<int64_t , 2 >,
140
+ SmallVector<int64_t , 2 >> {
141
+ VectorType vty = operand.getType ();
142
+ auto shape = vty.getShape ();
143
+ bool batch_dims_are_front =
144
+ batch_dims == ArrayRef<int64_t >(llvm::to_vector (
145
+ llvm::seq<int64_t >(0 , batch_dims.size ())));
146
+ // Case 1.
147
+ bool trailing_contracting_dims =
148
+ contracting_dims ==
149
+ ArrayRef<int64_t >(llvm::to_vector (llvm::seq<int64_t >(
150
+ shape.size () - contracting_dims.size (), shape.size ())));
151
+ // Case 2.
152
+ bool trailing_non_contracting_dims =
153
+ non_contracting_dims ==
154
+ ArrayRef<int64_t >(llvm::to_vector (llvm::seq<int64_t >(
155
+ shape.size () - non_contracting_dims.size (), shape.size ())));
156
+ bool should_collapse_non_contracting_dims =
157
+ batch_dims_are_front &&
158
+ (trailing_contracting_dims || trailing_non_contracting_dims) &&
159
+ non_contracting_dims.size () > 1 ;
160
+ if (!should_collapse_non_contracting_dims) {
161
+ return {nullptr , llvm::to_vector (non_contracting_dims),
162
+ llvm::to_vector (contracting_dims)};
163
+ }
164
+ SmallVector<int64_t , 2 > new_shape;
165
+ auto batch_sizes = shape.take_front (batch_dims.size ());
166
+ new_shape.append (batch_sizes.begin (), batch_sizes.end ());
167
+ SmallVector<int64_t , 2 > contracting_sizes;
168
+ for (int64_t contracting_dim : contracting_dims) {
169
+ contracting_sizes.push_back (shape[contracting_dim]);
170
+ }
171
+ int64_t collapsed_dim_size = std::accumulate (
172
+ non_contracting_dims.begin (), non_contracting_dims.end (), 1 ,
173
+ [&](int64_t a, int64_t b) { return a * shape[b]; });
174
+ ;
175
+ if (trailing_contracting_dims) {
176
+ new_shape.push_back (collapsed_dim_size);
177
+ new_shape.append (contracting_sizes.begin (), contracting_sizes.end ());
178
+ } else {
179
+ new_shape.append (contracting_sizes.begin (), contracting_sizes.end ());
180
+ new_shape.push_back (collapsed_dim_size);
181
+ }
182
+ auto new_operand =
183
+ cast<TypedValue<VectorType>>(builder.create <vector::ShapeCastOp>(
184
+ VectorType::get (new_shape, vty.getElementType ()), operand));
185
+ SmallVector<int64_t , 2 > new_non_contracting_dims, new_contracting_dims;
186
+ if (trailing_non_contracting_dims) {
187
+ // Case 2 - contracting dims are not changed and non contracting dims are
188
+ // changed to the last dim.
189
+ new_contracting_dims = llvm::to_vector (contracting_dims);
190
+ new_non_contracting_dims.push_back (new_shape.size () - 1 );
191
+ } else {
192
+ // Case 1 - non contracting dims are collapsed in the middle so all
193
+ // contracting dims are moved forward by (non_contracting_dims.size() -
194
+ // 1).
195
+ new_non_contracting_dims.push_back (batch_dims.size ());
196
+ for (int64_t contracting_dim : contracting_dims) {
197
+ new_contracting_dims.push_back (contracting_dim -
198
+ (non_contracting_dims.size () - 1 ));
199
+ }
200
+ }
201
+ return {new_operand, new_non_contracting_dims, new_contracting_dims};
202
+ };
203
+
204
+ auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
205
+ maybe_collapse_non_contracting_dims (
206
+ lhs, dimension_numbers.getLhsNonContractingDims (),
207
+ dimension_numbers.getLhsContractingDims (),
208
+ dimension_numbers.getLhsBatchDims ());
209
+
210
+ auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
211
+ maybe_collapse_non_contracting_dims (
212
+ rhs, dimension_numbers.getRhsNonContractingDims (),
213
+ dimension_numbers.getRhsContractingDims (),
214
+ dimension_numbers.getRhsBatchDims ());
215
+
216
+ // Nothing to collapse.
217
+ if (!new_lhs && !new_rhs) {
218
+ return std::nullopt;
219
+ }
220
+
221
+ // Overwrite the operands if they were collapsed. We're going to access the
222
+ // new shapes below.
223
+ lhs = new_lhs ? new_lhs : lhs;
224
+ rhs = new_rhs ? new_rhs : rhs;
225
+
226
+ SmallVector<int64_t , 2 > new_output_dim_order;
227
+ SmallVector<int64_t , 2 > new_acc_shape;
228
+ for (int64_t batch_dim : dimension_numbers.getLhsBatchDims ()) {
229
+ new_output_dim_order.push_back (0 );
230
+ new_output_dim_order.push_back (batch_dim);
231
+ new_acc_shape.push_back (lhs.getType ().getDimSize (batch_dim));
232
+ }
233
+ for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
234
+ new_output_dim_order.push_back (0 );
235
+ new_output_dim_order.push_back (non_contracting_dim);
236
+ new_acc_shape.push_back (lhs.getType ().getDimSize (non_contracting_dim));
237
+ }
238
+ for (int64_t non_contracting_dim : new_rhs_non_contracting_dims) {
239
+ new_output_dim_order.push_back (1 );
240
+ new_output_dim_order.push_back (non_contracting_dim);
241
+ new_acc_shape.push_back (rhs.getType ().getDimSize (non_contracting_dim));
242
+ }
243
+
244
+ // Batch dims are always at the front of the lhs and rhs.
245
+ tpu::DotDimensionNumbersAttr new_dimension_numbers =
246
+ tpu::DotDimensionNumbersAttr::get (
247
+ builder.getContext (), new_lhs_contracting_dims,
248
+ new_rhs_contracting_dims, new_lhs_non_contracting_dims,
249
+ new_rhs_non_contracting_dims, new_output_dim_order,
250
+ dimension_numbers.getLhsBatchDims (),
251
+ dimension_numbers.getRhsBatchDims ());
252
+
253
+ // Reshape acc too.
254
+ auto new_acc =
255
+ cast<TypedValue<VectorType>>(builder.create <vector::ShapeCastOp>(
256
+ VectorType::get (new_acc_shape, acc.getType ().getElementType ()), acc));
257
+
258
+ return std::make_tuple (lhs, rhs, new_acc, new_dimension_numbers);
259
+ }
260
+
113
261
FailureOr<Value> canonicalize_matmul (const CanonicalizeContext &ctx,
114
262
Operation &raw_op) {
115
263
auto op = cast<tpu::MatmulOp>(raw_op);
@@ -122,13 +270,13 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
122
270
auto rhs = op.getRhs ();
123
271
auto acc = op.getAcc ();
124
272
125
- const VectorType lhs_ty = lhs.getType ();
126
- const VectorType rhs_ty = rhs.getType ();
127
- const VectorType acc_ty = acc.getType ();
273
+ const VectorType old_lhs_ty = lhs.getType ();
274
+ const VectorType old_rhs_ty = rhs.getType ();
275
+ const VectorType old_acc_ty = acc.getType ();
128
276
129
- auto lhs_element_type = lhs_ty .getElementType ();
130
- auto rhs_element_type = rhs_ty .getElementType ();
131
- auto acc_element_type = acc_ty .getElementType ();
277
+ auto lhs_element_type = old_lhs_ty .getElementType ();
278
+ auto rhs_element_type = old_rhs_ty .getElementType ();
279
+ auto acc_element_type = old_acc_ty .getElementType ();
132
280
133
281
// there are a few primary paths for dimension_numbers in matmul
134
282
// 1) No dimension numbers provided -> set to default
@@ -146,6 +294,14 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
146
294
// Dot dim API - dimensions are provided and are not default
147
295
(op.getDimensionNumbers ().value () !=
148
296
defaultDimensionNumbers (builder, false , false ))) {
297
+ if (auto collapsed_operands_and_ddn = collapse_matmul_non_contracting_dims (
298
+ builder, lhs, rhs, acc, *op.getDimensionNumbers ())) {
299
+ tpu::DotDimensionNumbersAttr new_dimension_numbers;
300
+ std::tie (lhs, rhs, acc, new_dimension_numbers) =
301
+ *collapsed_operands_and_ddn;
302
+ op.setDimensionNumbersAttr (new_dimension_numbers);
303
+ }
304
+
149
305
auto dimension_numbers = op.getDimensionNumbers ();
150
306
auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims ();
151
307
auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims ();
@@ -156,11 +312,11 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
156
312
// Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs
157
313
// are the same
158
314
// Invariant in matmul verifier: Exactly one contracting and non contracting
159
- // dim in each of lhs and rhs for now .
160
- batch_size =
161
- lhs_batch_dims. empty ()
162
- ? std::nullopt
163
- : std::optional< int64_t >(lhs_ty .getShape ()[lhs_batch_dims[0 ]]);
315
+ // dim in each of lhs and rhs at the moment .
316
+ batch_size = lhs_batch_dims. empty ()
317
+ ? std::nullopt
318
+ : std::optional< int64_t >(
319
+ lhs. getType () .getShape ()[lhs_batch_dims[0 ]]);
164
320
// Lower each dim in contracting dims by size(batch_dims)
165
321
auto batch_adjusted_lhs_contracting_dim =
166
322
lhs_contracting_dims[0 ] - lhs_batch_dims.size ();
@@ -175,6 +331,18 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
175
331
}
176
332
}
177
333
334
+ // Make sure there is only one non-contracting dim in each of lhs and rhs
335
+ // after collapsing.
336
+ auto dimension_numbers = op.getDimensionNumbers ();
337
+ if (dimension_numbers->getLhsNonContractingDims ().size () != 1 ) {
338
+ return op->emitOpError (
339
+ " Not implemented: lhs non contracting dims must be of size 1" );
340
+ }
341
+ if (dimension_numbers->getRhsNonContractingDims ().size () != 1 ) {
342
+ return op->emitOpError (
343
+ " Not implemented: rhs non contracting dims must be of size 1" );
344
+ }
345
+
178
346
auto extsi_sitofp = [&builder, &op](TypedValue<VectorType> element) {
179
347
const VectorType ty = element.getType ();
180
348
auto shape = ty.getShape ();
@@ -247,7 +415,6 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
247
415
// operation that fuses the transpose into the matmul.
248
416
auto transpose_op =
249
417
dyn_cast_if_present<tpu::TransposeOp>(rhs.getDefiningOp ());
250
- auto dimension_numbers = op.getDimensionNumbers ();
251
418
if (transpose_op && transpose_op->hasOneUse () &&
252
419
dimension_numbers->getRhsContractingDims ().size () == 1 &&
253
420
dimension_numbers->getRhsNonContractingDims ().size () == 1 ) {
@@ -259,7 +426,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
259
426
permutation[rhs_non_contracting_dim] == rhs_contracting_dim &&
260
427
std::all_of (dimension_numbers->getRhsBatchDims ().begin (),
261
428
dimension_numbers->getRhsBatchDims ().end (),
262
- [&](long batch_dim) {
429
+ [&](int64_t batch_dim) {
263
430
return permutation[batch_dim] == batch_dim;
264
431
})) {
265
432
if (auto transpose_op_vector_operand =
@@ -312,6 +479,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
312
479
// If we have a batch_size, we want to slice rhs and lhs [:batch_size],
313
480
// and then do O[i] = A[i] @ B[i]
314
481
// Produce an output shape of [batch_size, m, n]
482
+ Value res;
315
483
if (batch_size.has_value ()) {
316
484
std::vector<Value> outputs;
317
485
@@ -336,22 +504,22 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
336
504
// Technically almost identical to the case where batch_size is 1, but
337
505
// we want to avoid the spurious concat here.
338
506
if (batch_size == 1 ) {
339
- op.replaceAllUsesWith (outputs[0 ]);
340
- op.erase ();
341
- return outputs[0 ];
507
+ res = outputs[0 ];
508
+ } else {
509
+ res = builder.create <tpu::ConcatenateOp>(acc.getType (), outputs,
510
+ /* dimension=*/ 0 );
342
511
}
343
- auto output =
344
- builder.create <tpu::ConcatenateOp>(acc_ty, outputs, /* dimension=*/ 0 );
345
- op.replaceAllUsesWith (output);
346
- op.erase ();
347
- return output;
348
512
} else {
349
- auto matmul_res = dot_dim_matmul (lhs, rhs, acc);
350
- op.replaceAllUsesWith (matmul_res);
351
- op.erase ();
352
- return matmul_res;
513
+ res = dot_dim_matmul (lhs, rhs, acc);
353
514
}
354
- return op.getResult ();
515
+
516
+ // Reshape the result to the old one as dims might have been collapsed.
517
+ if (res.getType () != old_acc_ty) {
518
+ res = builder.create <vector::ShapeCastOp>(old_acc_ty, res);
519
+ }
520
+ op.replaceAllUsesWith (res);
521
+ op.erase ();
522
+ return res;
355
523
};
356
524
357
525
FailureOr<Value> canonicalize_elementwise (const CanonicalizeContext &ctx,
0 commit comments