Skip to content

Commit 41a9c8a

Browse files
Porting Reverse_V2 operator from TFLite (#3123)
* Sync files related to Reverse_V2 from TFLite #3110 * Added Reverse_V2 changes BUG=fixes #3110 * Using stable_sort instead of sort & format fixes * Replace std::stable_sort with qsort * fix format issues * fix format issues * fix format issues * Updated the changes as per the review --------- Co-authored-by: suleshahid <110432064+suleshahid@users.noreply.github.com>
1 parent f50a6ea commit 41a9c8a

File tree

9 files changed

+646
-5
lines changed

9 files changed

+646
-5
lines changed

python/tflite_micro/python_ops_resolver.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -104,6 +104,7 @@ PythonOpsResolver::PythonOpsResolver() {
104104
AddReshape();
105105
AddResizeBilinear();
106106
AddResizeNearestNeighbor();
107+
AddReverseV2();
107108
AddRfft();
108109
AddRound();
109110
AddRsqrt();

tensorflow/lite/micro/kernels/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ tflm_kernel_cc_library(
292292
"reshape_common.cc",
293293
"resize_bilinear.cc",
294294
"resize_nearest_neighbor.cc",
295+
"reverse.cc",
295296
"round.cc",
296297
"select.cc",
297298
"shape.cc",
@@ -1224,6 +1225,20 @@ tflm_cc_test(
12241225
],
12251226
)
12261227

1228+
tflm_cc_test(
1229+
name = "reverse_test",
1230+
srcs = [
1231+
"reverse_test.cc",
1232+
],
1233+
deps = [
1234+
":kernel_runner",
1235+
"//tensorflow/lite/c:common",
1236+
"//tensorflow/lite/micro:op_resolvers",
1237+
"//tensorflow/lite/micro:test_helpers",
1238+
"//tensorflow/lite/micro/testing:micro_test",
1239+
],
1240+
)
1241+
12271242
tflm_cc_test(
12281243
name = "round_test",
12291244
srcs = [

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -160,6 +160,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/reduce_test.cc \
160160
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/reshape_test.cc \
161161
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/resize_bilinear_test.cc \
162162
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc \
163+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/reverse_test.cc \
163164
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/round_test.cc \
164165
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/select_test.cc \
165166
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/shape_test.cc \

tensorflow/lite/micro/kernels/micro_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -105,6 +105,7 @@ TFLMRegistration Register_RELU6();
105105
TFLMRegistration Register_RESHAPE();
106106
TFLMRegistration Register_RESIZE_BILINEAR();
107107
TFLMRegistration Register_RESIZE_NEAREST_NEIGHBOR();
108+
TFLMRegistration Register_REVERSE_V2();
108109
TFLMRegistration Register_ROUND();
109110
TFLMRegistration Register_RSQRT();
110111
TFLMRegistration Register_SELECT_V2();
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/lite/kernels/internal/reference/reverse.h"
16+
17+
#include <stdint.h>
18+
19+
#include <cstdlib>
20+
#include <cstring>
21+
22+
#include "tensorflow/lite/c/common.h"
23+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24+
#include "tensorflow/lite/kernels/kernel_util.h"
25+
#include "tensorflow/lite/micro/kernels/kernel_util.h"
26+
#include "tensorflow/lite/micro/micro_log.h"
27+
#include "tensorflow/lite/micro/micro_utils.h"
28+
29+
namespace tflite {
30+
namespace {
31+
32+
constexpr int kMaxDimensions = RuntimeShape::kMaxSmallSize;
33+
constexpr int kInputTensor = 0;
34+
constexpr int kAxisTensor = 1;
35+
constexpr int kOutputTensor = 0;
36+
37+
int comp(const void* a, const void* b) {
38+
const int* int_a = static_cast<const int*>(a);
39+
const int* int_b = static_cast<const int*>(b);
40+
41+
return (*int_a - *int_b);
42+
}
43+
44+
TfLiteStatus ReverseV2Prepare(TfLiteContext* context, TfLiteNode* node) {
45+
MicroContext* micro_context = GetMicroContext(context);
46+
47+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
48+
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
49+
50+
// Ensure inputs and outputs exist.
51+
TfLiteTensor* input =
52+
micro_context->AllocateTempInputTensor(node, kInputTensor);
53+
TF_LITE_ENSURE(context, input != nullptr);
54+
TfLiteTensor* axis =
55+
micro_context->AllocateTempInputTensor(node, kAxisTensor);
56+
TF_LITE_ENSURE(context, axis != nullptr);
57+
TfLiteTensor* output =
58+
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
59+
TF_LITE_ENSURE(context, output != nullptr);
60+
TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
61+
TF_LITE_ENSURE(context, NumDimensions(input) <= kMaxDimensions);
62+
TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
63+
64+
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
65+
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
66+
input->type != kTfLiteInt16 && input->type != kTfLiteInt64 &&
67+
input->type != kTfLiteBool) {
68+
MicroPrintf("Type '%s' is not supported by reverse.",
69+
TfLiteTypeGetName(input->type));
70+
return kTfLiteError;
71+
}
72+
73+
if (axis->type != kTfLiteInt32) {
74+
MicroPrintf("Axis Type '%s' is not supported by reverse.",
75+
TfLiteTypeGetName(axis->type));
76+
return kTfLiteError;
77+
}
78+
// The value type and output type must match.
79+
TF_LITE_ENSURE_EQ(context, input->type, output->type);
80+
81+
micro_context->DeallocateTempTfLiteTensor(input);
82+
micro_context->DeallocateTempTfLiteTensor(axis);
83+
micro_context->DeallocateTempTfLiteTensor(output);
84+
return kTfLiteOk;
85+
}
86+
87+
TfLiteStatus ReverseV2Eval(TfLiteContext* context, TfLiteNode* node) {
88+
const TfLiteEvalTensor* input =
89+
micro::GetEvalInput(context, node, kInputTensor);
90+
const TfLiteEvalTensor* axis =
91+
micro::GetEvalInput(context, node, kAxisTensor);
92+
TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
93+
94+
const int num_axes = static_cast<int>(ElementCount(*axis->dims));
95+
96+
// TFLite reverse implementation is expecting fixed size 8,
97+
// so using 8 below.
98+
std::array<int32_t, 8> axes_data;
99+
std::memcpy(axes_data.data(), axis->data.data, sizeof(int32_t) * num_axes);
100+
const int rank = tflite::micro::GetTensorShape(input).DimensionsCount();
101+
for (int i = 0; i < num_axes; ++i) {
102+
if (axes_data[i] < 0) {
103+
axes_data[i] += rank;
104+
}
105+
TF_LITE_ENSURE(context, axes_data[i] >= 0 && axes_data[i] < rank);
106+
}
107+
std::qsort(axes_data.data(), num_axes, sizeof(int32_t), comp);
108+
109+
bool is_contiguous = true;
110+
for (int i = 1; i < num_axes; ++i) {
111+
if (axes_data[i - 1] + 1 != axes_data[i]) {
112+
is_contiguous = false;
113+
break;
114+
}
115+
}
116+
if (!is_contiguous) {
117+
MicroPrintf("Non-contiguous `axes` not supported");
118+
return kTfLiteError;
119+
}
120+
121+
switch (output->type) {
122+
case kTfLiteFloat32:
123+
reference_ops::Reverse<float>(
124+
axes_data, num_axes, tflite::micro::GetTensorShape(input),
125+
tflite::micro::GetTensorData<float>(input),
126+
tflite::micro::GetTensorData<float>(output));
127+
break;
128+
case kTfLiteInt32:
129+
reference_ops::Reverse<int32_t>(
130+
axes_data, num_axes, tflite::micro::GetTensorShape(input),
131+
tflite::micro::GetTensorData<int32_t>(input),
132+
tflite::micro::GetTensorData<int32_t>(output));
133+
break;
134+
case kTfLiteInt16:
135+
reference_ops::Reverse<int16_t>(
136+
axes_data, num_axes, tflite::micro::GetTensorShape(input),
137+
tflite::micro::GetTensorData<int16_t>(input),
138+
tflite::micro::GetTensorData<int16_t>(output));
139+
break;
140+
case kTfLiteInt8:
141+
case kTfLiteUInt8:
142+
reference_ops::Reverse<uint8_t>(
143+
axes_data, num_axes, tflite::micro::GetTensorShape(input),
144+
tflite::micro::GetTensorData<uint8_t>(input),
145+
tflite::micro::GetTensorData<uint8_t>(output));
146+
break;
147+
case kTfLiteInt64:
148+
reference_ops::Reverse<int64_t>(
149+
axes_data, num_axes, tflite::micro::GetTensorShape(input),
150+
tflite::micro::GetTensorData<int64_t>(input),
151+
tflite::micro::GetTensorData<int64_t>(output));
152+
break;
153+
case kTfLiteBool:
154+
reference_ops::Reverse<bool>(axes_data, num_axes,
155+
tflite::micro::GetTensorShape(input),
156+
tflite::micro::GetTensorData<bool>(input),
157+
tflite::micro::GetTensorData<bool>(output));
158+
break;
159+
default:
160+
MicroPrintf("Output type '%s' (%d) is not supported.",
161+
TfLiteTypeGetName(output->type), output->type);
162+
return kTfLiteError;
163+
}
164+
165+
return kTfLiteOk;
166+
}
167+
168+
} // namespace
169+
170+
TFLMRegistration Register_REVERSE_V2() {
171+
return tflite::micro::RegisterOp(nullptr, ReverseV2Prepare, ReverseV2Eval);
172+
}
173+
174+
} // namespace tflite

0 commit comments

Comments
 (0)