Skip to content

Commit 1ab8216

Browse files
authored
Lower as_strided_copy use fast path with slice (#8734)
1 parent b4ba17b commit 1ab8216

File tree

6 files changed

+499
-1
lines changed

6 files changed

+499
-1
lines changed

test/cpp/test_aten_xla_tensor_3.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,9 @@ TEST_F(AtenXlaTensorTest, TestAsStrided) {
922922
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
923923
AllClose(output, xla_output);
924924
});
925+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
926+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
927+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
925928
}
926929

927930
TEST_F(AtenXlaTensorTest, TestAsStridedInPlace) {
@@ -938,6 +941,9 @@ TEST_F(AtenXlaTensorTest, TestAsStridedInPlace) {
938941
AllClose(output, xla_output);
939942
AllClose(input, xla_input);
940943
});
944+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
945+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
946+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
941947
}
942948

943949
TEST_F(AtenXlaTensorTest, TestAsStridedWithOffset) {
@@ -956,6 +962,9 @@ TEST_F(AtenXlaTensorTest, TestAsStridedWithOffset) {
956962
/*storage_offset=*/storage_offset);
957963
AllClose(output, xla_output);
958964
});
965+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
966+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
967+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
959968
}
960969

961970
TEST_F(AtenXlaTensorTest, TestAsStridedWithInplaceCopy) {
@@ -970,6 +979,9 @@ TEST_F(AtenXlaTensorTest, TestAsStridedWithInplaceCopy) {
970979
xla_output.as_strided(size, stride).copy_(xla_grad);
971980
AllClose(output, xla_output);
972981
});
982+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
983+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
984+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
973985
}
974986

975987
TEST_F(AtenXlaTensorTest, TestEmptyStrided) {
@@ -984,6 +996,140 @@ TEST_F(AtenXlaTensorTest, TestEmptyStrided) {
984996
});
985997
}
986998

