Skip to content

Commit eeae4ce

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] Support multiple non-contracting dims if they are collapsable.
With reshape, we can support the following two cases. 1. [batch_dims, non_contracting_dims, contracting_dims] -> [batch_dims, prod(non_contracting_dims), contracting_dims] or 2. [batch_dims, contracting_dims, non_contracting_dims] -> [batch_dims, contracting_dims, prod(non_contracting_dims)]. I'm reluctant to change apply vector layout and want to keep it to only handle 2D matrix. PiperOrigin-RevId: 780737523
1 parent 10b9fe0 commit eeae4ce

File tree

4 files changed

+292
-37
lines changed

4 files changed

+292
-37
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2131,7 +2131,12 @@ def _dot_general_lowering_rule(
21312131
raise NotImplementedError(ctx.avals_out[0].dtype)
21322132
lhs_aval, rhs_aval = ctx.avals_in
21332133
# This is really a matrix-vector product. It only looks like matrix-matrix.
2134-
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
2134+
if (
2135+
lhs_dims == (1,)
2136+
and rhs_dims == (1,)
2137+
and ctx.avals_in[1].shape[0] == 1
2138+
and len(ctx.avals_in[0].shape) == 2
2139+
):
21352140
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
21362141
bcast_shape = jnp.broadcast_shapes(
21372142
ctx.avals_in[0].shape, ctx.avals_out[0].shape

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,6 @@ LogicalResult MatmulOp::verify() {
10181018
// position 0. Future extensions to this will be to:
10191019
// 1. Support multiple batch dims
10201020
// 2. Support batch dims in any position in the output dim order
1021-
if (lhs_non_contracting_dims.size() != 1) {
1022-
emitOpError(
1023-
"Not implemented: lhs non contracting dims must be of size 1");
1024-
return failure();
1025-
}
1026-
if (rhs_non_contracting_dims.size() != 1) {
1027-
emitOpError(
1028-
"Not implemented: rhs non contracting dims must be of size 1");
1029-
return failure();
1030-
}
10311021

10321022
// A bit long winded, but the invariants we enforce below are:
10331023
// 1. The output order idx is 0 (lhs) or 1 (rhs)

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 194 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ limitations under the License.
2020
#include <memory>
2121
#include <numeric>
2222
#include <optional>
23+
#include <tuple>
2324
#include <utility>
2425
#include <vector>
2526

2627
#include "absl/log/check.h"
2728
#include "llvm/ADT/ArrayRef.h"
2829
#include "llvm/ADT/STLExtras.h"
30+
#include "llvm/ADT/Sequence.h"
2931
#include "llvm/ADT/SmallVector.h"
3032
#include "llvm/ADT/StringMap.h"
3133
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -110,6 +112,152 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
110112
bool need_elementwise_canonicalization(const CanonicalizeContext &ctx,
111113
Operation &op);
112114

115+
// Returns the collapsed lhs, rhs, acc and the new dimension numbers if the
116+
// non-contracting dims can be collapsed, otherwise returns std::nullopt.
117+
std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
118+
TypedValue<VectorType>, tpu::DotDimensionNumbersAttr>>
119+
collapse_matmul_non_contracting_dims(
120+
CanonicalBuilder &builder, TypedValue<VectorType> lhs,
121+
TypedValue<VectorType> rhs, TypedValue<VectorType> acc,
122+
const tpu::DotDimensionNumbersAttr &dimension_numbers) {
123+
// Collapse
124+
//
125+
// 1. [batch_dims, non_contracting_dims, contracting_dims] into
126+
// [batch_dims, prod(non_contracting_dims), contracting_dims] or
127+
// 2. [batch_dims, contracting_dims, non_contracting_dims] into
128+
// [batch_dims, contracting_dims, prod(non_contracting_dims)].
129+
//
130+
// Returns a tuple of [new_operand, new_non_contracting_dims,
131+
// new_contracting_dims]. new_operand is nullptr if the operand does not need
132+
// to be collapsed.
133+
// TODO(b/413194126): Some shapes will trigger unsupported
134+
// vector::ShapeCastOp.
135+
auto maybe_collapse_non_contracting_dims =
136+
[&](TypedValue<VectorType> operand,
137+
ArrayRef<int64_t> non_contracting_dims,
138+
ArrayRef<int64_t> contracting_dims, ArrayRef<int64_t> batch_dims)
139+
-> std::tuple<TypedValue<VectorType>, SmallVector<int64_t, 2>,
140+
SmallVector<int64_t, 2>> {
141+
VectorType vty = operand.getType();
142+
auto shape = vty.getShape();
143+
bool batch_dims_are_front =
144+
batch_dims == ArrayRef<int64_t>(llvm::to_vector(
145+
llvm::seq<int64_t>(0, batch_dims.size())));
146+
// Case 1.
147+
bool trailing_contracting_dims =
148+
contracting_dims ==
149+
ArrayRef<int64_t>(llvm::to_vector(llvm::seq<int64_t>(
150+
shape.size() - contracting_dims.size(), shape.size())));
151+
// Case 2.
152+
bool trailing_non_contracting_dims =
153+
non_contracting_dims ==
154+
ArrayRef<int64_t>(llvm::to_vector(llvm::seq<int64_t>(
155+
shape.size() - non_contracting_dims.size(), shape.size())));
156+
bool should_collapse_non_contracting_dims =
157+
batch_dims_are_front &&
158+
(trailing_contracting_dims || trailing_non_contracting_dims) &&
159+
non_contracting_dims.size() > 1;
160+
if (!should_collapse_non_contracting_dims) {
161+
return {nullptr, llvm::to_vector(non_contracting_dims),
162+
llvm::to_vector(contracting_dims)};
163+
}
164+
SmallVector<int64_t, 2> new_shape;
165+
auto batch_sizes = shape.take_front(batch_dims.size());
166+
new_shape.append(batch_sizes.begin(), batch_sizes.end());
167+
SmallVector<int64_t, 2> contracting_sizes;
168+
for (int64_t contracting_dim : contracting_dims) {
169+
contracting_sizes.push_back(shape[contracting_dim]);
170+
}
171+
int64_t collapsed_dim_size = std::accumulate(
172+
non_contracting_dims.begin(), non_contracting_dims.end(), 1,
173+
[&](int64_t a, int64_t b) { return a * shape[b]; });
174+
;
175+
if (trailing_contracting_dims) {
176+
new_shape.push_back(collapsed_dim_size);
177+
new_shape.append(contracting_sizes.begin(), contracting_sizes.end());
178+
} else {
179+
new_shape.append(contracting_sizes.begin(), contracting_sizes.end());
180+
new_shape.push_back(collapsed_dim_size);
181+
}
182+
auto new_operand =
183+
cast<TypedValue<VectorType>>(builder.create<vector::ShapeCastOp>(
184+
VectorType::get(new_shape, vty.getElementType()), operand));
185+
SmallVector<int64_t, 2> new_non_contracting_dims, new_contracting_dims;
186+
if (trailing_non_contracting_dims) {
187+
// Case 2 - contracting dims are not changed and non contracting dims are
188+
// changed to the last dim.
189+
new_contracting_dims = llvm::to_vector(contracting_dims);
190+
new_non_contracting_dims.push_back(new_shape.size() - 1);
191+
} else {
192+
// Case 1 - non contracting dims are collapsed in the middle so all
193+
// contracting dims are moved forward by (non_contracting_dims.size() -
194+
// 1).
195+
new_non_contracting_dims.push_back(batch_dims.size());
196+
for (int64_t contracting_dim : contracting_dims) {
197+
new_contracting_dims.push_back(contracting_dim -
198+
(non_contracting_dims.size() - 1));
199+
}
200+
}
201+
return {new_operand, new_non_contracting_dims, new_contracting_dims};
202+
};
203+
204+
auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
205+
maybe_collapse_non_contracting_dims(
206+
lhs, dimension_numbers.getLhsNonContractingDims(),
207+
dimension_numbers.getLhsContractingDims(),
208+
dimension_numbers.getLhsBatchDims());
209+
210+
auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
211+
maybe_collapse_non_contracting_dims(
212+
rhs, dimension_numbers.getRhsNonContractingDims(),
213+
dimension_numbers.getRhsContractingDims(),
214+
dimension_numbers.getRhsBatchDims());
215+
216+
// Nothing to collapse.
217+
if (!new_lhs && !new_rhs) {
218+
return std::nullopt;
219+
}
220+
221+
// Overwrite the operands if they were collapsed. We're going to access the
222+
// new shapes below.
223+
lhs = new_lhs ? new_lhs : lhs;
224+
rhs = new_rhs ? new_rhs : rhs;
225+
226+
SmallVector<int64_t, 2> new_output_dim_order;
227+
SmallVector<int64_t, 2> new_acc_shape;
228+
for (int64_t batch_dim : dimension_numbers.getLhsBatchDims()) {
229+
new_output_dim_order.push_back(0);
230+
new_output_dim_order.push_back(batch_dim);
231+
new_acc_shape.push_back(lhs.getType().getDimSize(batch_dim));
232+
}
233+
for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
234+
new_output_dim_order.push_back(0);
235+
new_output_dim_order.push_back(non_contracting_dim);
236+
new_acc_shape.push_back(lhs.getType().getDimSize(non_contracting_dim));
237+
}
238+
for (int64_t non_contracting_dim : new_rhs_non_contracting_dims) {
239+
new_output_dim_order.push_back(1);
240+
new_output_dim_order.push_back(non_contracting_dim);
241+
new_acc_shape.push_back(rhs.getType().getDimSize(non_contracting_dim));
242+
}
243+
244+
// Batch dims are always at the front of the lhs and rhs.
245+
tpu::DotDimensionNumbersAttr new_dimension_numbers =
246+
tpu::DotDimensionNumbersAttr::get(
247+
builder.getContext(), new_lhs_contracting_dims,
248+
new_rhs_contracting_dims, new_lhs_non_contracting_dims,
249+
new_rhs_non_contracting_dims, new_output_dim_order,
250+
dimension_numbers.getLhsBatchDims(),
251+
dimension_numbers.getRhsBatchDims());
252+
253+
// Reshape acc too.
254+
auto new_acc =
255+
cast<TypedValue<VectorType>>(builder.create<vector::ShapeCastOp>(
256+
VectorType::get(new_acc_shape, acc.getType().getElementType()), acc));
257+
258+
return std::make_tuple(lhs, rhs, new_acc, new_dimension_numbers);
259+
}
260+
113261
FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
114262
Operation &raw_op) {
115263
auto op = cast<tpu::MatmulOp>(raw_op);
@@ -122,13 +270,13 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
122270
auto rhs = op.getRhs();
123271
auto acc = op.getAcc();
124272

125-
const VectorType lhs_ty = lhs.getType();
126-
const VectorType rhs_ty = rhs.getType();
127-
const VectorType acc_ty = acc.getType();
273+
const VectorType old_lhs_ty = lhs.getType();
274+
const VectorType old_rhs_ty = rhs.getType();
275+
const VectorType old_acc_ty = acc.getType();
128276

129-
auto lhs_element_type = lhs_ty.getElementType();
130-
auto rhs_element_type = rhs_ty.getElementType();
131-
auto acc_element_type = acc_ty.getElementType();
277+
auto lhs_element_type = old_lhs_ty.getElementType();
278+
auto rhs_element_type = old_rhs_ty.getElementType();
279+
auto acc_element_type = old_acc_ty.getElementType();
132280

133281
// there are a few primary paths for dimension_numbers in matmul
134282
// 1) No dimension numbers provided -> set to default
@@ -146,6 +294,14 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
146294
// Dot dim API - dimensions are provided and are not default
147295
(op.getDimensionNumbers().value() !=
148296
defaultDimensionNumbers(builder, false, false))) {
297+
if (auto collapsed_operands_and_ddn = collapse_matmul_non_contracting_dims(
298+
builder, lhs, rhs, acc, *op.getDimensionNumbers())) {
299+
tpu::DotDimensionNumbersAttr new_dimension_numbers;
300+
std::tie(lhs, rhs, acc, new_dimension_numbers) =
301+
*collapsed_operands_and_ddn;
302+
op.setDimensionNumbersAttr(new_dimension_numbers);
303+
}
304+
149305
auto dimension_numbers = op.getDimensionNumbers();
150306
auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims();
151307
auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims();
@@ -156,11 +312,11 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
156312
// Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs
157313
// are the same
158314
// Invariant in matmul verifier: Exactly one contracting and non contracting
159-
// dim in each of lhs and rhs for now.
160-
batch_size =
161-
lhs_batch_dims.empty()
162-
? std::nullopt
163-
: std::optional<int64_t>(lhs_ty.getShape()[lhs_batch_dims[0]]);
315+
// dim in each of lhs and rhs at the moment.
316+
batch_size = lhs_batch_dims.empty()
317+
? std::nullopt
318+
: std::optional<int64_t>(
319+
lhs.getType().getShape()[lhs_batch_dims[0]]);
164320
// Lower each dim in contracting dims by size(batch_dims)
165321
auto batch_adjusted_lhs_contracting_dim =
166322
lhs_contracting_dims[0] - lhs_batch_dims.size();
@@ -175,6 +331,18 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
175331
}
176332
}
177333

