Skip to content

Commit 57e34e1

Browse files
tlongerijax authors
authored andcommitted
[Mosaic][NFC] Use TypedValue<VectorType> instead of Value for applicable arguments/return values in disassemble and relayout
Ideally we would prefer `TypedValue<VectorType>` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array<Value>` to `xla::Array<TypedValue<VectorType>>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking. For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`. Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type. PiperOrigin-RevId: 610509743
1 parent ca1844d commit 57e34e1

File tree

4 files changed

+71
-59
lines changed

4 files changed

+71
-59
lines changed

jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,10 @@ MlirTpuValueArray mlirTpuDisassemble(MlirTpuInsertionPoint insertion_point,
349349
MlirTpuVectorLayout layout, MlirValue val,
350350
MlirTpuI64TargetTuple target_shape) {
351351
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
352+
// This cast will fail and assert if the caller passed a non-vector
353+
auto vector_val = mlir::cast<mlir::TypedValue<mlir::VectorType>>(unwrap(val));
352354
mlir::FailureOr<xla::Array<mlir::Value>> failure_or_vals =
353-
mlir::tpu::disassemble(builder, *unwrap(layout), unwrap(val),
355+
mlir::tpu::disassemble(builder, *unwrap(layout), vector_val,
354356
unwrap(target_shape));
355357
if (failed(failure_or_vals)) {
356358
return {{nullptr, 0}, nullptr};
@@ -371,8 +373,11 @@ MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val,
371373
MlirTpuVectorLayout src, MlirTpuVectorLayout dst,
372374
MlirTpuI64TargetTuple target_shape) {
373375
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
374-
mlir::FailureOr<mlir::Value> failure_or_new_val = mlir::tpu::relayout(
375-
builder, unwrap(val), *unwrap(src), *unwrap(dst), unwrap(target_shape));
376+
// This cast will fail and assert if the caller passed a non-vector
377+
auto vector_val = mlir::cast<mlir::TypedValue<mlir::VectorType>>(unwrap(val));
378+
mlir::FailureOr<mlir::TypedValue<mlir::VectorType>> failure_or_new_val =
379+
mlir::tpu::relayout(builder, vector_val, *unwrap(src), *unwrap(dst),
380+
unwrap(target_shape));
376381
if (failed(failure_or_new_val)) {
377382
return {nullptr};
378383
}

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,16 +325,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
325325
}
326326

327327
def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
328-
let arguments = (ins Variadic<AnyType>:$input);
329-
let results = (outs AnyType:$output);
328+
let arguments = (ins Variadic<AnyVector>:$input);
329+
let results = (outs AnyVector:$output);
330330
let assemblyFormat = [{
331331
$input attr-dict `:` type($input) `->` type($output)
332332
}];
333333
}
334334

335335
def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> {
336-
let arguments = (ins AnyType:$input);
337-
let results = (outs Variadic<AnyType>:$output);
336+
let arguments = (ins AnyVector:$input);
337+
let results = (outs Variadic<AnyVector>:$output);
338338
let hasCanonicalizeMethod = 1;
339339
let assemblyFormat = [{
340340
$input attr-dict `:` type($input) `->` type($output)

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,11 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
590590
SmallVector<xla::Array<Value>> in_vreg_arrays;
591591
in_vreg_arrays.reserve(num_operands);
592592
for (unsigned i = 0; i < num_operands; ++i) {
593-
FAILUREOR_ASSIGN_OR_RETURN(xla::Array<Value> tile_array,
594-
disassemble(builder, *layouts_in[i],
595-
op.getOperand(i), ctx.target_shape));
593+
FAILUREOR_ASSIGN_OR_RETURN(
594+
xla::Array<Value> tile_array,
595+
disassemble(builder, *layouts_in[i],
596+
cast<TypedValue<VectorType>>(op.getOperand(i)),
597+
ctx.target_shape));
596598
in_vreg_arrays.emplace_back(std::move(tile_array));
597599
}
598600

@@ -653,15 +655,16 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
653655
const VectorLayout &layout_in,
654656
const VectorLayout &layout_out) {
655657
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
656-
auto result_ty = cast<VectorType>(op.getResult().getType());
657-
auto source_ty = cast<VectorType>(op.getIn().getType());
658+
const auto result_ty = cast<VectorType>(op.getResult().getType());
659+
auto source = cast<TypedValue<VectorType>>(op.getIn());
660+
const auto source_ty = source.getType();
658661
if (layout_out.bitwidth() != 32) {
659662
return op.emitOpError(
660663
"Not implemented: Only extensions to 32-bit supported");
661664
}
662665
FAILUREOR_ASSIGN_OR_RETURN(
663666
const xla::Array<Value> input_vregs,
664-
disassemble(builder, layout_in, op.getIn(), ctx.target_shape));
667+
disassemble(builder, layout_in, source, ctx.target_shape));
665668
xla::Array<Value> output_vregs(
666669
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
667670
FAILUREOR_ASSIGN_OR_RETURN(
@@ -762,7 +765,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
762765
auto result_ty = cast<VectorType>(op.getResult().getType());
763766
FAILUREOR_ASSIGN_OR_RETURN(
764767
const xla::Array<Value> input_vregs,
765-
disassemble(builder, layout_in, op.getIn(), ctx.target_shape));
768+
disassemble(builder, layout_in, cast<TypedValue<VectorType>>(op.getIn()),
769+
ctx.target_shape));
766770
xla::Array<Value> output_vregs(
767771
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
768772
if (layout_in.bitwidth() != 32) {
@@ -905,13 +909,13 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
905909
}
906910
continue;
907911
}
908-
if (auto vty = dyn_cast<VectorType>(operand.getType())) {
912+
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
909913
if (!layout.has_value()) {
910914
return op.emitOpError("Expected layout for vector operand");
911915
}
912916
FAILUREOR_ASSIGN_OR_RETURN(
913917
const xla::Array<Value> tiles,
914-
disassemble(builder, *layout, operand, ctx.target_shape));
918+
disassemble(builder, *layout, vector_operand, ctx.target_shape));
915919
unrolled_args.append(tiles.begin(), tiles.end());
916920
} else {
917921
if (layout.has_value()) {
@@ -1098,12 +1102,12 @@ LogicalResult scf_yield_rule(RewriteContext &ctx, Operation &op,
10981102
SmallVector<Value> unrolled;
10991103
for (auto [operand, layout] :
11001104
llvm::zip_equal(yield_op.getOperands(), layouts_in)) {
1101-
if (auto vty = dyn_cast<VectorType>(operand.getType())) {
1105+
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
11021106
// When the operand has vector type, disassemble the operand.
11031107
TPU_ASSERT_OP(layout.has_value());
11041108
FAILUREOR_ASSIGN_OR_RETURN(
11051109
const xla::Array<Value> tiles,
1106-
disassemble(builder, *layout, operand, ctx.target_shape));
1110+
disassemble(builder, *layout, vector_operand, ctx.target_shape));
11071111
unrolled.append(tiles.begin(), tiles.end());
11081112
} else {
11091113
TPU_ASSERT_OP(!layout.has_value());
@@ -1745,7 +1749,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
17451749
for (Value operand : concatenate_op.getOperands()) {
17461750
FAILUREOR_ASSIGN_OR_RETURN(
17471751
xla::Array<Value> t,
1748-
disassemble(builder, layout, operand, ctx.target_shape));
1752+
disassemble(builder, layout, cast<TypedValue<VectorType>>(operand),
1753+
ctx.target_shape));
17491754
tiles.emplace_back(std::move(t));
17501755
}
17511756
const xla::Array<Value> res_tiles = concatenate(tiles, dimension);
@@ -2227,7 +2232,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
22272232
const VectorType dst_ty = broadcast_op.getResult().getType();
22282233
const SmallVector<int64_t> dst_tiles_shape =
22292234
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape);
2230-
if (auto src_ty = dyn_cast<VectorType>(broadcast_op.getSourceType())) {
2235+
if (auto src = dyn_cast<TypedValue<VectorType>>(broadcast_op.getSource())) {
2236+
VectorType src_ty = src.getType();
22312237
TPU_ASSERT_OP(maybe_layout_in.has_value());
22322238
const VectorLayout &layout_in = *maybe_layout_in;
22332239
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
@@ -2301,8 +2307,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
23012307

23022308
FAILUREOR_ASSIGN_OR_RETURN(
23032309
xla::Array<Value> src_tiles,
2304-
disassemble(builder, layout_in, broadcast_op.getSource(),
2305-
ctx.target_shape));
2310+
disassemble(builder, layout_in, src, ctx.target_shape));
23062311
xla::Array<Value> dst_tiles(dst_tiles_shape);
23072312
if (no_op) {
23082313
SmallVector<int64_t> reshape_dims(expand_rank, 1);
@@ -2666,10 +2671,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
26662671
return multi_reduction_op.emitOpError(
26672672
"Not implemented: Can only reduce into vectors");
26682673
}
2669-
if (!layouts_out.front().has_value()) {
2670-
// Shouldn't be empty since result is a vector
2671-
return op.emitOpError("Expected non-null output layout");
2672-
}
2674+
// Op definition enforces that accumulator type must match result type
2675+
auto acc = cast<TypedValue<VectorType>>(multi_reduction_op.getAcc());
2676+
TPU_ASSERT_OP(layouts_out.front().has_value());
26732677

26742678
const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims();
26752679
SmallVector<int64_t> dims;
@@ -2686,11 +2690,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
26862690
}
26872691
FAILUREOR_ASSIGN_OR_RETURN(
26882692
const xla::Array<Value> acc_vregs,
2689-
disassemble(builder, acc_layout, multi_reduction_op.getAcc(),
2690-
ctx.target_shape));
2691-
const Value acc_vreg = *acc_vregs.begin();
2692-
auto acc_def =
2693-
dyn_cast_if_present<arith::ConstantOp>(acc_vreg.getDefiningOp());
2693+
disassemble(builder, acc_layout, acc, ctx.target_shape));
2694+
auto acc_def = dyn_cast_if_present<arith::ConstantOp>(
2695+
acc_vregs.begin()->getDefiningOp());
26942696
if (acc_def == nullptr) {
26952697
return multi_reduction_op.emitOpError(
26962698
"Not implemented: Only constant accumulator supported");
@@ -2838,7 +2840,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
28382840
}
28392841
xla::Array<Value> reduced_vregs =
28402842
src_vregs.Slice(src_slice_start, src_slice_end);
2841-
std::optional<Value> acc;
2843+
std::optional<Value> acc_vreg;
28422844
auto reduction_status = reduced_vregs.EachStatus(
28432845
[&](const absl::Span<const int64_t> red_idx,
28442846
Value *const src_vreg) {
@@ -2860,17 +2862,17 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
28602862
return absl::UnknownError("");
28612863
}
28622864
Value vreg = failure_or_vreg.value();
2863-
if (!acc.has_value()) {
2864-
acc = vreg;
2865+
if (!acc_vreg.has_value()) {
2866+
acc_vreg = vreg;
28652867
} else {
28662868
switch (tpu_kind) {
28672869
case tpu::ReductionKind::SUM:
2868-
acc = builder.create<arith::AddFOp>(vreg.getLoc(), *acc,
2869-
vreg);
2870+
acc_vreg = builder.create<arith::AddFOp>(vreg.getLoc(),
2871+
*acc_vreg, vreg);
28702872
break;
28712873
case tpu::ReductionKind::MAX:
2872-
acc = builder.create<arith::MaximumFOp>(vreg.getLoc(), *acc,
2873-
vreg);
2874+
acc_vreg = builder.create<arith::MaximumFOp>(
2875+
vreg.getLoc(), *acc_vreg, vreg);
28742876
break;
28752877
}
28762878
}
@@ -2879,16 +2881,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
28792881
if (!reduction_status.ok()) {
28802882
return reduction_status;
28812883
}
2882-
TPU_ASSERT_OP(acc.has_value());
2884+
TPU_ASSERT_OP(acc_vreg.has_value());
28832885
if (reduces[1]) {
2884-
acc = builder.create<tpu::AllReduceOp>(multi_reduction_op->getLoc(),
2885-
*acc, 1, tpu_kind);
2886+
acc_vreg = builder.create<tpu::AllReduceOp>(
2887+
multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind);
28862888
}
28872889
if (reduces[0]) {
2888-
acc = builder.create<tpu::AllReduceOp>(multi_reduction_op->getLoc(),
2889-
*acc, 0, tpu_kind);
2890+
acc_vreg = builder.create<tpu::AllReduceOp>(
2891+
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
28902892
}
2891-
*dst_vreg = *acc;
2893+
*dst_vreg = *acc_vreg;
28922894
return absl::OkStatus();
28932895
});
28942896
if (!all_results_ok.ok()) {
@@ -3478,9 +3480,10 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
34783480
// Returns:
34793481
// An ndarray of MLIR values representing the tiling of val given by layout.
34803482
FailureOr<xla::Array<Value>> disassemble(
3481-
OpBuilder &builder, const VectorLayout &layout, const Value val,
3483+
OpBuilder &builder, const VectorLayout &layout,
3484+
const TypedValue<VectorType> val,
34823485
const std::array<int64_t, 2> target_shape) {
3483-
const auto vty = cast<VectorType>(val.getType());
3486+
const auto vty = val.getType();
34843487
const auto op_result = dyn_cast<OpResult>(val);
34853488
if (op_result == nullptr) {
34863489
return failure();
@@ -3869,15 +3872,15 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx,
38693872
}
38703873

38713874
// TODO(apaszke): Test this function properly
3872-
FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
3873-
const VectorLayout &dst,
3874-
const std::array<int64_t, 2> target_shape) {
3875+
FailureOr<TypedValue<VectorType>> relayout(
3876+
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
3877+
const VectorLayout &dst, const std::array<int64_t, 2> target_shape) {
38753878
const int8_t bitwidth = src.bitwidth();
38763879
if (bitwidth != dst.bitwidth()) {
38773880
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
38783881
}
38793882
const int packing = src.packing();
3880-
VectorType vty = cast<VectorType>(v.getType());
3883+
VectorType vty = v.getType();
38813884
FAILUREOR_ASSIGN_OR_RETURN(xla::Array<Value> src_tiles,
38823885
disassemble(builder, src, v, target_shape));
38833886
SmallVector<int64_t> dst_tiles_shape =
@@ -4202,16 +4205,17 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
42024205
for (auto [idx, tup] :
42034206
llvm::enumerate(llvm::zip(op.getOperands(), layouts_in))) {
42044207
auto [operand, li] = tup;
4205-
auto vty = dyn_cast<VectorType>(operand.getType());
4206-
TPU_ASSERT_EQ_OP(vty != nullptr, li.has_value());
4207-
if (vty == nullptr) {
4208+
auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand);
4209+
TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value());
4210+
if (vector_operand == nullptr) {
42084211
continue;
42094212
}
4213+
auto vty = vector_operand.getType();
42104214

42114215
// The operand should always be an Operation (and not a BlockArgument)
42124216
// since we expect the FuncOp to have only memrefs and semaphores as
42134217
// arguments.
4214-
auto op_result = dyn_cast<OpResult>(operand);
4218+
auto op_result = dyn_cast<OpResult>(vector_operand);
42154219
if (op_result == nullptr) {
42164220
return op.emitError("Expected operand to be an operation result");
42174221
}
@@ -4227,7 +4231,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
42274231
}
42284232
OpBuilder builder(&op);
42294233
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
4230-
relayout(builder, operand, /*src=*/*lo,
4234+
relayout(builder, vector_operand, /*src=*/*lo,
42314235
/*dst=*/*li, ctx.target_shape));
42324236
op.setOperand(idx, new_v);
42334237
}

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
2929
const xla::Array<Value> &vals,
3030
std::array<int64_t, 2> target_shape);
3131
FailureOr<xla::Array<Value>> disassemble(OpBuilder &builder,
32-
const VectorLayout &layout, Value val,
32+
const VectorLayout &layout,
33+
TypedValue<VectorType> val,
3334
std::array<int64_t, 2> target_shape);
3435

3536
// Rewrites the operation according to its layout annotations.
@@ -55,9 +56,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op);
5556
//
5657
// Returns:
5758
// A new MLIR vector value, laid out as requested by dst.
58-
FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
59-
const VectorLayout &dst,
60-
std::array<int64_t, 2> target_shape);
59+
FailureOr<TypedValue<VectorType>> relayout(OpBuilder &builder,
60+
TypedValue<VectorType> v,
61+
VectorLayout src,
62+
const VectorLayout &dst,
63+
std::array<int64_t, 2> target_shape);
6164

6265
} // namespace mlir::tpu
6366

0 commit comments

Comments
 (0)