@@ -772,3 +772,77 @@ func.func @torch.aten.stft.center_2D_hop_length_3_window_pad_both(%arg0: !torch.
772
772
%0 = torch.aten.stft.center %arg0 , %nfft , %hoplen , %winlen , %arg1 , %cstfalse , %padmode , %cstfalse , %cstfalse , %csttrue , %cstfalse : !torch.vtensor <[3 ,90 ],f32 >, !torch.int , !torch.int , !torch.int , !torch.vtensor <[8 ],f32 >, !torch.bool , !torch.str , !torch.bool , !torch.bool , !torch.bool , !torch.bool -> !torch.vtensor <[3 ,6 ,27 ],complex <f32 >>
773
773
return %0 : !torch.vtensor <[3 ,6 ,27 ],complex <f32 >>
774
774
}
775
+
776
+
777
+ // -----
778
+
779
+
780
+ // CHECK-LABEL: func.func @native_layer_norm(
781
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,56,56,96],f32>, %[[ARG1:.*]]: !torch.list<int>, %[[ARG2:.*]]: !torch.vtensor<[96],f32>, %[[ARG3:.*]]: !torch.vtensor<[96],f32>, %[[ARG4:.*]]: !torch.float) -> (!torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
782
+ // CHECK-DAG: %[[INT96:.*]] = torch.constant.int 96
783
+ // CHECK-DAG: %[[INT56:.*]] = torch.constant.int 56
784
+ // CHECK-DAG: %[[NONE:.*]] = torch.constant.none
785
+ // CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
786
+ // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
787
+ // CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
788
+ // CHECK: %[[VAR1:.*]] = torch.aten.sum.dim_IntList %[[ARG0]], %[[VAR0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
789
+ // CHECK: %[[VAR2:.*]] = torch.aten.numel %[[ARG0]] : !torch.vtensor<[1,56,56,96],f32> -> !torch.int
790
+ // CHECK: %[[VAR3:.*]] = torch.aten.div.Scalar %[[VAR1]], %[[VAR2]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
791
+ // CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
792
+ // CHECK: %[[VAR5:.*]] = torch.aten.broadcast_to %[[VAR3]], %[[VAR4]] : !torch.vtensor<[1,56,56,1],f32>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
793
+ // CHECK: %[[VAR6:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAR5]], %int1 : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32>, !torch.int -> !torch.vtensor<[1,56,56,96],f32>
794
+ // CHECK: %[[VAR7:.*]] = torch.aten.mul.Tensor %[[VAR6]], %[[VAR6]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32> -> !torch.vtensor<[1,56,56,96],f32>
795
+ // CHECK: %[[VAR8:.*]] = torch.aten.sum.dim_IntList %[[VAR7]], %0, %true, %none : !torch.vtensor<[1,56,56,96],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
796
+ // CHECK: %[[VAR9:.*]] = torch.aten.numel %7 : !torch.vtensor<[1,56,56,96],f32> -> !torch.int
797
+ // CHECK: %[[VAR10:.*]] = torch.aten.div.Scalar %[[VAR8]], %[[VAR9]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
798
+ // CHECK: %[[VAR11:.*]] = torch.aten.add.Scalar %[[VAR10]], %[[ARG4]], %int1 : !torch.vtensor<[1,56,56,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
799
+ // CHECK: %[[VAR12:.*]] = torch.aten.rsqrt %[[VAR11]] : !torch.vtensor<[1,56,56,1],f32> -> !torch.vtensor<[1,56,56,1],f32>
800
+ // CHECK: %[[VAR13:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
801
+ // CHECK: %[[VAR14:.*]] = torch.aten.broadcast_to %[[VAR12]], %[[VAR13]] : !torch.vtensor<[1,56,56,1],f32>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],f32>
802
+ // CHECK: %[[VAR15:.*]] = torch.aten.mul.Tensor %[[VAR6]], %[[VAR14]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,96],f32> -> !torch.vtensor<[1,56,56,96],f32>
803
+ // CHECK: %[[VAR16:.*]] = torch.aten.mul.Tensor %[[VAR15]], %[[ARG2]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[96],f32> -> !torch.vtensor<[1,56,56,96],f32>
804
+ // CHECK: %[[VAR17:.*]] = torch.aten.add.Tensor %[[VAR16]], %[[ARG3]], %int1 : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[96],f32>, !torch.int -> !torch.vtensor<[1,56,56,96],f32>
805
+ // CHECK: return %[[VAR17]], %[[VAR3]], %[[VAR12]] : !torch.vtensor<[1,56,56,96],f32>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
806
+ func.func @native_layer_norm (%input: !torch.vtensor <[1 ,56 ,56 ,96 ],f32 >, %normalized_shape: !torch.list <int >, %weight: !torch.vtensor <[96 ],f32 >, %bias: !torch.vtensor <[96 ],f32 >, %eps: !torch.float ) -> (!torch.vtensor <[1 ,56 ,56 ,96 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >) {
807
+ %result , %mean , %rstd = torch.aten.native_layer_norm %input , %normalized_shape , %weight , %bias , %eps : !torch.vtensor <[1 ,56 ,56 ,96 ],f32 >, !torch.list <int >, !torch.vtensor <[96 ],f32 >, !torch.vtensor <[96 ],f32 >, !torch.float -> !torch.vtensor <[1 ,56 ,56 ,96 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >
808
+ return %result , %mean , %rstd : !torch.vtensor <[1 ,56 ,56 ,96 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >
809
+ }
810
+
811
+
812
+ // -----
813
+
814
+
815
+ // CHECK-LABEL: func.func @native_layer_norm_mixed_dtypes(
816
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,56,56,96],bf16>, %[[ARG1:.*]]: !torch.list<int>, %[[ARG2:.*]]: !torch.vtensor<[96],bf16>, %[[ARG3:.*]]: !torch.vtensor<[96],bf16>, %[[ARG4:.*]]: !torch.float) -> (!torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>) {
817
+ // CHECK-DAG: %[[INT96:.*]] = torch.constant.int 96
818
+ // CHECK-DAG: %[[INT56:.*]] = torch.constant.int 56
819
+ // CHECK-DAG: %[[INT15:.*]] = torch.constant.int 15
820
+ // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
821
+ // CHECK-DAG: %[[NONE:.*]] = torch.constant.none
822
+ // CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
823
+ // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
824
+ // CHECK: %[[VAR0:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
825
+ // CHECK: %[[VAR1:.*]] = torch.aten.sum.dim_IntList %[[ARG0]], %[[VAR0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
826
+ // CHECK: %[[VAR2:.*]] = torch.aten.numel %[[ARG0]] : !torch.vtensor<[1,56,56,96],bf16> -> !torch.int
827
+ // CHECK: %[[VAR3:.*]] = torch.aten.div.Scalar %[[VAR1]], %[[VAR2]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
828
+ // CHECK: %[[VAR4:.*]] = torch.aten.to.dtype %[[VAR3]], %[[INT15]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],bf16>
829
+ // CHECK: %[[VAR5:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
830
+ // CHECK: %[[VAR6:.*]] = torch.aten.broadcast_to %[[VAR4]], %[[VAR5]] : !torch.vtensor<[1,56,56,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],bf16>
831
+ // CHECK: %[[VAR7:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAR6]], %int1 : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16>, !torch.int -> !torch.vtensor<[1,56,56,96],bf16>
832
+ // CHECK: %[[VAR8:.*]] = torch.aten.mul.Tensor %[[VAR7]], %[[VAR7]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
833
+ // CHECK: %[[VAR9:.*]] = torch.aten.sum.dim_IntList %[[VAR8]], %0, %true, %none : !torch.vtensor<[1,56,56,96],bf16>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],f32>
834
+ // CHECK: %[[VAR10:.*]] = torch.aten.numel %8 : !torch.vtensor<[1,56,56,96],bf16> -> !torch.int
835
+ // CHECK: %[[VAR11:.*]] = torch.aten.div.Scalar %[[VAR9]], %[[VAR10]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
836
+ // CHECK: %[[VAR12:.*]] = torch.aten.add.Scalar %[[VAR11]], %[[ARG4]], %int1 : !torch.vtensor<[1,56,56,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,56,56,1],f32>
837
+ // CHECK: %[[VAR13:.*]] = torch.aten.rsqrt %[[VAR12]] : !torch.vtensor<[1,56,56,1],f32> -> !torch.vtensor<[1,56,56,1],f32>
838
+ // CHECK: %[[VAR14:.*]] = torch.aten.to.dtype %[[VAR13]], %[[INT15]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,56,56,1],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,56,56,1],bf16>
839
+ // CHECK: %[[VAR15:.*]] = torch.prim.ListConstruct %int1, %int56, %int56, %int96 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
840
+ // CHECK: %[[VAR16:.*]] = torch.aten.broadcast_to %[[VAR14]], %[[VAR15]] : !torch.vtensor<[1,56,56,1],bf16>, !torch.list<int> -> !torch.vtensor<[1,56,56,96],bf16>
841
+ // CHECK: %[[VAR17:.*]] = torch.aten.mul.Tensor %[[VAR7]], %[[VAR16]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
842
+ // CHECK: %[[VAR18:.*]] = torch.aten.mul.Tensor %[[VAR17]], %[[ARG2]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[96],bf16> -> !torch.vtensor<[1,56,56,96],bf16>
843
+ // CHECK: %[[VAR19:.*]] = torch.aten.add.Tensor %[[VAR18]], %[[ARG3]], %int1 : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[96],bf16>, !torch.int -> !torch.vtensor<[1,56,56,96],bf16>
844
+ // CHECK: return %[[VAR19]], %[[VAR3]], %[[VAR13]] : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32>
845
+ func.func @native_layer_norm_mixed_dtypes (%input: !torch.vtensor <[1 ,56 ,56 ,96 ],bf16 >, %normalized_shape: !torch.list <int >, %weight: !torch.vtensor <[96 ],bf16 >, %bias: !torch.vtensor <[96 ],bf16 >, %eps: !torch.float ) -> (!torch.vtensor <[1 ,56 ,56 ,96 ],bf16 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >) {
846
+ %result , %mean , %rstd = torch.aten.native_layer_norm %input , %normalized_shape , %weight , %bias , %eps : !torch.vtensor <[1 ,56 ,56 ,96 ],bf16 >, !torch.list <int >, !torch.vtensor <[96 ],bf16 >, !torch.vtensor <[96 ],bf16 >, !torch.float -> !torch.vtensor <[1 ,56 ,56 ,96 ],bf16 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >
847
+ return %result , %mean , %rstd : !torch.vtensor <[1 ,56 ,56 ,96 ],bf16 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >, !torch.vtensor <[1 ,56 ,56 ,1 ],f32 >
848
+ }
0 commit comments