Skip to content

Commit 7e603f8

Browse files
authored
Revert D78455813
Differential Revision: D78564475 Pull Request resolved: #12640
1 parent 9f2b8cd commit 7e603f8

File tree

2 files changed

+2
-198
lines changed

2 files changed

+2
-198
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 2 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -22,159 +22,21 @@ namespace native {
2222
using Tensor = executorch::aten::Tensor;
2323
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
2424

25-
namespace {
26-
27-
bool check_fast_path_conditions(
28-
ET_UNUSED const Tensor& in,
29-
TensorOptList indices,
30-
size_t* dim) {
31-
bool found_index = false;
32-
for (const auto i : c10::irange(indices.size())) {
33-
if (indices[i].has_value()) {
34-
*dim = i;
35-
// Fast path only supports a single non-null index tensor
36-
if (found_index) {
37-
return false;
38-
}
39-
found_index = true;
40-
const Tensor& index = indices[i].value();
41-
ScalarType ix_type = index.scalar_type();
42-
// Fast path only supports only supports Long or Int index tensors
43-
if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44-
return false;
45-
}
46-
// Fast path only supports a 1-dimensional index tensor
47-
if (index.dim() != 1) {
48-
return false;
49-
}
50-
}
51-
}
52-
53-
// Fast path only supports needs at least one non-null index tensor
54-
if (!found_index) {
55-
return false;
56-
}
57-
58-
return true;
59-
}
60-
61-
bool check_fast_path_args(
62-
const Tensor& in,
63-
TensorOptList indices,
64-
size_t dim,
65-
Tensor& out) {
66-
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
67-
68-
ET_CHECK_OR_RETURN_FALSE(
69-
static_cast<ssize_t>(indices.size()) <= in.dim(),
70-
"Indexing too many dimensions");
71-
72-
const Tensor& index = indices[dim].value();
73-
74-
bool is_valid_index = true;
75-
ET_SWITCH_TWO_TYPES(
76-
Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() {
77-
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
78-
for (const auto i : c10::irange(index.numel())) {
79-
if (index_arr[i] < 0 ||
80-
index_arr[i] >= static_cast<CTYPE>(in.size(dim))) {
81-
ET_LOG(
82-
Error,
83-
"Index %" PRId64
84-
" out of range for tensor with size %zd"
85-
" at dimension %zu",
86-
static_cast<int64_t>(index_arr[i]),
87-
in.size(dim),
88-
dim);
89-
is_valid_index = false;
90-
break;
91-
}
92-
}
93-
});
94-
95-
ET_CHECK_OR_RETURN_FALSE(
96-
is_valid_index,
97-
"Some index values are not within bounds of input tensor at indexed dim");
98-
99-
return true;
100-
}
101-
102-
Tensor& fast_path(
25+
Tensor& index_Tensor_out(
10326
KernelRuntimeContext& ctx,
10427
const Tensor& in,
10528
TensorOptList indices,
106-
size_t dim,
10729
Tensor& out) {
10830
(void)ctx;
10931

11032
ET_KERNEL_CHECK(
111-
ctx, check_fast_path_args(in, indices, dim, out), InvalidArgument, out);
112-
113-
const Tensor& index = indices[dim].value();
114-
ScalarType index_type = index.scalar_type();
115-
116-
if (out.dim() == 0) {
117-
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), out.nbytes());
118-
return out;
119-
}
120-
121-
size_t leading_dims = getLeadingDims(in, dim);
122-
size_t trailing_dims = getTrailingDims(in, dim);
123-
124-
if (leading_dims == 0 || trailing_dims == 0) {
125-
return out;
126-
}
127-
128-
size_t in_dim_length = in.size(dim);
129-
size_t out_dim_length = out.size(dim);
130-
131-
size_t length_per_step = trailing_dims * in.element_size();
132-
133-
const char* in_data = in.const_data_ptr<char>();
134-
char* out_data = out.mutable_data_ptr<char>();
135-
136-
// @lint-ignore CLANGTIDY facebook-hte-CArray
137-
static constexpr const char op_name[] = "index.Tensor_out";
138-
139-
ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
140-
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
141-
for (const auto i : c10::irange(leading_dims)) {
142-
const char* src = in_data + i * in_dim_length * length_per_step;
143-
char* dest = out_data + i * out_dim_length * length_per_step;
144-
for (const auto j : c10::irange(out_dim_length)) {
145-
const char* copy_src = src + index_arr[j] * length_per_step;
146-
char* copy_dest = dest + j * length_per_step;
147-
memcpy(copy_dest, copy_src, length_per_step);
148-
}
149-
}
150-
});
151-
152-
return out;
153-
}
154-
155-
} // namespace
156-
157-
Tensor& index_Tensor_out(
158-
KernelRuntimeContext& ctx,
159-
const Tensor& in,
160-
TensorOptList indices,
161-
Tensor& out) {
162-
(void)ctx;
33+
ctx, check_index_args(in, indices, out), InvalidArgument, out);
16334

