Skip to content

Commit 1671617

Browse files
bythew3ijax authors
authored andcommitted
[XLA:Mosaic] Support trunc/ext op for 1D vector with any implicit dim.
PiperOrigin-RevId: 626466602
1 parent 6e23c14 commit 1671617

File tree

3 files changed

+123
-145
lines changed

3 files changed

+123
-145
lines changed

jaxlib/mosaic/dialect/tpu/layout.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ class VectorLayout {
270270

271271
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
272272

273-
private:
274273
SmallVector<int64_t> tileArrayImplicitShape(
275274
ArrayRef<int64_t> shape, std::array<int64_t, 2> target_shape) const;
276275

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 96 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -658,15 +658,23 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
658658
const auto result_ty = cast<VectorType>(op.getResult().getType());
659659
auto source = cast<TypedValue<VectorType>>(op.getIn());
660660
const auto source_ty = source.getType();
661+
auto output_vregs_shape =
662+
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
661663
if (layout_out.bitwidth() != 32) {
662664
return op.emitOpError(
663665
"Not implemented: Only extensions to 32-bit supported");
664666
}
665667
FAILUREOR_ASSIGN_OR_RETURN(
666-
const xla::Array<Value> input_vregs,
668+
xla::Array<Value> input_vregs,
667669
disassemble(builder, layout_in, source, ctx.target_shape));
668-
xla::Array<Value> output_vregs(
669-
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
670+
xla::Array<Value> output_vregs(output_vregs_shape);
671+
// TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble?
672+
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
673+
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
674+
ctx.target_shape));
675+
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
676+
ctx.target_shape));
677+
}
670678
FAILUREOR_ASSIGN_OR_RETURN(
671679
const VectorType res_vreg_ty,
672680
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
@@ -676,51 +684,24 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
676684
if (layout_in.offsets() != layout_out.offsets()) {
677685
return op.emitOpError("Not implemented: Change of offsets during the cast");
678686
}
679-
switch (layout_in.implicit_dim()) {
680-
case VectorLayout::ImplicitDim::kNone: {
681-
if (layout_in.tiling() != layout_out.tiling()) {
682-
return op.emitOpError(
683-
"Not implemented: Changing tiling during the cast");
684-
}
685-
auto tiling = layout_in.tiling();
686-
if (ctx.target_shape[0] % tiling[0] != 0 ||
687-
ctx.target_shape[1] != tiling[1]) {
688-
return op.emitOpError("Not implemented: tiling not supported");
689-
}
690-
const int packing = layout_in.packing();
691-
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
692-
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
693-
input_vreg_idxs.back() /= packing;
694-
const int64_t vreg_part = idxs.back() % packing;
695-
*v = builder.create<UnpackSubelementsOp>(
696-
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
697-
});
698-
} break;
699-
case VectorLayout::ImplicitDim::kMinor:
700-
return op.emitOpError(
701-
"Not implemented: Only casts of lane-oriented values supported");
702-
case VectorLayout::ImplicitDim::kSecondMinor: {
703-
auto is_one_tile = [](VectorType vty, VectorLayout layout) {
704-
auto implicit_shape = layout.implicitShape(vty.getShape());
705-
auto tiled_shape = ArrayRef<int64_t>(implicit_shape).take_back(2);
706-
return (layout.offsets()[0].value_or(0) + tiled_shape[0] <=
707-
layout.tiling()[0]) &&
708-
(layout.offsets()[1].value_or(0) + tiled_shape[1] <=
709-
layout.tiling()[1]);
710-
};
711-
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
712-
output_vregs.dimensions() != absl::Span<const int64_t>{1} ||
713-
!is_one_tile(source_ty, layout_in) ||
714-
!is_one_tile(result_ty, layout_out)) {
715-
return op.emitOpError("Not implemented");
716-
}
717-
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
718-
return op.emitOpError("Not implemented");
719-
}
720-
auto unpack_subelements_op = builder.create<UnpackSubelementsOp>(
721-
res_vreg_ty, *input_vregs.begin(), 0);
722-
output_vregs.Fill(unpack_subelements_op.getResult());
723-
}
687+
if (layout_in.tiling() != layout_out.tiling()) {
688+
return op.emitOpError("Not implemented: Changing tiling during the cast");
689+
}
690+
auto tiling = layout_in.tiling();
691+
if (ctx.target_shape[0] % tiling[0] != 0 ||
692+
ctx.target_shape[1] != tiling[1]) {
693+
return op.emitOpError("Not implemented: tiling not supported");
694+
}
695+
const int packing = layout_in.packing();
696+
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
697+
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
698+
input_vreg_idxs.back() /= packing;
699+
const int64_t vreg_part = idxs.back() % packing;
700+
*v = builder.create<UnpackSubelementsOp>(
701+
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
702+
});
703+
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
704+
output_vregs.Reshape(output_vregs_shape);
724705
}
725706
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
726707
std::move(output_vregs), ctx.target_shape)
@@ -762,73 +743,85 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
762743
const VectorLayout &layout_in,
763744
const VectorLayout &layout_out) {
764745
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
746+
auto source = cast<TypedValue<VectorType>>(op.getIn());
747+
const auto source_ty = source.getType();
765748
auto result_ty = cast<VectorType>(op.getResult().getType());
749+
auto output_vregs_shape =
750+
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape);
766751
FAILUREOR_ASSIGN_OR_RETURN(
767-
const xla::Array<Value> input_vregs,
768-
disassemble(builder, layout_in, cast<TypedValue<VectorType>>(op.getIn()),
769-
ctx.target_shape));
770-
xla::Array<Value> output_vregs(
771-
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
752+
xla::Array<Value> input_vregs,
753+
disassemble(builder, layout_in, source, ctx.target_shape));
754+
xla::Array<Value> output_vregs(output_vregs_shape);
772755
if (layout_in.bitwidth() != 32) {
773756
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
774757
}
758+
if (layout_in.offsets() != layout_out.offsets()) {
759+
return op.emitOpError(
760+
"Not implemented: Change of offsets during the truncation");
761+
}
762+
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
763+
return op.emitOpError("Not implemented: Change of layout during the cast");
764+
}
765+
if (layout_in.tiling() != ctx.target_shape) {
766+
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
767+
}
768+
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
769+
input_vregs.Reshape(layout_in.tileArrayImplicitShape(source_ty.getShape(),
770+
ctx.target_shape));
771+
output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(),
772+
ctx.target_shape));
773+
}
775774
FAILUREOR_ASSIGN_OR_RETURN(
776775
VectorType res_vreg_ty,
777776
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
778-
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
779-
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
780-
if (layout_in.tiling() != ctx.target_shape) {
781-
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
782-
}
783-
if (layout_out.tiling() == ctx.target_shape) {
784-
const int packing = layout_out.packing();
785-
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
786-
SmallVector<Value> parts;
787-
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
788-
idxs_local.back() *= packing;
789-
for (int64_t i = 0; i < packing; ++i) {
790-
parts.push_back(input_vregs(idxs_local));
791-
// Pack any data lying around if OOB
792-
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
793-
++idxs_local.back();
794-
}
795-
}
796-
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
797-
});
798-
799-
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
800-
int packing = layout_out.packing();
777+
if (layout_out.tiling() == ctx.target_shape) {
778+
const int packing = layout_out.packing();
779+
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
801780
SmallVector<Value> parts;
802-
parts.reserve(packing);
803-
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
804-
CHECK_GE(idxs.size(), 2);
805-
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
806-
idxs_local[idxs.size() - 2] *= packing;
781+
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
782+
idxs_local.back() *= packing;
783+
for (int64_t i = 0; i < packing; ++i) {
807784
parts.push_back(input_vregs(idxs_local));
808-
idxs_local[idxs.size() - 2]++;
809-
while (parts.size() < packing) {
810-
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
811-
parts.push_back(input_vregs(idxs_local));
812-
idxs_local[idxs.size() - 2]++;
813-
} else {
814-
// Once we run out of tiles, we can pick any one we like.
815-
parts.push_back(parts.back());
816-
}
785+
// Pack any data lying around if OOB
786+
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
787+
++idxs_local.back();
817788
}
818-
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
819-
parts.clear();
820-
});
821-
} else {
822-
return op.emitOpError("Not implemented: unsupported output tiling");
823-
}
824-
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
825-
std::move(output_vregs), ctx.target_shape)
826-
.getResult());
827-
op.erase();
828-
return success();
789+
}
790+
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
791+
});
792+
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
793+
int packing = layout_out.packing();
794+
SmallVector<Value> parts;
795+
parts.reserve(packing);
796+
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
797+
CHECK_GE(idxs.size(), 2);
798+
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
799+
idxs_local[idxs.size() - 2] *= packing;
800+
parts.push_back(input_vregs(idxs_local));
801+
idxs_local[idxs.size() - 2]++;
802+
while (parts.size() < packing) {
803+
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
804+
parts.push_back(input_vregs(idxs_local));
805+
idxs_local[idxs.size() - 2]++;
806+
} else {
807+
// Once we run out of tiles, we can pick any one we like.
808+
parts.push_back(parts.back());
809+
}
810+
}
811+
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
812+
parts.clear();
813+
});
814+
} else {
815+
return op.emitOpError("Not implemented: unsupported output tiling");
829816
}
830-
// TODO(tlongeri): why wasn't this part of the original code?
831-
return op.emitOpError("Not implemented");
817+
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
818+
output_vregs.Reshape(output_vregs_shape);
819+
}
820+
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
821+
std::move(output_vregs), ctx.target_shape)
822+
.getResult());
823+
op.erase();
824+
return success();
832825
}
833826

