@@ -922,6 +922,9 @@ TEST_F(AtenXlaTensorTest, TestAsStrided) {
922
922
torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
923
923
AllClose (output, xla_output);
924
924
});
925
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
926
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
927
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
925
928
}
926
929
927
930
TEST_F (AtenXlaTensorTest, TestAsStridedInPlace) {
@@ -938,6 +941,9 @@ TEST_F(AtenXlaTensorTest, TestAsStridedInPlace) {
938
941
AllClose (output, xla_output);
939
942
AllClose (input, xla_input);
940
943
});
944
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
945
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
946
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
941
947
}
942
948
943
949
TEST_F (AtenXlaTensorTest, TestAsStridedWithOffset) {
@@ -956,6 +962,9 @@ TEST_F(AtenXlaTensorTest, TestAsStridedWithOffset) {
956
962
/* storage_offset=*/ storage_offset);
957
963
AllClose (output, xla_output);
958
964
});
965
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
966
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
967
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
959
968
}
960
969
961
970
TEST_F (AtenXlaTensorTest, TestAsStridedWithInplaceCopy) {
@@ -970,6 +979,9 @@ TEST_F(AtenXlaTensorTest, TestAsStridedWithInplaceCopy) {
970
979
xla_output.as_strided (size, stride).copy_ (xla_grad);
971
980
AllClose (output, xla_output);
972
981
});
982
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
983
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
984
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
973
985
}
974
986
975
987
TEST_F (AtenXlaTensorTest, TestEmptyStrided) {
@@ -984,6 +996,140 @@ TEST_F(AtenXlaTensorTest, TestEmptyStrided) {
984
996
});
985
997
}
986
998
999
+ TEST_F (AtenXlaTensorTest, TestAsStridedUseSlice) {
1000
+ torch::Tensor input =
1001
+ torch::rand ({16 , 32 , 24 }, torch::TensorOptions (torch::kFloat ));
1002
+ std::vector<int64_t > size = {16 , 8 , 24 };
1003
+ std::vector<int64_t > stride = {768 , 48 , 1 };
1004
+ torch::Tensor output =
1005
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride);
1006
+ ForEachDevice ([&](const torch::Device& device) {
1007
+ torch::Tensor xla_input = CopyToDevice (input, device);
1008
+ torch::Tensor xla_output =
1009
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
1010
+ AllClose (output, xla_output);
1011
+ });
1012
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1013
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1014
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1015
+ }
1016
+
1017
+ TEST_F (AtenXlaTensorTest, TestAsStridedUseSliceSizeReduce) {
1018
+ torch::Tensor input =
1019
+ torch::rand ({16 , 32 , 24 }, torch::TensorOptions (torch::kFloat ));
1020
+ std::vector<int64_t > size = {16 , 8 , 24 };
1021
+ std::vector<int64_t > stride = {768 , 24 , 1 };
1022
+ torch::Tensor output =
1023
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride);
1024
+ ForEachDevice ([&](const torch::Device& device) {
1025
+ torch::Tensor xla_input = CopyToDevice (input, device);
1026
+ torch::Tensor xla_output =
1027
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
1028
+ AllClose (output, xla_output);
1029
+ });
1030
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1031
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1032
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1033
+ }
1034
+
1035
+ TEST_F (AtenXlaTensorTest, TestAsStridedMismatchLastDimUseSlice) {
1036
+ torch::Tensor input =
1037
+ torch::rand ({16 , 32 , 24 }, torch::TensorOptions (torch::kFloat ));
1038
+ std::vector<int64_t > size = {16 , 32 };
1039
+ std::vector<int64_t > stride = {768 , 24 };
1040
+ torch::Tensor output =
1041
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride);
1042
+ ForEachDevice ([&](const torch::Device& device) {
1043
+ torch::Tensor xla_input = CopyToDevice (input, device);
1044
+ torch::Tensor xla_output =
1045
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
1046
+ AllClose (output, xla_output);
1047
+ });
1048
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1049
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1050
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1051
+ }
1052
+
1053
+ TEST_F (AtenXlaTensorTest, TestAsStridedMismatchMiddleDimUseSlice) {
1054
+ torch::lazy::MetricsArena::Get ()->ResetMetrics ();
1055
+ runtime::metrics::ClearMetrics ();
1056
+ torch::Tensor input =
1057
+ torch::rand ({6 , 4 , 2 , 4 }, torch::TensorOptions (torch::kFloat ));
1058
+ std::vector<int64_t > size = {6 , 2 , 4 };
1059
+ std::vector<int64_t > stride = {32 , 4 , 1 };
1060
+ torch::Tensor output =
1061
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride);
1062
+ ForEachDevice ([&](const torch::Device& device) {
1063
+ torch::Tensor xla_input = CopyToDevice (input, device);
1064
+ torch::Tensor xla_output =
1065
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
1066
+ AllClose (output, xla_output);
1067
+ });
1068
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1069
+ ExpectCounterNotChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1070
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1071
+ }
1072
+
1073
+ TEST_F (AtenXlaTensorTest, TestAsStridedMismatchDimWithOffset) {
1074
+ torch::lazy::MetricsArena::Get ()->ResetMetrics ();
1075
+ runtime::metrics::ClearMetrics ();
1076
+ torch::Tensor input =
1077
+ torch::rand ({6 , 4 , 2 , 4 }, torch::TensorOptions (torch::kFloat ));
1078
+ std::vector<int64_t > size = {6 , 2 , 4 };
1079
+ std::vector<int64_t > stride = {32 , 4 , 1 };
1080
+ torch::Tensor output =
1081
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride, 1 );
1082
+ ForEachDevice ([&](const torch::Device& device) {
1083
+ torch::Tensor xla_input = CopyToDevice (input, device);
1084
+ torch::Tensor xla_output =
1085
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride, 1 );
1086
+ AllClose (output, xla_output);
1087
+ });
1088
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1089
+ ExpectCounterChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1090
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1091
+ }
1092
+
1093
+ TEST_F (AtenXlaTensorTest, TestAsStridedMultipleMismatchDimWithOffset) {
1094
+ torch::lazy::MetricsArena::Get ()->ResetMetrics ();
1095
+ runtime::metrics::ClearMetrics ();
1096
+ torch::Tensor input =
1097
+ torch::rand ({6 , 4 , 2 , 4 }, torch::TensorOptions (torch::kFloat ));
1098
+ std::vector<int64_t > size = {3 , 2 , 4 };
1099
+ std::vector<int64_t > stride = {16 , 4 , 1 };
1100
+ torch::Tensor output =
1101
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride);
1102
+ ForEachDevice ([&](const torch::Device& device) {
1103
+ torch::Tensor xla_input = CopyToDevice (input, device);
1104
+ torch::Tensor xla_output =
1105
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
1106
+ AllClose (output, xla_output);
1107
+ });
1108
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1109
+ ExpectCounterChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1110
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1111
+ }
1112
+
1113
+ TEST_F (AtenXlaTensorTest, TestAsStridedMultipleDimMismatch) {
1114
+ torch::lazy::MetricsArena::Get ()->ResetMetrics ();
1115
+ runtime::metrics::ClearMetrics ();
1116
+ torch::Tensor input =
1117
+ torch::rand ({6 , 4 , 2 , 4 }, torch::TensorOptions (torch::kFloat ));
1118
+ std::vector<int64_t > size = {6 , 4 , 1 , 2 };
1119
+ std::vector<int64_t > stride = {32 , 8 , 8 , 2 };
1120
+ torch::Tensor output =
1121
+ torch::as_strided (input, /* size=*/ size, /* stride=*/ stride);
1122
+ ForEachDevice ([&](const torch::Device& device) {
1123
+ torch::Tensor xla_input = CopyToDevice (input, device);
1124
+ torch::Tensor xla_output =
1125
+ torch::as_strided (xla_input, /* size=*/ size, /* stride=*/ stride);
1126
+ AllClose (output, xla_output);
1127
+ });
1128
+ ExpectCounterNotChanged (" aten::*" , cpp_test::GetIgnoredCounters ());
1129
+ ExpectCounterChanged (" xla::take" , cpp_test::GetIgnoredCounters ());
1130
+ ExpectCounterChanged (" xla::as_strided_copy" , cpp_test::GetIgnoredCounters ());
1131
+ }
1132
+
987
1133
TEST_F (AtenXlaTensorTest, TestAvgPool2DBackward) {
988
1134
int kernel_size = 2 ;
989
1135
for (int stride = 1 ; stride <= 2 ; ++stride) {
0 commit comments