Skip to content

Commit bf5de15

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] Support multiple non-contracting dims if they are collapsable.
With reshape, we can support the following two cases. 1. [batch_dims, non_contracting_dims, contracting_dims] -> [batch_dims, prod(non_contracting_dims), contracting_dims] or 2. [batch_dims, contracting_dims, non_contracting_dims] -> [batch_dims, contracting_dims, prod(non_contracting_dims)]. I'm reluctant to change apply vector layout and want to keep it to only handle 2D matrix. PiperOrigin-RevId: 780737523
1 parent 344b1bc commit bf5de15

File tree

4 files changed

+291
-37
lines changed

4 files changed

+291
-37
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2159,7 +2159,13 @@ def _dot_general_lowering_rule(
21592159
raise NotImplementedError(ctx.avals_out[0].dtype)
21602160
lhs_aval, rhs_aval = ctx.avals_in
21612161
# This is really a matrix-vector product. It only looks like matrix-matrix.
2162-
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
2162+
if (
2163+
lhs_dims == (1,)
2164+
and rhs_dims == (1,)
2165+
and ctx.avals_in[1].shape[0] == 1
2166+
and len(ctx.avals_in[0].shape) == 2
2167+
and len(ctx.avals_in[1].shape) == 2
2168+
):
21632169
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
21642170
bcast_shape = jnp.broadcast_shapes(
21652171
ctx.avals_in[0].shape, ctx.avals_out[0].shape

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,16 +1018,6 @@ LogicalResult MatmulOp::verify() {
10181018
// position 0. Future extensions to this will be to:
10191019
// 1. Support multiple batch dims
10201020
// 2. Support batch dims in any position in the output dim order
1021-
if (lhs_non_contracting_dims.size() != 1) {
1022-
emitOpError(
1023-
"Not implemented: lhs non contracting dims must be of size 1");
1024-
return failure();
1025-
}
1026-
if (rhs_non_contracting_dims.size() != 1) {
1027-
emitOpError(
1028-
"Not implemented: rhs non contracting dims must be of size 1");
1029-
return failure();
1030-
}
10311021

10321022
// A bit long winded, but the invariants we enforce below are:
10331023
// 1. The output order idx is 0 (lhs) or 1 (rhs)

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 204 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ limitations under the License.
2020
#include <memory>
2121
#include <numeric>
2222
#include <optional>
23+
#include <tuple>
2324
#include <utility>
2425
#include <vector>
2526

2627
#include "absl/log/check.h"
2728
#include "llvm/ADT/ArrayRef.h"
2829
#include "llvm/ADT/STLExtras.h"
30+
#include "llvm/ADT/Sequence.h"
2931
#include "llvm/ADT/SmallVector.h"
3032
#include "llvm/ADT/StringMap.h"
3133
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -110,6 +112,160 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
110112
bool need_elementwise_canonicalization(const CanonicalizeContext &ctx,
111113
Operation &op);
112114

115+
// Returns the collapsed lhs, rhs, acc and the new dimension numbers if the
116+
// non-contracting dims can be collapsed, otherwise returns std::nullopt.
117+
std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
118+
TypedValue<VectorType>, tpu::DotDimensionNumbersAttr>>
119+
collapse_matmul_non_contracting_dims(
120+
CanonicalBuilder &builder, TypedValue<VectorType> lhs,
121+
TypedValue<VectorType> rhs, TypedValue<VectorType> acc,
122+
const tpu::DotDimensionNumbersAttr &dimension_numbers) {
123+
// Collapse
124+
//
125+
// 1. [batch_dims, non_contracting_dims, contracting_dims] into
126+
// [batch_dims, prod(non_contracting_dims), contracting_dims] or
127+
// 2. [batch_dims, contracting_dims, non_contracting_dims] into
128+
// [batch_dims, contracting_dims, prod(non_contracting_dims)].
129+
//
130+
// Returns a tuple of [new_operand, new_non_contracting_dims,
131+
// new_contracting_dims]. new_operand is nullptr if the operand does not need
132+
// to be collapsed.
133+
// TODO(b/413194126): Some shapes will trigger unsupported
134+
// vector::ShapeCastOp.
135+
auto maybe_collapse_non_contracting_dims =
136+
[&](TypedValue<VectorType> operand,
137+
ArrayRef<int64_t> non_contracting_dims,
138+
ArrayRef<int64_t> contracting_dims, ArrayRef<int64_t> batch_dims)
139+
-> std::tuple<TypedValue<VectorType>, SmallVector<int64_t, 2>,
140+
SmallVector<int64_t, 2>> {
141+
VectorType vty = operand.getType();
142+
auto shape = vty.getShape();
143+
bool batch_dims_are_front =
144+
batch_dims == ArrayRef<int64_t>(llvm::to_vector(
145+
llvm::seq<int64_t>(0, batch_dims.size())));
146+
// Case 1.
147+
bool trailing_contracting_dims =
148+
contracting_dims ==
149+
ArrayRef<int64_t>(llvm::to_vector(llvm::seq<int64_t>(
150+
shape.size() - contracting_dims.size(), shape.size()))) &&
151+
non_contracting_dims ==
152+
ArrayRef<int64_t>(llvm::to_vector(llvm::seq<int64_t>(
153+
batch_dims.size(),
154+
batch_dims.size() + non_contracting_dims.size())));
155+
// Case 2.
156+
bool trailing_non_contracting_dims =
157+
non_contracting_dims ==
158+
ArrayRef<int64_t>(llvm::to_vector(llvm::seq<int64_t>(
159+
shape.size() - non_contracting_dims.size(), shape.size()))) &&
160+
contracting_dims ==
161+
ArrayRef<int64_t>(llvm::to_vector(llvm::seq<int64_t>(
162+
batch_dims.size(),
163+
batch_dims.size() + contracting_dims.size())));
164+
bool should_collapse_non_contracting_dims =
165+
batch_dims_are_front &&
166+
(trailing_contracting_dims || trailing_non_contracting_dims) &&
167+
non_contracting_dims.size() > 1;
168+
if (!should_collapse_non_contracting_dims) {
169+
return {nullptr, llvm::to_vector(non_contracting_dims),
170+
llvm::to_vector(contracting_dims)};
171+
}
172+
SmallVector<int64_t, 2> new_shape;
173+
auto batch_shape = shape.take_front(batch_dims.size());
174+
new_shape.append(batch_shape.begin(), batch_shape.end());
175+
SmallVector<int64_t, 2> contracting_sizes;
176+
for (int64_t contracting_dim : contracting_dims) {
177+
contracting_sizes.push_back(shape[contracting_dim]);
178+
}
179+
int64_t collapsed_dim_size = 1;
180+
for (int64_t non_contracting_dim : non_contracting_dims) {
181+
collapsed_dim_size *= shape[non_contracting_dim];
182+
}
183+
if (trailing_contracting_dims) {
184+
new_shape.push_back(collapsed_dim_size);
185+
new_shape.append(contracting_sizes.begin(), contracting_sizes.end());
186+
} else {
187+
new_shape.append(contracting_sizes.begin(), contracting_sizes.end());
188+
new_shape.push_back(collapsed_dim_size);
189+
}
190+
auto new_operand =
191+
cast<TypedValue<VectorType>>(builder.create<vector::ShapeCastOp>(
192+
VectorType::get(new_shape, vty.getElementType()), operand));
193+
SmallVector<int64_t, 2> new_non_contracting_dims, new_contracting_dims;
194+
if (trailing_non_contracting_dims) {
195+
// Case 2 - contracting dims are not changed and non contracting dims are
196+
// changed to the last dim.
197+
new_contracting_dims = llvm::to_vector(contracting_dims);
198+
new_non_contracting_dims.push_back(new_shape.size() - 1);
199+
} else {
200+
// Case 1 - non contracting dims are collapsed in the middle so all
201+
// contracting dims are moved forward by (non_contracting_dims.size() -
202+
// 1).
203+
new_non_contracting_dims.push_back(batch_dims.size());
204+
for (int64_t contracting_dim : contracting_dims) {
205+
new_contracting_dims.push_back(contracting_dim -
206+
(non_contracting_dims.size() - 1));
207+
}
208+
}
209+
return {new_operand, new_non_contracting_dims, new_contracting_dims};
210+
};
211+
212+
auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
213+
maybe_collapse_non_contracting_dims(
214+
lhs, dimension_numbers.getLhsNonContractingDims(),
215+
dimension_numbers.getLhsContractingDims(),
216+
dimension_numbers.getLhsBatchDims());
217+
218+
auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
219+
maybe_collapse_non_contracting_dims(
220+
rhs, dimension_numbers.getRhsNonContractingDims(),
221+
dimension_numbers.getRhsContractingDims(),
222+
dimension_numbers.getRhsBatchDims());
223+
224+
// Nothing to collapse.
225+
if (!new_lhs && !new_rhs) {
226+
return std::nullopt;
227+
}
228+
229+
// Overwrite the operands if they were collapsed. We're going to access the
230+
// new shapes below.
231+
lhs = new_lhs ? new_lhs : lhs;
232+
rhs = new_rhs ? new_rhs : rhs;
233+
234+
SmallVector<int64_t, 2> new_output_dim_order;
235+
SmallVector<int64_t, 2> new_acc_shape;
236+
for (int64_t batch_dim : dimension_numbers.getLhsBatchDims()) {
237+
new_output_dim_order.push_back(0);
238+
new_output_dim_order.push_back(batch_dim);
239+
new_acc_shape.push_back(lhs.getType().getDimSize(batch_dim));
240+
}
241+
for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
242+
new_output_dim_order.push_back(0);
243+
new_output_dim_order.push_back(non_contracting_dim);
244+
new_acc_shape.push_back(lhs.getType().getDimSize(non_contracting_dim));
245+
}
246+
for (int64_t non_contracting_dim : new_rhs_non_contracting_dims) {
247+
new_output_dim_order.push_back(1);
248+
new_output_dim_order.push_back(non_contracting_dim);
249+
new_acc_shape.push_back(rhs.getType().getDimSize(non_contracting_dim));
250+
}
251+
252+
// Batch dims are always at the front of the lhs and rhs.
253+
tpu::DotDimensionNumbersAttr new_dimension_numbers =
254+
tpu::DotDimensionNumbersAttr::get(
255+
builder.getContext(), new_lhs_contracting_dims,
256+
new_rhs_contracting_dims, new_lhs_non_contracting_dims,
257+
new_rhs_non_contracting_dims, new_output_dim_order,
258+
dimension_numbers.getLhsBatchDims(),
259+
dimension_numbers.getRhsBatchDims());
260+
261+
// Reshape acc too.
262+
auto new_acc =
263+
cast<TypedValue<VectorType>>(builder.create<vector::ShapeCastOp>(
264+
VectorType::get(new_acc_shape, acc.getType().getElementType()), acc));
265+
266+
return std::make_tuple(lhs, rhs, new_acc, new_dimension_numbers);
267+
}
268+
113269
FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
114270
Operation &raw_op) {
115271
auto op = cast<tpu::MatmulOp>(raw_op);
@@ -122,13 +278,13 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
122278
auto rhs = op.getRhs();
123279
auto acc = op.getAcc();
124280

125-
const VectorType lhs_ty = lhs.getType();
126-
const VectorType rhs_ty = rhs.getType();
127-
const VectorType acc_ty = acc.getType();
281+
const VectorType old_lhs_ty = lhs.getType();
282+
const VectorType old_rhs_ty = rhs.getType();
283+
const VectorType old_acc_ty = acc.getType();
128284

129-
auto lhs_element_type = lhs_ty.getElementType();
130-
auto rhs_element_type = rhs_ty.getElementType();
131-
auto acc_element_type = acc_ty.getElementType();
285+
auto lhs_element_type = old_lhs_ty.getElementType();
286+
auto rhs_element_type = old_rhs_ty.getElementType();
287+
auto acc_element_type = old_acc_ty.getElementType();
132288

133289
// there are a few primary paths for dimension_numbers in matmul
134290
// 1) No dimension numbers provided -> set to default
@@ -146,6 +302,14 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
146302
// Dot dim API - dimensions are provided and are not default
147303
(op.getDimensionNumbers().value() !=
148304
defaultDimensionNumbers(builder, false, false))) {
305+
if (auto collapsed_operands_and_ddn = collapse_matmul_non_contracting_dims(
306+
builder, lhs, rhs, acc, *op.getDimensionNumbers())) {
307+
tpu::DotDimensionNumbersAttr new_dimension_numbers;
308+
std::tie(lhs, rhs, acc, new_dimension_numbers) =
309+
*collapsed_operands_and_ddn;
310+
op.setDimensionNumbersAttr(new_dimension_numbers);
311+
}
312+
149313
auto dimension_numbers = op.getDimensionNumbers();
150314
auto lhs_contracting_dims = dimension_numbers->getLhsContractingDims();
151315
auto rhs_contracting_dims = dimension_numbers->getRhsContractingDims();
@@ -156,11 +320,11 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
156320
// Invariant in matmul verifier: <= 1 batch dim atm, and that lhs and rhs
157321
// are the same
158322
// Invariant in matmul verifier: Exactly one contracting and non contracting
159-
// dim in each of lhs and rhs for now.
160-
batch_size =
161-
lhs_batch_dims.empty()
162-
? std::nullopt
163-
: std::optional<int64_t>(lhs_ty.getShape()[lhs_batch_dims[0]]);
323+
// dim in each of lhs and rhs at the moment.
324+
batch_size = lhs_batch_dims.empty()
325+
? std::nullopt
326+
: std::optional<int64_t>(
327+
lhs.getType().getShape()[lhs_batch_dims[0]]);
164328
// Lower each dim in contracting dims by size(batch_dims)
165329
auto batch_adjusted_lhs_contracting_dim =
166330
lhs_contracting_dims[0] - lhs_batch_dims.size();
@@ -175,6 +339,20 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
175339
}
176340
}
177341