334+
// Make sure there is only one non-contracting dim in each of lhs and rhs
335+
// after collapsing.
336+
auto dimension_numbers = op.getDimensionNumbers();
337+
if (dimension_numbers->getLhsNonContractingDims().size() != 1) {
338+
return op->emitOpError(
339+
"Not implemented: lhs non contracting dims must be of size 1");
340+
}
341+
if (dimension_numbers->getRhsNonContractingDims().size() != 1) {
342+
return op->emitOpError(
343+
"Not implemented: rhs non contracting dims must be of size 1");
344+
}
345+
178346
auto extsi_sitofp = [&builder, &op](TypedValue<VectorType> element) {
179347
const VectorType ty = element.getType();
180348
auto shape = ty.getShape();
@@ -247,7 +415,6 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
247415
// operation that fuses the transpose into the matmul.
248416
auto transpose_op =
249417
dyn_cast_if_present<tpu::TransposeOp>(rhs.getDefiningOp());
250-
auto dimension_numbers = op.getDimensionNumbers();
251418
if (transpose_op && transpose_op->hasOneUse() &&
252419
dimension_numbers->getRhsContractingDims().size() == 1 &&
253420
dimension_numbers->getRhsNonContractingDims().size() == 1) {
@@ -259,7 +426,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
259426
permutation[rhs_non_contracting_dim] == rhs_contracting_dim &&
260427
std::all_of(dimension_numbers->getRhsBatchDims().begin(),
261428
dimension_numbers->getRhsBatchDims().end(),
262-
[&](long batch_dim) {
429+
[&](int64_t batch_dim) {
263430
return permutation[batch_dim] == batch_dim;
264431
})) {
265432
if (auto transpose_op_vector_operand =
@@ -312,6 +479,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
312479
// If we have a batch_size, we want to slice rhs and lhs [:batch_size],
313480
// and then do O[i] = A[i] @ B[i]
314481
// Produce an output shape of [batch_size, m, n]
482+
Value res;
315483
if (batch_size.has_value()) {
316484
std::vector<Value> outputs;
317485

@@ -336,22 +504,22 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
336504
// Technically almost identical to the case where batch_size is 1, but
337505
// we want to avoid the spurious concat here.
338506
if (batch_size == 1) {
339-
op.replaceAllUsesWith(outputs[0]);
340-
op.erase();
341-
return outputs[0];
507+
res = outputs[0];
508+
} else {
509+
res = builder.create<tpu::ConcatenateOp>(acc.getType(), outputs,
510+
/*dimension=*/0);
342511
}
343-
auto output =
344-
builder.create<tpu::ConcatenateOp>(acc_ty, outputs, /*dimension=*/0);
345-
op.replaceAllUsesWith(output);
346-
op.erase();
347-
return output;
348512
} else {
349-
auto matmul_res = dot_dim_matmul(lhs, rhs, acc);
350-
op.replaceAllUsesWith(matmul_res);
351-
op.erase();
352-
return matmul_res;
513+
res = dot_dim_matmul(lhs, rhs, acc);
353514
}
354-
return op.getResult();
515+
516+
// Reshape the result to the old one as dims might have been collapsed.
517+
if (res.getType() != old_acc_ty) {
518+
res = builder.create<vector::ShapeCastOp>(old_acc_ty, res);
519+
}
520+
op.replaceAllUsesWith(res);
521+
op.erase();
522+
return res;
355523
};
356524

357525
FailureOr<Value> canonicalize_elementwise(const CanonicalizeContext &ctx,

0 commit comments

Comments
 (0)