@@ -590,9 +590,11 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
590
590
SmallVector<xla::Array<Value>> in_vreg_arrays;
591
591
in_vreg_arrays.reserve (num_operands);
592
592
for (unsigned i = 0 ; i < num_operands; ++i) {
593
- FAILUREOR_ASSIGN_OR_RETURN (xla::Array<Value> tile_array,
594
- disassemble (builder, *layouts_in[i],
595
- op.getOperand (i), ctx.target_shape ));
593
+ FAILUREOR_ASSIGN_OR_RETURN (
594
+ xla::Array<Value> tile_array,
595
+ disassemble (builder, *layouts_in[i],
596
+ cast<TypedValue<VectorType>>(op.getOperand (i)),
597
+ ctx.target_shape ));
596
598
in_vreg_arrays.emplace_back (std::move (tile_array));
597
599
}
598
600
@@ -653,15 +655,16 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
653
655
const VectorLayout &layout_in,
654
656
const VectorLayout &layout_out) {
655
657
ImplicitLocOpBuilder builder (op.getLoc (), op.getOperation ());
656
- auto result_ty = cast<VectorType>(op.getResult ().getType ());
657
- auto source_ty = cast<VectorType>(op.getIn ().getType ());
658
+ const auto result_ty = cast<VectorType>(op.getResult ().getType ());
659
+ auto source = cast<TypedValue<VectorType>>(op.getIn ());
660
+ const auto source_ty = source.getType ();
658
661
if (layout_out.bitwidth () != 32 ) {
659
662
return op.emitOpError (
660
663
" Not implemented: Only extensions to 32-bit supported" );
661
664
}
662
665
FAILUREOR_ASSIGN_OR_RETURN (
663
666
const xla::Array<Value> input_vregs,
664
- disassemble (builder, layout_in, op. getIn () , ctx.target_shape ));
667
+ disassemble (builder, layout_in, source , ctx.target_shape ));
665
668
xla::Array<Value> output_vregs (
666
669
layout_out.tileArrayShape (result_ty.getShape (), ctx.target_shape ));
667
670
FAILUREOR_ASSIGN_OR_RETURN (
@@ -762,7 +765,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
762
765
auto result_ty = cast<VectorType>(op.getResult ().getType ());
763
766
FAILUREOR_ASSIGN_OR_RETURN (
764
767
const xla::Array<Value> input_vregs,
765
- disassemble (builder, layout_in, op.getIn (), ctx.target_shape ));
768
+ disassemble (builder, layout_in, cast<TypedValue<VectorType>>(op.getIn ()),
769
+ ctx.target_shape ));
766
770
xla::Array<Value> output_vregs (
767
771
layout_out.tileArrayShape (result_ty.getShape (), ctx.target_shape ));
768
772
if (layout_in.bitwidth () != 32 ) {
@@ -905,13 +909,13 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
905
909
}
906
910
continue ;
907
911
}
908
- if (auto vty = dyn_cast<VectorType>(operand. getType () )) {
912
+ if (auto vector_operand = dyn_cast<TypedValue< VectorType>> (operand)) {
909
913
if (!layout.has_value ()) {
910
914
return op.emitOpError (" Expected layout for vector operand" );
911
915
}
912
916
FAILUREOR_ASSIGN_OR_RETURN (
913
917
const xla::Array<Value> tiles,
914
- disassemble (builder, *layout, operand , ctx.target_shape ));
918
+ disassemble (builder, *layout, vector_operand , ctx.target_shape ));
915
919
unrolled_args.append (tiles.begin (), tiles.end ());
916
920
} else {
917
921
if (layout.has_value ()) {
@@ -1098,12 +1102,12 @@ LogicalResult scf_yield_rule(RewriteContext &ctx, Operation &op,
1098
1102
SmallVector<Value> unrolled;
1099
1103
for (auto [operand, layout] :
1100
1104
llvm::zip_equal (yield_op.getOperands (), layouts_in)) {
1101
- if (auto vty = dyn_cast<VectorType>(operand. getType () )) {
1105
+ if (auto vector_operand = dyn_cast<TypedValue< VectorType>> (operand)) {
1102
1106
// When the operand has vector type, disassemble the operand.
1103
1107
TPU_ASSERT_OP (layout.has_value ());
1104
1108
FAILUREOR_ASSIGN_OR_RETURN (
1105
1109
const xla::Array<Value> tiles,
1106
- disassemble (builder, *layout, operand , ctx.target_shape ));
1110
+ disassemble (builder, *layout, vector_operand , ctx.target_shape ));
1107
1111
unrolled.append (tiles.begin (), tiles.end ());
1108
1112
} else {
1109
1113
TPU_ASSERT_OP (!layout.has_value ());
@@ -1745,7 +1749,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
1745
1749
for (Value operand : concatenate_op.getOperands ()) {
1746
1750
FAILUREOR_ASSIGN_OR_RETURN (
1747
1751
xla::Array<Value> t,
1748
- disassemble (builder, layout, operand, ctx.target_shape ));
1752
+ disassemble (builder, layout, cast<TypedValue<VectorType>>(operand),
1753
+ ctx.target_shape ));
1749
1754
tiles.emplace_back (std::move (t));
1750
1755
}
1751
1756
const xla::Array<Value> res_tiles = concatenate (tiles, dimension);
@@ -2227,7 +2232,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
2227
2232
const VectorType dst_ty = broadcast_op.getResult ().getType ();
2228
2233
const SmallVector<int64_t > dst_tiles_shape =
2229
2234
layout_out.tileArrayShape (dst_ty.getShape (), ctx.target_shape );
2230
- if (auto src_ty = dyn_cast<VectorType>(broadcast_op.getSourceType ())) {
2235
+ if (auto src = dyn_cast<TypedValue<VectorType>>(broadcast_op.getSource ())) {
2236
+ VectorType src_ty = src.getType ();
2231
2237
TPU_ASSERT_OP (maybe_layout_in.has_value ());
2232
2238
const VectorLayout &layout_in = *maybe_layout_in;
2233
2239
if (layout_in.implicit_dim () != layout_out.implicit_dim ()) {
@@ -2301,8 +2307,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
2301
2307
2302
2308
FAILUREOR_ASSIGN_OR_RETURN (
2303
2309
xla::Array<Value> src_tiles,
2304
- disassemble (builder, layout_in, broadcast_op.getSource (),
2305
- ctx.target_shape ));
2310
+ disassemble (builder, layout_in, src, ctx.target_shape ));
2306
2311
xla::Array<Value> dst_tiles (dst_tiles_shape);
2307
2312
if (no_op) {
2308
2313
SmallVector<int64_t > reshape_dims (expand_rank, 1 );
@@ -2666,10 +2671,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
2666
2671
return multi_reduction_op.emitOpError (
2667
2672
" Not implemented: Can only reduce into vectors" );
2668
2673
}
2669
- if (!layouts_out.front ().has_value ()) {
2670
- // Shouldn't be empty since result is a vector
2671
- return op.emitOpError (" Expected non-null output layout" );
2672
- }
2674
+ // Op definition enforces that accumulator type must match result type
2675
+ auto acc = cast<TypedValue<VectorType>>(multi_reduction_op.getAcc ());
2676
+ TPU_ASSERT_OP (layouts_out.front ().has_value ());
2673
2677
2674
2678
const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims ();
2675
2679
SmallVector<int64_t > dims;
@@ -2686,11 +2690,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
2686
2690
}
2687
2691
FAILUREOR_ASSIGN_OR_RETURN (
2688
2692
const xla::Array<Value> acc_vregs,
2689
- disassemble (builder, acc_layout, multi_reduction_op.getAcc (),
2690
- ctx.target_shape ));
2691
- const Value acc_vreg = *acc_vregs.begin ();
2692
- auto acc_def =
2693
- dyn_cast_if_present<arith::ConstantOp>(acc_vreg.getDefiningOp ());
2693
+ disassemble (builder, acc_layout, acc, ctx.target_shape ));
2694
+ auto acc_def = dyn_cast_if_present<arith::ConstantOp>(
2695
+ acc_vregs.begin ()->getDefiningOp ());
2694
2696
if (acc_def == nullptr ) {
2695
2697
return multi_reduction_op.emitOpError (
2696
2698
" Not implemented: Only constant accumulator supported" );
@@ -2838,7 +2840,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
2838
2840
}
2839
2841
xla::Array<Value> reduced_vregs =
2840
2842
src_vregs.Slice (src_slice_start, src_slice_end);
2841
- std::optional<Value> acc ;
2843
+ std::optional<Value> acc_vreg ;
2842
2844
auto reduction_status = reduced_vregs.EachStatus (
2843
2845
[&](const absl::Span<const int64_t > red_idx,
2844
2846
Value *const src_vreg) {
@@ -2860,17 +2862,17 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
2860
2862
return absl::UnknownError (" " );
2861
2863
}
2862
2864
Value vreg = failure_or_vreg.value ();
2863
- if (!acc .has_value ()) {
2864
- acc = vreg;
2865
+ if (!acc_vreg .has_value ()) {
2866
+ acc_vreg = vreg;
2865
2867
} else {
2866
2868
switch (tpu_kind) {
2867
2869
case tpu::ReductionKind::SUM:
2868
- acc = builder.create <arith::AddFOp>(vreg.getLoc (), *acc ,
2869
- vreg);
2870
+ acc_vreg = builder.create <arith::AddFOp>(vreg.getLoc (),
2871
+ *acc_vreg, vreg);
2870
2872
break ;
2871
2873
case tpu::ReductionKind::MAX:
2872
- acc = builder.create <arith::MaximumFOp>(vreg. getLoc (), *acc,
2873
- vreg);
2874
+ acc_vreg = builder.create <arith::MaximumFOp>(
2875
+ vreg. getLoc (), *acc_vreg, vreg);
2874
2876
break ;
2875
2877
}
2876
2878
}
@@ -2879,16 +2881,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
2879
2881
if (!reduction_status.ok ()) {
2880
2882
return reduction_status;
2881
2883
}
2882
- TPU_ASSERT_OP (acc .has_value ());
2884
+ TPU_ASSERT_OP (acc_vreg .has_value ());
2883
2885
if (reduces[1 ]) {
2884
- acc = builder.create <tpu::AllReduceOp>(multi_reduction_op-> getLoc (),
2885
- *acc , 1 , tpu_kind);
2886
+ acc_vreg = builder.create <tpu::AllReduceOp>(
2887
+ multi_reduction_op-> getLoc (), *acc_vreg , 1 , tpu_kind);
2886
2888
}
2887
2889
if (reduces[0 ]) {
2888
- acc = builder.create <tpu::AllReduceOp>(multi_reduction_op-> getLoc (),
2889
- *acc , 0 , tpu_kind);
2890
+ acc_vreg = builder.create <tpu::AllReduceOp>(
2891
+ multi_reduction_op-> getLoc (), *acc_vreg , 0 , tpu_kind);
2890
2892
}
2891
- *dst_vreg = *acc ;
2893
+ *dst_vreg = *acc_vreg ;
2892
2894
return absl::OkStatus ();
2893
2895
});
2894
2896
if (!all_results_ok.ok ()) {
@@ -3478,9 +3480,10 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
3478
3480
// Returns:
3479
3481
// An ndarray of MLIR values representing the tiling of val given by layout.
3480
3482
FailureOr<xla::Array<Value>> disassemble (
3481
- OpBuilder &builder, const VectorLayout &layout, const Value val,
3483
+ OpBuilder &builder, const VectorLayout &layout,
3484
+ const TypedValue<VectorType> val,
3482
3485
const std::array<int64_t , 2 > target_shape) {
3483
- const auto vty = cast<VectorType>( val.getType () );
3486
+ const auto vty = val.getType ();
3484
3487
const auto op_result = dyn_cast<OpResult>(val);
3485
3488
if (op_result == nullptr ) {
3486
3489
return failure ();
@@ -3869,15 +3872,15 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx,
3869
3872
}
3870
3873
3871
3874
// TODO(apaszke): Test this function properly
3872
- FailureOr<Value> relayout (OpBuilder &builder, Value v, VectorLayout src,
3873
- const VectorLayout &dst ,
3874
- const std::array<int64_t , 2 > target_shape) {
3875
+ FailureOr<TypedValue<VectorType>> relayout (
3876
+ OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src ,
3877
+ const VectorLayout &dst, const std::array<int64_t , 2 > target_shape) {
3875
3878
const int8_t bitwidth = src.bitwidth ();
3876
3879
if (bitwidth != dst.bitwidth ()) {
3877
3880
return emitError (v.getLoc (), " Can't change bitwidth during a relayout" );
3878
3881
}
3879
3882
const int packing = src.packing ();
3880
- VectorType vty = cast<VectorType>( v.getType () );
3883
+ VectorType vty = v.getType ();
3881
3884
FAILUREOR_ASSIGN_OR_RETURN (xla::Array<Value> src_tiles,
3882
3885
disassemble (builder, src, v, target_shape));
3883
3886
SmallVector<int64_t > dst_tiles_shape =
@@ -4202,16 +4205,17 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
4202
4205
for (auto [idx, tup] :
4203
4206
llvm::enumerate (llvm::zip (op.getOperands (), layouts_in))) {
4204
4207
auto [operand, li] = tup;
4205
- auto vty = dyn_cast<VectorType>(operand. getType () );
4206
- TPU_ASSERT_EQ_OP (vty != nullptr , li.has_value ());
4207
- if (vty == nullptr ) {
4208
+ auto vector_operand = dyn_cast<TypedValue< VectorType>> (operand);
4209
+ TPU_ASSERT_EQ_OP (vector_operand != nullptr , li.has_value ());
4210
+ if (vector_operand == nullptr ) {
4208
4211
continue ;
4209
4212
}
4213
+ auto vty = vector_operand.getType ();
4210
4214
4211
4215
// The operand should always be an Operation (and not a BlockArgument)
4212
4216
// since we expect the FuncOp to have only memrefs and semaphores as
4213
4217
// arguments.
4214
- auto op_result = dyn_cast<OpResult>(operand );
4218
+ auto op_result = dyn_cast<OpResult>(vector_operand );
4215
4219
if (op_result == nullptr ) {
4216
4220
return op.emitError (" Expected operand to be an operation result" );
4217
4221
}
@@ -4227,7 +4231,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
4227
4231
}
4228
4232
OpBuilder builder (&op);
4229
4233
FAILUREOR_ASSIGN_OR_RETURN (Value new_v,
4230
- relayout (builder, operand , /* src=*/ *lo,
4234
+ relayout (builder, vector_operand , /* src=*/ *lo,
4231
4235
/* dst=*/ *li, ctx.target_shape ));
4232
4236
op.setOperand (idx, new_v);
4233
4237
}
0 commit comments