342+
// Make sure there is only one non-contracting dim in each of lhs and rhs
343+
// after collapsing.
344+
auto dimension_numbers = op.getDimensionNumbers();
345+
if (dimension_numbers->getLhsNonContractingDims().size() != 1) {
346+
return op->emitOpError(
347+
"Not implemented: lhs non contracting dims must be an infix/suffix of "
348+
"the shape.");
349+
}
350+
if (dimension_numbers->getRhsNonContractingDims().size() != 1) {
351+
return op->emitOpError(
352+
"Not implemented: rhs non contracting dims must be an infix/suffix of "
353+
"the shape.");
354+
}
355+
178356
auto extsi_sitofp = [&builder, &op](TypedValue<VectorType> element) {
179357
const VectorType ty = element.getType();
180358
auto shape = ty.getShape();
@@ -247,7 +425,6 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
247425
// operation that fuses the transpose into the matmul.
248426
auto transpose_op =
249427
dyn_cast_if_present<tpu::TransposeOp>(rhs.getDefiningOp());
250-
auto dimension_numbers = op.getDimensionNumbers();
251428
if (transpose_op && transpose_op->hasOneUse() &&
252429
dimension_numbers->getRhsContractingDims().size() == 1 &&
253430
dimension_numbers->getRhsNonContractingDims().size() == 1) {
@@ -259,7 +436,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
259436
permutation[rhs_non_contracting_dim] == rhs_contracting_dim &&
260437
std::all_of(dimension_numbers->getRhsBatchDims().begin(),
261438
dimension_numbers->getRhsBatchDims().end(),
262-
[&](long batch_dim) {
439+
[&](int64_t batch_dim) {
263440
return permutation[batch_dim] == batch_dim;
264441
})) {
265442
if (auto transpose_op_vector_operand =
@@ -312,6 +489,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
312489
// If we have a batch_size, we want to slice rhs and lhs [:batch_size],
313490
// and then do O[i] = A[i] @ B[i]
314491
// Produce an output shape of [batch_size, m, n]
492+
Value res;
315493
if (batch_size.has_value()) {
316494
std::vector<Value> outputs;
317495

@@ -336,22 +514,22 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
336514
// Technically almost identical to the case where batch_size is 1, but
337515
// we want to avoid the spurious concat here.
338516
if (batch_size == 1) {
339-
op.replaceAllUsesWith(outputs[0]);
340-
op.erase();
341-
return outputs[0];
517+
res = outputs[0];
518+
} else {
519+
res = builder.create<tpu::ConcatenateOp>(acc.getType(), outputs,
520+
/*dimension=*/0);
342521
}
343-
auto output =
344-
builder.create<tpu::ConcatenateOp>(acc_ty, outputs, /*dimension=*/0);
345-
op.replaceAllUsesWith(output);
346-
op.erase();
347-
return output;
348522
} else {
349-
auto matmul_res = dot_dim_matmul(lhs, rhs, acc);
350-
op.replaceAllUsesWith(matmul_res);
351-
op.erase();
352-
return matmul_res;
523+
res = dot_dim_matmul(lhs, rhs, acc);
524+
}
525+
526+
// Reshape the result to the old one as dims might have been collapsed.
527+
if (res.getType() != old_acc_ty) {
528+
res = builder.create<vector::ShapeCastOp>(old_acc_ty, res);
353529
}
354-
return op.getResult();
530+
op.replaceAllUsesWith(res);
531+
op.erase();
532+
return res;
355533
};
356534

357535
FailureOr<Value> canonicalize_elementwise(const CanonicalizeContext &ctx,

0 commit comments

Comments
 (0)