@@ -20,12 +20,14 @@ limitations under the License.
20
20
#include < memory>
21
21
#include < numeric>
22
22
#include < optional>
23
+ #include < tuple>
23
24
#include < utility>
24
25
#include < vector>
25
26
26
27
#include " absl/log/check.h"
27
28
#include " llvm/ADT/ArrayRef.h"
28
29
#include " llvm/ADT/STLExtras.h"
30
+ #include " llvm/ADT/Sequence.h"
29
31
#include " llvm/ADT/SmallVector.h"
30
32
#include " llvm/ADT/StringMap.h"
31
33
#include " mlir/Dialect/Arith/IR/Arith.h"
@@ -110,6 +112,160 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
110
112
bool need_elementwise_canonicalization (const CanonicalizeContext &ctx,
111
113
Operation &op);
112
114
115
+ // Returns the collapsed lhs, rhs, acc and the new dimension numbers if the
116
+ // non-contracting dims can be collapsed, otherwise returns std::nullopt.
117
+ std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
118
+ TypedValue<VectorType>, tpu::DotDimensionNumbersAttr>>
119
+ collapse_matmul_non_contracting_dims (
120
+ CanonicalBuilder &builder, TypedValue<VectorType> lhs,
121
+ TypedValue<VectorType> rhs, TypedValue<VectorType> acc,
122
+ const tpu::DotDimensionNumbersAttr &dimension_numbers) {
123
+ // Collapse
124
+ //
125
+ // 1. [batch_dims, non_contracting_dims, contracting_dims] into
126
+ // [batch_dims, prod(non_contracting_dims), contracting_dims] or
127
+ // 2. [batch_dims, contracting_dims, non_contracting_dims] into
128
+ // [batch_dims, contracting_dims, prod(non_contracting_dims)].
129
+ //
130
+ // Returns a tuple of [new_operand, new_non_contracting_dims,
131
+ // new_contracting_dims]. new_operand is nullptr if the operand does not need
132
+ // to be collapsed.
133
+ // TODO(b/413194126): Some shapes will trigger unsupported
134
+ // vector::ShapeCastOp.
135
+ auto maybe_collapse_non_contracting_dims =
136
+ [&](TypedValue<VectorType> operand,
137
+ ArrayRef<int64_t > non_contracting_dims,
138
+ ArrayRef<int64_t > contracting_dims, ArrayRef<int64_t > batch_dims)
139
+ -> std::tuple<TypedValue<VectorType>, SmallVector<int64_t , 2 >,
140
+ SmallVector<int64_t , 2 >> {
141
+ VectorType vty = operand.getType ();
142
+ auto shape = vty.getShape ();
143
+ bool batch_dims_are_front =
144
+ batch_dims == ArrayRef<int64_t >(llvm::to_vector (
145
+ llvm::seq<int64_t >(0 , batch_dims.size ())));
146
+ // Case 1.
147
+ bool trailing_contracting_dims =
148
+ contracting_dims ==
149
+ ArrayRef<int64_t >(llvm::to_vector (llvm::seq<int64_t >(
150
+ shape.size () - contracting_dims.size (), shape.size ()))) &&
151
+ non_contracting_dims ==
152
+ ArrayRef<int64_t >(llvm::to_vector (llvm::seq<int64_t >(
153
+ batch_dims.size (),
154
+ batch_dims.size () + non_contracting_dims.size ())));
155
+ // Case 2.
156
+ bool trailing_non_contracting_dims =
157
+ non_contracting_dims ==
158
+ ArrayRef<int64_t >(llvm::to_vector (llvm::seq<int64_t >(
159
+ shape.size () - non_contracting_dims.size (), shape.size ()))) &&
160
+ contracting_dims ==
161
+ ArrayRef<int64_t >(llvm::to_vector (llvm::seq<int64_t >(
162
+ batch_dims.size (),
163
+ batch_dims.size () + contracting_dims.size ())));
164
+ bool should_collapse_non_contracting_dims =
165
+ batch_dims_are_front &&
166
+ (trailing_contracting_dims || trailing_non_contracting_dims) &&
167
+ non_contracting_dims.size () > 1 ;
168
+ if (!should_collapse_non_contracting_dims) {
169
+ return {nullptr , llvm::to_vector (non_contracting_dims),
170
+ llvm::to_vector (contracting_dims)};
171
+ }
172
+ SmallVector<int64_t , 2 > new_shape;
173
+ auto batch_shape = shape.take_front (batch_dims.size ());
174
+ new_shape.append (batch_shape.begin (), batch_shape.end ());
175
+ SmallVector<int64_t , 2 > contracting_sizes;
176
+ for (int64_t contracting_dim : contracting_dims) {
177
+ contracting_sizes.push_back (shape[contracting_dim]);
178
+ }
179
+ int64_t collapsed_dim_size = 1 ;
180
+ for (int64_t non_contracting_dim : non_contracting_dims) {
181
+ collapsed_dim_size *= shape[non_contracting_dim];
182
+ }
183
+ if (trailing_contracting_dims) {
184
+ new_shape.push_back (collapsed_dim_size);
185
+ new_shape.append (contracting_sizes.begin (), contracting_sizes.end ());
186
+ } else {
187
+ new_shape.append (contracting_sizes.begin (), contracting_sizes.end ());
188
+ new_shape.push_back (collapsed_dim_size);
189
+ }
190
+ auto new_operand =
191
+ cast<TypedValue<VectorType>>(builder.create <vector::ShapeCastOp>(
192
+ VectorType::get (new_shape, vty.getElementType ()), operand));
193
+ SmallVector<int64_t , 2 > new_non_contracting_dims, new_contracting_dims;
194
+ if (trailing_non_contracting_dims) {
195
+ // Case 2 - contracting dims are not changed and non contracting dims are
196
+ // changed to the last dim.
197
+ new_contracting_dims = llvm::to_vector (contracting_dims);
198
+ new_non_contracting_dims.push_back (new_shape.size () - 1 );
199
+ } else {
200
+ // Case 1 - non contracting dims are collapsed in the middle so all
201
+ // contracting dims are moved forward by (non_contracting_dims.size() -
202
+ // 1).
203
+ new_non_contracting_dims.push_back (batch_dims.size ());
204
+ for (int64_t contracting_dim : contracting_dims) {
205
+ new_contracting_dims.push_back (contracting_dim -
206
+ (non_contracting_dims.size () - 1 ));
207
+ }
208
+ }
209
+ return {new_operand, new_non_contracting_dims, new_contracting_dims};
210
+ };
211
+
212
+ auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
213
+ maybe_collapse_non_contracting_dims (
214
+ lhs, dimension_numbers.getLhsNonContractingDims (),
215
+ dimension_numbers.getLhsContractingDims (),
216
+ dimension_numbers.getLhsBatchDims ());
217
+
218
+ auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
219
+ maybe_collapse_non_contracting_dims (
220
+ rhs, dimension_numbers.getRhsNonContractingDims (),
221
+ dimension_numbers.getRhsContractingDims (),
222
+ dimension_numbers.getRhsBatchDims ());
223
+
224
+ // Nothing to collapse.
225
+ if (!new_lhs && !new_rhs) {
226
+ return std::nullopt;
227
+ }
228
+
229
+ // Overwrite the operands if they were collapsed. We're going to access the
230
+ // new shapes below.
231
+ lhs = new_lhs ? new_lhs : lhs;
232
+ rhs = new_rhs ? new_rhs : rhs;
233
+
234
+ SmallVector<int64_t , 2 > new_output_dim_order;
235
+ SmallVector<int64_t , 2 > new_acc_shape;
236
+ for (int64_t batch_dim : dimension_numbers.getLhsBatchDims ()) {
237
+ new_output_dim_order.push_back (0 );
238
+ new_output_dim_order.push_back (batch_dim);
239
+ new_acc_shape.push_back (lhs.getType ().getDimSize (batch_dim));
240
+ }
241
+ for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
242
+ new_output_dim_order.push_back (0 );
243
+ new_output_dim_order.push_back (non_contracting_dim);
244
+ new_acc_shape.push_back (lhs.getType ().getDimSize (non_contracting_dim));
245
+ }
246
+ for (int64_t non_contracting_dim : new_rhs_non_contracting_dims) {
247
+ new_output_dim_order.push_back (1 );
248
+ new_output_dim_order.push_back (non_contracting_dim);
249
+ new_acc_shape.push_back (rhs.getType ().getDimSize (non_contracting_dim));
250
+ }
251
+
252
+ // Batch dims are always at the front of the lhs and rhs.
253
+ tpu::DotDimensionNumbersAttr new_dimension_numbers =
254
+ tpu::DotDimensionNumbersAttr::get (
255
+ builder.getContext (), new_lhs_contracting_dims,
256
+ new_rhs_contracting_dims, new_lhs_non_contracting_dims,
257
+ new_rhs_non_contracting_dims, new_output_dim_order,
258
+ dimension_numbers.getLhsBatchDims (),
259
+ dimension_numbers.getRhsBatchDims ());
260
+
261
+ // Reshape acc too.
262
+ auto new_acc =
263
+ cast<TypedValue<VectorType>>(builder.create <vector::ShapeCastOp>(
264
+ VectorType::get (new_acc_shape, acc.getType ().getElementType ()), acc));
265
+
266
+ return std::make_tuple (lhs, rhs, new_acc, new_dimension_numbers);
267
+ }
268
+
113
269
FailureOr<Value> canonicalize_matmul (const CanonicalizeContext &ctx,
114
270
Operation &raw_op) {
115
271
auto op = cast<tpu::MatmulOp>(raw_op);
@@ -122,13 +278,13 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
122
278
auto rhs = op.getRhs ();
123
279
auto acc = op.getAcc ();
124
280
125
- const VectorType lhs_ty = lhs.getType ();
126
- const VectorType rhs_ty = rhs.getType ();
127
- const VectorType acc_ty = acc.getType ();
281
+ const VectorType old_lhs_ty = lhs.getType ();
282
+ const VectorType old_rhs_ty = rhs.getType ();
283
+ const VectorType old_acc_ty = acc.getType ();
128
284
129
- auto lhs_element_type = lhs_ty .getElementType ();
130
- auto rhs_element_type = rhs_ty .getElementType ();
131
- auto acc_element_type = acc_ty .getElementType ();
285
+ auto lhs_element_type = old_lhs_ty .getElementType ();
286
+ auto rhs_element_type = old_rhs_ty .getElementType ();
287
+ auto acc_element_type = old_acc_ty .getElementType ();
132
288
133
289
// there are a few primary paths for dimension_numbers in matmul
134
290
// 1) No dimension numbers provided -> set to default
@@ -146,6 +302,14 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
146
302
// Dot dim API - dimensions are provided and are not default
147
303
(op.getDimensionNumbers ().value () !=
148
304
defaultDimensionNumbers (builder, false , false ))) {
305
+ if (auto collapsed_operands_and_ddn = collapse_matmul_non_contracting_dims (
306
+ builder, lhs, rhs, acc, *op.getDimensionNumbers ())) {
307
+ tpu::DotDimensionNumbersAttr new_dimension_numbers;
308
+ std::tie (lhs, rhs, acc, new_dimension_numbers) =
309
+ *collapsed_operands_and_ddn;
310
+ op.setDimensionNumbersAttr (new_dimension_numbers);
311
+ }
312
+
149
313
auto dimension_numbers = op.getDimensionNumbers ();
150
314
auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims ();
151
315
auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims ();
@@ -156,11 +320,11 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
156
320
// Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs
157
321
// are the same
158
322
// Invariant in matmul verifier: Exactly one contracting and non contracting
159
- // dim in each of lhs and rhs for now .
160
- batch_size =
161
- lhs_batch_dims. empty ()
162
- ? std::nullopt
163
- : std::optional< int64_t >(lhs_ty .getShape ()[lhs_batch_dims[0 ]]);
323
+ // dim in each of lhs and rhs at the moment .
324
+ batch_size = lhs_batch_dims. empty ()
325
+ ? std::nullopt
326
+ : std::optional< int64_t >(
327
+ lhs. getType () .getShape ()[lhs_batch_dims[0 ]]);
164
328
// Lower each dim in contracting dims by size(batch_dims)
165
329
auto batch_adjusted_lhs_contracting_dim =
166
330
lhs_contracting_dims[0 ] - lhs_batch_dims.size ();
@@ -175,6 +339,20 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
175
339
}
176
340
}
177
341
342
+ // Make sure there is only one non-contracting dim in each of lhs and rhs
343
+ // after collapsing.
344
+ auto dimension_numbers = op.getDimensionNumbers ();
345
+ if (dimension_numbers->getLhsNonContractingDims ().size () != 1 ) {
346
+ return op->emitOpError (
347
+ " Not implemented: lhs non contracting dims must be an infix/suffix of "
348
+ " the shape." );
349
+ }
350
+ if (dimension_numbers->getRhsNonContractingDims ().size () != 1 ) {
351
+ return op->emitOpError (
352
+ " Not implemented: rhs non contracting dims must be an infix/suffix of "
353
+ " the shape." );
354
+ }
355
+
178
356
auto extsi_sitofp = [&builder, &op](TypedValue<VectorType> element) {
179
357
const VectorType ty = element.getType ();
180
358
auto shape = ty.getShape ();
@@ -247,7 +425,6 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
247
425
// operation that fuses the transpose into the matmul.
248
426
auto transpose_op =
249
427
dyn_cast_if_present<tpu::TransposeOp>(rhs.getDefiningOp ());
250
- auto dimension_numbers = op.getDimensionNumbers ();
251
428
if (transpose_op && transpose_op->hasOneUse () &&
252
429
dimension_numbers->getRhsContractingDims ().size () == 1 &&
253
430
dimension_numbers->getRhsNonContractingDims ().size () == 1 ) {
@@ -259,7 +436,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
259
436
permutation[rhs_non_contracting_dim] == rhs_contracting_dim &&
260
437
std::all_of (dimension_numbers->getRhsBatchDims ().begin (),
261
438
dimension_numbers->getRhsBatchDims ().end (),
262
- [&](long batch_dim) {
439
+ [&](int64_t batch_dim) {
263
440
return permutation[batch_dim] == batch_dim;
264
441
})) {
265
442
if (auto transpose_op_vector_operand =
@@ -312,6 +489,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
312
489
// If we have a batch_size, we want to slice rhs and lhs [:batch_size],
313
490
// and then do O[i] = A[i] @ B[i]
314
491
// Produce an output shape of [batch_size, m, n]
492
+ Value res;
315
493
if (batch_size.has_value ()) {
316
494
std::vector<Value> outputs;
317
495
@@ -336,22 +514,22 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
336
514
// Technically almost identical to the case where batch_size is 1, but
337
515
// we want to avoid the spurious concat here.
338
516
if (batch_size == 1 ) {
339
- op.replaceAllUsesWith (outputs[0 ]);
340
- op.erase ();
341
- return outputs[0 ];
517
+ res = outputs[0 ];
518
+ } else {
519
+ res = builder.create <tpu::ConcatenateOp>(acc.getType (), outputs,
520
+ /* dimension=*/ 0 );
342
521
}
343
- auto output =
344
- builder.create <tpu::ConcatenateOp>(acc_ty, outputs, /* dimension=*/ 0 );
345
- op.replaceAllUsesWith (output);
346
- op.erase ();
347
- return output;
348
522
} else {
349
- auto matmul_res = dot_dim_matmul (lhs, rhs, acc);
350
- op.replaceAllUsesWith (matmul_res);
351
- op.erase ();
352
- return matmul_res;
523
+ res = dot_dim_matmul (lhs, rhs, acc);
524
+ }
525
+
526
+ // Reshape the result to the old one as dims might have been collapsed.
527
+ if (res.getType () != old_acc_ty) {
528
+ res = builder.create <vector::ShapeCastOp>(old_acc_ty, res);
353
529
}
354
- return op.getResult ();
530
+ op.replaceAllUsesWith (res);
531
+ op.erase ();
532
+ return res;
355
533
};
356
534
357
535
FailureOr<Value> canonicalize_elementwise (const CanonicalizeContext &ctx,
0 commit comments