@@ -658,15 +658,23 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
658
658
const auto result_ty = cast<VectorType>(op.getResult ().getType ());
659
659
auto source = cast<TypedValue<VectorType>>(op.getIn ());
660
660
const auto source_ty = source.getType ();
661
+ auto output_vregs_shape =
662
+ layout_out.tileArrayShape (result_ty.getShape (), ctx.target_shape );
661
663
if (layout_out.bitwidth () != 32 ) {
662
664
return op.emitOpError (
663
665
" Not implemented: Only extensions to 32-bit supported" );
664
666
}
665
667
FAILUREOR_ASSIGN_OR_RETURN (
666
- const xla::Array<Value> input_vregs,
668
+ xla::Array<Value> input_vregs,
667
669
disassemble (builder, layout_in, source, ctx.target_shape ));
668
- xla::Array<Value> output_vregs (
669
- layout_out.tileArrayShape (result_ty.getShape (), ctx.target_shape ));
670
+ xla::Array<Value> output_vregs (output_vregs_shape);
671
+ // TODO(jevinjiang): maybe just use tileArrayImplicitShape in disassemble?
672
+ if (layout_in.implicit_dim () != VectorLayout::ImplicitDim::kNone ) {
673
+ input_vregs.Reshape (layout_in.tileArrayImplicitShape (source_ty.getShape (),
674
+ ctx.target_shape ));
675
+ output_vregs.Reshape (layout_out.tileArrayImplicitShape (result_ty.getShape (),
676
+ ctx.target_shape ));
677
+ }
670
678
FAILUREOR_ASSIGN_OR_RETURN (
671
679
const VectorType res_vreg_ty,
672
680
getNativeVregType (result_ty.getElementType (), ctx.target_shape ));
@@ -676,51 +684,24 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
676
684
if (layout_in.offsets () != layout_out.offsets ()) {
677
685
return op.emitOpError (" Not implemented: Change of offsets during the cast" );
678
686
}
679
- switch (layout_in.implicit_dim ()) {
680
- case VectorLayout::ImplicitDim::kNone : {
681
- if (layout_in.tiling () != layout_out.tiling ()) {
682
- return op.emitOpError (
683
- " Not implemented: Changing tiling during the cast" );
684
- }
685
- auto tiling = layout_in.tiling ();
686
- if (ctx.target_shape [0 ] % tiling[0 ] != 0 ||
687
- ctx.target_shape [1 ] != tiling[1 ]) {
688
- return op.emitOpError (" Not implemented: tiling not supported" );
689
- }
690
- const int packing = layout_in.packing ();
691
- output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
692
- SmallVector<int64_t > input_vreg_idxs (toArrayRef (idxs));
693
- input_vreg_idxs.back () /= packing;
694
- const int64_t vreg_part = idxs.back () % packing;
695
- *v = builder.create <UnpackSubelementsOp>(
696
- res_vreg_ty, input_vregs (input_vreg_idxs), vreg_part);
697
- });
698
- } break ;
699
- case VectorLayout::ImplicitDim::kMinor :
700
- return op.emitOpError (
701
- " Not implemented: Only casts of lane-oriented values supported" );
702
- case VectorLayout::ImplicitDim::kSecondMinor : {
703
- auto is_one_tile = [](VectorType vty, VectorLayout layout) {
704
- auto implicit_shape = layout.implicitShape (vty.getShape ());
705
- auto tiled_shape = ArrayRef<int64_t >(implicit_shape).take_back (2 );
706
- return (layout.offsets ()[0 ].value_or (0 ) + tiled_shape[0 ] <=
707
- layout.tiling ()[0 ]) &&
708
- (layout.offsets ()[1 ].value_or (0 ) + tiled_shape[1 ] <=
709
- layout.tiling ()[1 ]);
710
- };
711
- if (input_vregs.dimensions () != absl::Span<const int64_t >{1 } ||
712
- output_vregs.dimensions () != absl::Span<const int64_t >{1 } ||
713
- !is_one_tile (source_ty, layout_in) ||
714
- !is_one_tile (result_ty, layout_out)) {
715
- return op.emitOpError (" Not implemented" );
716
- }
717
- if (layout_in.offsets ()[0 ] >= ctx.target_shape [0 ]) {
718
- return op.emitOpError (" Not implemented" );
719
- }
720
- auto unpack_subelements_op = builder.create <UnpackSubelementsOp>(
721
- res_vreg_ty, *input_vregs.begin (), 0 );
722
- output_vregs.Fill (unpack_subelements_op.getResult ());
723
- }
687
+ if (layout_in.tiling () != layout_out.tiling ()) {
688
+ return op.emitOpError (" Not implemented: Changing tiling during the cast" );
689
+ }
690
+ auto tiling = layout_in.tiling ();
691
+ if (ctx.target_shape [0 ] % tiling[0 ] != 0 ||
692
+ ctx.target_shape [1 ] != tiling[1 ]) {
693
+ return op.emitOpError (" Not implemented: tiling not supported" );
694
+ }
695
+ const int packing = layout_in.packing ();
696
+ output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
697
+ SmallVector<int64_t > input_vreg_idxs (toArrayRef (idxs));
698
+ input_vreg_idxs.back () /= packing;
699
+ const int64_t vreg_part = idxs.back () % packing;
700
+ *v = builder.create <UnpackSubelementsOp>(
701
+ res_vreg_ty, input_vregs (input_vreg_idxs), vreg_part);
702
+ });
703
+ if (layout_out.implicit_dim () != VectorLayout::ImplicitDim::kNone ) {
704
+ output_vregs.Reshape (output_vregs_shape);
724
705
}
725
706
op.replaceAllUsesWith (assemble (builder, result_ty, layout_out,
726
707
std::move (output_vregs), ctx.target_shape )
@@ -762,73 +743,85 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
762
743
const VectorLayout &layout_in,
763
744
const VectorLayout &layout_out) {
764
745
ImplicitLocOpBuilder builder (op.getLoc (), op.getOperation ());
746
+ auto source = cast<TypedValue<VectorType>>(op.getIn ());
747
+ const auto source_ty = source.getType ();
765
748
auto result_ty = cast<VectorType>(op.getResult ().getType ());
749
+ auto output_vregs_shape =
750
+ layout_out.tileArrayShape (result_ty.getShape (), ctx.target_shape );
766
751
FAILUREOR_ASSIGN_OR_RETURN (
767
- const xla::Array<Value> input_vregs,
768
- disassemble (builder, layout_in, cast<TypedValue<VectorType>>(op.getIn ()),
769
- ctx.target_shape ));
770
- xla::Array<Value> output_vregs (
771
- layout_out.tileArrayShape (result_ty.getShape (), ctx.target_shape ));
752
+ xla::Array<Value> input_vregs,
753
+ disassemble (builder, layout_in, source, ctx.target_shape ));
754
+ xla::Array<Value> output_vregs (output_vregs_shape);
772
755
if (layout_in.bitwidth () != 32 ) {
773
756
return op.emitOpError (" Not implemented: Only 32-bit truncation supported" );
774
757
}
758
+ if (layout_in.offsets () != layout_out.offsets ()) {
759
+ return op.emitOpError (
760
+ " Not implemented: Change of offsets during the truncation" );
761
+ }
762
+ if (layout_in.implicit_dim () != layout_out.implicit_dim ()) {
763
+ return op.emitOpError (" Not implemented: Change of layout during the cast" );
764
+ }
765
+ if (layout_in.tiling () != ctx.target_shape ) {
766
+ return op.emitOpError (" Not implemented: Only (8,128) tiling supported" );
767
+ }
768
+ if (layout_in.implicit_dim () != VectorLayout::ImplicitDim::kNone ) {
769
+ input_vregs.Reshape (layout_in.tileArrayImplicitShape (source_ty.getShape (),
770
+ ctx.target_shape ));
771
+ output_vregs.Reshape (layout_out.tileArrayImplicitShape (result_ty.getShape (),
772
+ ctx.target_shape ));
773
+ }
775
774
FAILUREOR_ASSIGN_OR_RETURN (
776
775
VectorType res_vreg_ty,
777
776
getNativeVregType (result_ty.getElementType (), ctx.target_shape ));
778
- if (layout_in.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
779
- layout_out.implicit_dim () == VectorLayout::ImplicitDim::kNone ) {
780
- if (layout_in.tiling () != ctx.target_shape ) {
781
- return op.emitOpError (" Not implemented: Only (8,128) tiling supported" );
782
- }
783
- if (layout_out.tiling () == ctx.target_shape ) {
784
- const int packing = layout_out.packing ();
785
- output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
786
- SmallVector<Value> parts;
787
- SmallVector<int64_t > idxs_local (toArrayRef (idxs));
788
- idxs_local.back () *= packing;
789
- for (int64_t i = 0 ; i < packing; ++i) {
790
- parts.push_back (input_vregs (idxs_local));
791
- // Pack any data lying around if OOB
792
- if (idxs_local.back () < input_vregs.dimensions ().back () - 1 ) {
793
- ++idxs_local.back ();
794
- }
795
- }
796
- *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts);
797
- });
798
-
799
- } else if (layout_out.hasNativeTiling (ctx.target_shape )) {
800
- int packing = layout_out.packing ();
777
+ if (layout_out.tiling () == ctx.target_shape ) {
778
+ const int packing = layout_out.packing ();
779
+ output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
801
780
SmallVector<Value> parts;
802
- parts.reserve (packing);
803
- output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
804
- CHECK_GE (idxs.size (), 2 );
805
- SmallVector<int64_t > idxs_local (toArrayRef (idxs));
806
- idxs_local[idxs.size () - 2 ] *= packing;
781
+ SmallVector<int64_t > idxs_local (toArrayRef (idxs));
782
+ idxs_local.back () *= packing;
783
+ for (int64_t i = 0 ; i < packing; ++i) {
807
784
parts.push_back (input_vregs (idxs_local));
808
- idxs_local[idxs.size () - 2 ]++;
809
- while (parts.size () < packing) {
810
- if (*(idxs_local.end () - 2 ) < *(input_vregs.dimensions ().end () - 2 )) {
811
- parts.push_back (input_vregs (idxs_local));
812
- idxs_local[idxs.size () - 2 ]++;
813
- } else {
814
- // Once we run out of tiles, we can pick any one we like.
815
- parts.push_back (parts.back ());
816
- }
785
+ // Pack any data lying around if OOB
786
+ if (idxs_local.back () < input_vregs.dimensions ().back () - 1 ) {
787
+ ++idxs_local.back ();
817
788
}
818
- *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts);
819
- parts.clear ();
820
- });
821
- } else {
822
- return op.emitOpError (" Not implemented: unsupported output tiling" );
823
- }
824
- op.replaceAllUsesWith (assemble (builder, result_ty, layout_out,
825
- std::move (output_vregs), ctx.target_shape )
826
- .getResult ());
827
- op.erase ();
828
- return success ();
789
+ }
790
+ *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts);
791
+ });
792
+ } else if (layout_out.hasNativeTiling (ctx.target_shape )) {
793
+ int packing = layout_out.packing ();
794
+ SmallVector<Value> parts;
795
+ parts.reserve (packing);
796
+ output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
797
+ CHECK_GE (idxs.size (), 2 );
798
+ SmallVector<int64_t > idxs_local (toArrayRef (idxs));
799
+ idxs_local[idxs.size () - 2 ] *= packing;
800
+ parts.push_back (input_vregs (idxs_local));
801
+ idxs_local[idxs.size () - 2 ]++;
802
+ while (parts.size () < packing) {
803
+ if (*(idxs_local.end () - 2 ) < *(input_vregs.dimensions ().end () - 2 )) {
804
+ parts.push_back (input_vregs (idxs_local));
805
+ idxs_local[idxs.size () - 2 ]++;
806
+ } else {
807
+ // Once we run out of tiles, we can pick any one we like.
808
+ parts.push_back (parts.back ());
809
+ }
810
+ }
811
+ *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts);
812
+ parts.clear ();
813
+ });
814
+ } else {
815
+ return op.emitOpError (" Not implemented: unsupported output tiling" );
829
816
}
830
- // TODO(tlongeri): why wasn't this part of the original code?
831
- return op.emitOpError (" Not implemented" );
817
+ if (layout_out.implicit_dim () != VectorLayout::ImplicitDim::kNone ) {
818
+ output_vregs.Reshape (output_vregs_shape);
819
+ }
820
+ op.replaceAllUsesWith (assemble (builder, result_ty, layout_out,
821
+ std::move (output_vregs), ctx.target_shape )
822
+ .getResult ());
823
+ op.erase ();
824
+ return success ();
832
825
}
833
826
834
827
LogicalResult arith_truncf_rule (RewriteContext &ctx, Operation &op,
0 commit comments