@@ -22,159 +22,21 @@ namespace native {
22
22
using Tensor = executorch::aten::Tensor;
23
23
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
24
24
25
- namespace {
26
-
27
- bool check_fast_path_conditions (
28
- ET_UNUSED const Tensor& in,
29
- TensorOptList indices,
30
- size_t * dim) {
31
- bool found_index = false ;
32
- for (const auto i : c10::irange (indices.size ())) {
33
- if (indices[i].has_value ()) {
34
- *dim = i;
35
- // Fast path only supports a single non-null index tensor
36
- if (found_index) {
37
- return false ;
38
- }
39
- found_index = true ;
40
- const Tensor& index = indices[i].value ();
41
- ScalarType ix_type = index.scalar_type ();
42
- // Fast path only supports only supports Long or Int index tensors
43
- if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44
- return false ;
45
- }
46
- // Fast path only supports a 1-dimensional index tensor
47
- if (index.dim () != 1 ) {
48
- return false ;
49
- }
50
- }
51
- }
52
-
53
- // Fast path only supports needs at least one non-null index tensor
54
- if (!found_index) {
55
- return false ;
56
- }
57
-
58
- return true ;
59
- }
60
-
61
- bool check_fast_path_args (
62
- const Tensor& in,
63
- TensorOptList indices,
64
- size_t dim,
65
- Tensor& out) {
66
- ET_LOG_AND_RETURN_IF_FALSE (tensors_have_same_dtype (in, out));
67
-
68
- ET_CHECK_OR_RETURN_FALSE (
69
- static_cast <ssize_t >(indices.size ()) <= in.dim (),
70
- " Indexing too many dimensions" );
71
-
72
- const Tensor& index = indices[dim].value ();
73
-
74
- bool is_valid_index = true ;
75
- ET_SWITCH_TWO_TYPES (
76
- Long, Int, index.scalar_type (), ctx, " index_put_" , CTYPE, [&]() {
77
- const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
78
- for (const auto i : c10::irange (index.numel ())) {
79
- if (index_arr[i] < 0 ||
80
- index_arr[i] >= static_cast <CTYPE>(in.size (dim))) {
81
- ET_LOG (
82
- Error,
83
- " Index %" PRId64
84
- " out of range for tensor with size %zd"
85
- " at dimension %zu" ,
86
- static_cast <int64_t >(index_arr[i]),
87
- in.size (dim),
88
- dim);
89
- is_valid_index = false ;
90
- break ;
91
- }
92
- }
93
- });
94
-
95
- ET_CHECK_OR_RETURN_FALSE (
96
- is_valid_index,
97
- " Some index values are not within bounds of input tensor at indexed dim" );
98
-
99
- return true ;
100
- }
101
-
102
- Tensor& fast_path (
25
+ Tensor& index_Tensor_out (
103
26
KernelRuntimeContext& ctx,
104
27
const Tensor& in,
105
28
TensorOptList indices,
106
- size_t dim,
107
29
Tensor& out) {
108
30
(void )ctx;
109
31
110
32
ET_KERNEL_CHECK (
111
- ctx, check_fast_path_args (in, indices, dim, out), InvalidArgument, out);
112
-
113
- const Tensor& index = indices[dim].value ();
114
- ScalarType index_type = index.scalar_type ();
115
-
116
- if (out.dim () == 0 ) {
117
- memcpy (out.mutable_data_ptr (), in.const_data_ptr (), out.nbytes ());
118
- return out;
119
- }
120
-
121
- size_t leading_dims = getLeadingDims (in, dim);
122
- size_t trailing_dims = getTrailingDims (in, dim);
123
-
124
- if (leading_dims == 0 || trailing_dims == 0 ) {
125
- return out;
126
- }
127
-
128
- size_t in_dim_length = in.size (dim);
129
- size_t out_dim_length = out.size (dim);
130
-
131
- size_t length_per_step = trailing_dims * in.element_size ();
132
-
133
- const char * in_data = in.const_data_ptr <char >();
134
- char * out_data = out.mutable_data_ptr <char >();
135
-
136
- // @lint-ignore CLANGTIDY facebook-hte-CArray
137
- static constexpr const char op_name[] = " index.Tensor_out" ;
138
-
139
- ET_SWITCH_TWO_TYPES (Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
140
- const CTYPE* const index_arr = index.const_data_ptr <CTYPE>();
141
- for (const auto i : c10::irange (leading_dims)) {
142
- const char * src = in_data + i * in_dim_length * length_per_step;
143
- char * dest = out_data + i * out_dim_length * length_per_step;
144
- for (const auto j : c10::irange (out_dim_length)) {
145
- const char * copy_src = src + index_arr[j] * length_per_step;
146
- char * copy_dest = dest + j * length_per_step;
147
- memcpy (copy_dest, copy_src, length_per_step);
148
- }
149
- }
150
- });
151
-
152
- return out;
153
- }
154
-
155
- } // namespace
156
-
157
- Tensor& index_Tensor_out (
158
- KernelRuntimeContext& ctx,
159
- const Tensor& in,
160
- TensorOptList indices,
161
- Tensor& out) {
162
- (void )ctx;
33
+ ctx, check_index_args (in, indices, out), InvalidArgument, out);
163
34
164
35
ET_KERNEL_CHECK (
165
36
ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
166
37
167
38
ET_KERNEL_CHECK (ctx, tensor_is_default_dim_order (in), InvalidArgument, out);
168
39
169
- size_t dim = 0 ;
170
- bool is_fast_path = check_fast_path_conditions (in, indices, &dim);
171
- if (is_fast_path) {
172
- return fast_path (ctx, in, indices, dim, out);
173
- }
174
-
175
- ET_KERNEL_CHECK (
176
- ctx, check_index_args (in, indices, out), InvalidArgument, out);
177
-
178
40
ScalarType in_type = in.scalar_type ();
179
41
size_t block_count = count_index_blocks (indices);
180
42
0 commit comments