834827
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,34 +1473,24 @@ class VectorLayoutInferer {
14731473
"Only extensions to 32-bit supported");
14741474
}
14751475
auto &layout = *some_layout;
1476-
if (layout.implicit_dim() == ImplicitDim::kNone) {
1477-
// TODO(apaszke): Support native packed layouts here.
1478-
Layout src_layout;
1479-
Layout dst_layout;
1480-
// All layouts that subdivide the rows of the default tiling evenly
1481-
// can be handled uniformly with the default case, by preserving the
1482-
// tiling through the op.
1483-
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
1484-
default_tiling_[1] == layout.tiling()[1]) {
1485-
src_layout = layout;
1486-
} else {
1487-
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
1488-
default_tiling_, ImplicitDim::kNone);
1489-
}
1490-
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
1491-
ImplicitDim::kNone);
1492-
setLayout(op, src_layout, dst_layout);
1493-
return success();
1494-
}
1495-
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
1496-
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
1497-
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
1498-
layout.implicit_dim());
1499-
setLayout(op, some_layout, dst_layout);
1500-
return success();
1476+
// TODO(apaszke): Support native packed layouts here.
1477+
Layout src_layout;
1478+
Layout dst_layout;
1479+
// All layouts that subdivide the rows of the default tiling evenly
1480+
// can be handled uniformly with the default case, by preserving the
1481+
// tiling through the op.
1482+
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
1483+
default_tiling_[1] == layout.tiling()[1]) {
1484+
src_layout = layout;
1485+
} else {
1486+
// TODO(b/335863273): we should also reduce offsets.
1487+
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
1488+
default_tiling_, layout.implicit_dim());
15011489
}
1502-
op->emitOpError("unsupported extension layout");
1503-
return failure();
1490+
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
1491+
layout.implicit_dim());
1492+
setLayout(op, src_layout, dst_layout);
1493+
return success();
15041494
}
15051495