16435
ET_KERNEL_CHECK(
16536
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
16637

16738
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
16839

169-
size_t dim = 0;
170-
bool is_fast_path = check_fast_path_conditions(in, indices, &dim);
171-
if (is_fast_path) {
172-
return fast_path(ctx, in, indices, dim, out);
173-
}
174-
175-
ET_KERNEL_CHECK(
176-
ctx, check_index_args(in, indices, out), InvalidArgument, out);
177-
17840
ScalarType in_type = in.scalar_type();
17941
size_t block_count = count_index_blocks(indices);
18042

kernels/test/op_index_test.cpp

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -627,61 +627,3 @@ TEST_F(OpIndexTensorOutTest, UpperBoundOutTensor) {
627627
EXPECT_TENSOR_EQ(out, ret);
628628
EXPECT_TENSOR_EQ(ret, expected);
629629
}
630-
631-
TEST_F(OpIndexTensorOutTest, FastPath) {
632-
TensorFactory<ScalarType::Float> tf;
633-
TensorFactory<ScalarType::Long> tfl;
634-
635-
// clang-format off
636-
Tensor x = tf.make(
637-
{2, 3, 4},
638-
{
639-
// [0, :, :]
640-
1., 2., 3., 4., // [0, 0, :]
641-
5., 6., 7., 8., // [0, 1, :]
642-
9., 10., 11., 12., // [0, 2, :]
643-
644-
// [1, :, :]
645-
-1., -2., -3., -4., // [1, 0, :]
646-
-5., -6., -7., -8., // [1, 1, :]
647-
-9., -10., -11., -12., // [1, 2, :]
648-
});
649-
// clang-format on
650-
651-
optional<Tensor> indices[] = {
652-
optional<Tensor>(),
653-
optional<Tensor>(),
654-
optional<Tensor>(tfl.make({3}, {2, 0, 1}))};
655-
656-
Tensor out = tf.zeros({2, 3, 3});
657-
// clang-format off
658-
Tensor expected = tf.make(
659-
{2, 3, 3},
660-
{
661-
3., 1., 2.,
662-
7., 5., 6.,
663-
11., 9., 10.,
664-
665-
-3., -1., -2.,
666-
-7., -5., -6.,
667-
-11., -9., -10.,
668-
});
669-
// clang-format on
670-
671-
op_index_tensor_out(x, indices, out);
672-
673-
EXPECT_TENSOR_EQ(out, expected);
674-
}
675-
676-
TEST_F(OpIndexTensorOutTest, FastPathZeroDim) {
677-
TensorFactory<ScalarType::Float> tf;
678-
TensorFactory<ScalarType::Long> tfl;
679-
680-
Tensor x = tf.ones({0});
681-
optional<Tensor> indices[] = {optional<Tensor>(tfl.zeros({0}))};
682-
Tensor out = tf.zeros({0});
683-
Tensor expected = tf.ones({0});
684-
op_index_tensor_out(x, indices, out);
685-
686-
EXPECT_TENSOR_EQ(out, expected);
687-
}

0 commit comments

Comments
 (0)