999+
TEST_F(AtenXlaTensorTest, TestAsStridedUseSlice) {
1000+
torch::Tensor input =
1001+
torch::rand({16, 32, 24}, torch::TensorOptions(torch::kFloat));
1002+
std::vector<int64_t> size = {16, 8, 24};
1003+
std::vector<int64_t> stride = {768, 48, 1};
1004+
torch::Tensor output =
1005+
torch::as_strided(input, /*size=*/size, /*stride=*/stride);
1006+
ForEachDevice([&](const torch::Device& device) {
1007+
torch::Tensor xla_input = CopyToDevice(input, device);
1008+
torch::Tensor xla_output =
1009+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
1010+
AllClose(output, xla_output);
1011+
});
1012+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1013+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
1014+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1015+
}
1016+
1017+
TEST_F(AtenXlaTensorTest, TestAsStridedUseSliceSizeReduce) {
1018+
torch::Tensor input =
1019+
torch::rand({16, 32, 24}, torch::TensorOptions(torch::kFloat));
1020+
std::vector<int64_t> size = {16, 8, 24};
1021+
std::vector<int64_t> stride = {768, 24, 1};
1022+
torch::Tensor output =
1023+
torch::as_strided(input, /*size=*/size, /*stride=*/stride);
1024+
ForEachDevice([&](const torch::Device& device) {
1025+
torch::Tensor xla_input = CopyToDevice(input, device);
1026+
torch::Tensor xla_output =
1027+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
1028+
AllClose(output, xla_output);
1029+
});
1030+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1031+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
1032+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1033+
}
1034+
1035+
TEST_F(AtenXlaTensorTest, TestAsStridedMismatchLastDimUseSlice) {
1036+
torch::Tensor input =
1037+
torch::rand({16, 32, 24}, torch::TensorOptions(torch::kFloat));
1038+
std::vector<int64_t> size = {16, 32};
1039+
std::vector<int64_t> stride = {768, 24};
1040+
torch::Tensor output =
1041+
torch::as_strided(input, /*size=*/size, /*stride=*/stride);
1042+
ForEachDevice([&](const torch::Device& device) {
1043+
torch::Tensor xla_input = CopyToDevice(input, device);
1044+
torch::Tensor xla_output =
1045+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
1046+
AllClose(output, xla_output);
1047+
});
1048+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1049+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
1050+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1051+
}
1052+
1053+
TEST_F(AtenXlaTensorTest, TestAsStridedMismatchMiddleDimUseSlice) {
1054+
torch::lazy::MetricsArena::Get()->ResetMetrics();
1055+
runtime::metrics::ClearMetrics();
1056+
torch::Tensor input =
1057+
torch::rand({6, 4, 2, 4}, torch::TensorOptions(torch::kFloat));
1058+
std::vector<int64_t> size = {6, 2, 4};
1059+
std::vector<int64_t> stride = {32, 4, 1};
1060+
torch::Tensor output =
1061+
torch::as_strided(input, /*size=*/size, /*stride=*/stride);
1062+
ForEachDevice([&](const torch::Device& device) {
1063+
torch::Tensor xla_input = CopyToDevice(input, device);
1064+
torch::Tensor xla_output =
1065+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
1066+
AllClose(output, xla_output);
1067+
});
1068+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1069+
ExpectCounterNotChanged("xla::take", cpp_test::GetIgnoredCounters());
1070+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1071+
}
1072+
1073+
TEST_F(AtenXlaTensorTest, TestAsStridedMismatchDimWithOffset) {
1074+
torch::lazy::MetricsArena::Get()->ResetMetrics();
1075+
runtime::metrics::ClearMetrics();
1076+
torch::Tensor input =
1077+
torch::rand({6, 4, 2, 4}, torch::TensorOptions(torch::kFloat));
1078+
std::vector<int64_t> size = {6, 2, 4};
1079+
std::vector<int64_t> stride = {32, 4, 1};
1080+
torch::Tensor output =
1081+
torch::as_strided(input, /*size=*/size, /*stride=*/stride, 1);
1082+
ForEachDevice([&](const torch::Device& device) {
1083+
torch::Tensor xla_input = CopyToDevice(input, device);
1084+
torch::Tensor xla_output =
1085+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride, 1);
1086+
AllClose(output, xla_output);
1087+
});
1088+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1089+
ExpectCounterChanged("xla::take", cpp_test::GetIgnoredCounters());
1090+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1091+
}
1092+
1093+
TEST_F(AtenXlaTensorTest, TestAsStridedMultipleMismatchDimWithOffset) {
1094+
torch::lazy::MetricsArena::Get()->ResetMetrics();
1095+
runtime::metrics::ClearMetrics();
1096+
torch::Tensor input =
1097+
torch::rand({6, 4, 2, 4}, torch::TensorOptions(torch::kFloat));
1098+
std::vector<int64_t> size = {3, 2, 4};
1099+
std::vector<int64_t> stride = {16, 4, 1};
1100+
torch::Tensor output =
1101+
torch::as_strided(input, /*size=*/size, /*stride=*/stride);
1102+
ForEachDevice([&](const torch::Device& device) {
1103+
torch::Tensor xla_input = CopyToDevice(input, device);
1104+
torch::Tensor xla_output =
1105+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
1106+
AllClose(output, xla_output);
1107+
});
1108+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1109+
ExpectCounterChanged("xla::take", cpp_test::GetIgnoredCounters());
1110+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1111+
}
1112+
1113+
TEST_F(AtenXlaTensorTest, TestAsStridedMultipleDimMismatch) {
1114+
torch::lazy::MetricsArena::Get()->ResetMetrics();
1115+
runtime::metrics::ClearMetrics();
1116+
torch::Tensor input =
1117+
torch::rand({6, 4, 2, 4}, torch::TensorOptions(torch::kFloat));
1118+
std::vector<int64_t> size = {6, 4, 1, 2};
1119+
std::vector<int64_t> stride = {32, 8, 8, 2};
1120+
torch::Tensor output =
1121+
torch::as_strided(input, /*size=*/size, /*stride=*/stride);
1122+
ForEachDevice([&](const torch::Device& device) {
1123+
torch::Tensor xla_input = CopyToDevice(input, device);
1124+
torch::Tensor xla_output =
1125+
torch::as_strided(xla_input, /*size=*/size, /*stride=*/stride);
1126+
AllClose(output, xla_output);
1127+
});
1128+
ExpectCounterNotChanged("aten::*", cpp_test::GetIgnoredCounters());
1129+
ExpectCounterChanged("xla::take", cpp_test::GetIgnoredCounters());
1130+
ExpectCounterChanged("xla::as_strided_copy", cpp_test::GetIgnoredCounters());
1131+
}
1132+
9871133
TEST_F(AtenXlaTensorTest, TestAvgPool2DBackward) {
9881134
int kernel_size = 2;
9891135
for (int stride = 1; stride <= 2; ++stride) {

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ function run_xla_op_tests2 {
197197
run_test "$CDIR/scan/test_scan.py"
198198
run_test "$CDIR/scan/test_scan_spmd.py"
199199
run_test "$CDIR/scan/test_scan_layers.py"
200+
run_test "$CDIR/test_as_stride_use_slice.py"
200201
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
201202
run_test "$CDIR/test_autocast.py"
202203
run_test "$CDIR/eager/test_eager.py"

test/scan/test_scan_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor):
115115
y_xla = model(x)
116116

117117
torch_xla.sync()
118-
torch.testing.assert_close(y_cpu, y_xla.cpu(), atol=1e-3, rtol=1e-3)
118+
torch.testing.assert_close(y_cpu, y_xla.cpu(), atol=1e-3, rtol=1e-2)
119119

120120
def check_dots_in_model(self, model, x, expect_pattern):
121121
# Trace the model to get the HLO.

0 commit comments

Comments
 (0)