15061496
LogicalResult inferTrunc(Operation *op) {
@@ -1523,20 +1513,16 @@ class VectorLayoutInferer {
15231513
"Only 32-bit truncation supported");
15241514
}
15251515
auto &layout = *some_layout;
1526-
if (layout.implicit_dim() == ImplicitDim::kNone) {
1527-
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
1528-
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
1529-
ImplicitDim::kNone);
1530-
auto dst_layout = VectorLayout(
1531-
dst_ty.getElementTypeBitWidth(), layout.offsets(),
1532-
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
1533-
: default_tiling_,
1534-
ImplicitDim::kNone);
1535-
setLayout(op, src_layout, dst_layout);
1536-
return success();
1537-
}
1538-
op->emitOpError("unsupported truncation layout");
1539-
return failure();
1516+
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
1517+
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
1518+
layout.implicit_dim());
1519+
auto dst_layout = VectorLayout(
1520+
dst_ty.getElementTypeBitWidth(), layout.offsets(),
1521+
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
1522+
: default_tiling_,
1523+
layout.implicit_dim());
1524+
setLayout(op, src_layout, dst_layout);
1525+
return success();
15401526
}
15411527

15421528
LogicalResult inferElementwise(Operation *op, bool check_bitwidth = true) {

0 commit comments

Comments
 (0)