Skip to content

Support for DECODE operator #3132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tflite_micro/python_ops_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddConv2D();
AddCos();
AddCumSum();
AddDecode();
AddDelay();
AddDepthToSpace();
AddDepthwiseConv2D();
Expand Down
20 changes: 20 additions & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ tflm_kernel_cc_library(
"conv.cc",
"conv_common.cc",
"cumsum.cc",
"decode.cc",
"decode_state.cc",
"decode_state_lut.cc",
"depth_to_space.cc",
"depthwise_conv.cc",
"depthwise_conv_common.cc",
Expand Down Expand Up @@ -327,6 +330,8 @@ tflm_kernel_cc_library(
"batch_matmul.h",
"circular_buffer.h",
"conv.h",
"decode_state.h",
"decode_state_lut.h",
"depthwise_conv.h",
"dequantize.h",
"ethosu.h",
Expand Down Expand Up @@ -643,6 +648,21 @@ tflm_cc_test(
],
)

tflm_cc_test(
name = "decode_test",
srcs = [
"decode_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

tflm_cc_test(
name = "decompress_test",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/Makefile.inc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
Expand Down
148 changes: 148 additions & 0 deletions tensorflow/lite/micro/kernels/decode.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/decode_state.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {
namespace {

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
const size_t num_outputs = NumOutputs(node);
TF_LITE_ENSURE(context, num_outputs > 0);
TF_LITE_ENSURE_EQ(context, num_inputs, num_outputs * 2);

MicroContext* const micro_context = GetMicroContext(context);

node->user_data = micro_context->AllocatePersistentBuffer(
num_outputs * sizeof(DecodeState*));
TF_LITE_ENSURE(context, node->user_data != nullptr);
DecodeState** const dsp_arr =
reinterpret_cast<DecodeState**>(node->user_data);

TfLiteTensor* input = nullptr;
TfLiteTensor* ancillary = nullptr;
TfLiteTensor* output = nullptr;
TfLiteStatus status = kTfLiteOk;

for (size_t i = 0; i < num_inputs; i += 2) {
input = micro_context->AllocateTempInputTensor(node, i);
if (input == nullptr) {
MicroPrintf("failed to allocate input tensor %u", i);
status = kTfLiteError;
break;
}
ancillary = micro_context->AllocateTempInputTensor(node, i + 1);
if (ancillary == nullptr) {
MicroPrintf("failed to allocate ancillary tensor %u", i + 1);
status = kTfLiteError;
break;
}
output = micro_context->AllocateTempOutputTensor(node, i / 2);
if (output == nullptr) {
MicroPrintf("failed to allocate output tensor %u", i / 2);
status = kTfLiteError;
break;
}

if (DecodeState::Version(*ancillary) != 1) {
MicroPrintf("version %u != 1", DecodeState::Version(*ancillary));
status = kTfLiteError;
break;
}

DecodeState* dsp = nullptr;
switch (DecodeState::Type(*ancillary)) {
case DecodeState::kDcmTypeLUT:
dsp = DecodeState::CreateDecodeStateLUT(
context, micro_context->GetAlternateProfiler());
break;
case DecodeState::kDcmTypeCustom:
MicroPrintf("Custom decode type not yet supported");
break;
default:
MicroPrintf("unsupported decode type %u",
DecodeState::Type(*ancillary));
break;
}

if (dsp != nullptr) {
status = dsp->Setup(*input, *ancillary, *output);
if (status != kTfLiteOk) {
break;
}
dsp_arr[i / 2] = dsp;
} else {
MicroPrintf("failed to allocate DecodeState[%u]", i / 2);
break;
}

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(ancillary);
micro_context->DeallocateTempTfLiteTensor(output);
input = nullptr;
ancillary = nullptr;
output = nullptr;
}

if (input != nullptr) {
micro_context->DeallocateTempTfLiteTensor(input);
}
if (ancillary != nullptr) {
micro_context->DeallocateTempTfLiteTensor(ancillary);
}
if (output != nullptr) {
micro_context->DeallocateTempTfLiteTensor(output);
}

return status;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const size_t num_inputs = NumInputs(node);
DecodeState** const dsp_arr =
reinterpret_cast<DecodeState**>(node->user_data);

for (size_t i = 0; i < num_inputs; i += 2) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, i);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteEvalTensor* ancillary =
tflite::micro::GetEvalInput(context, node, i + 1);
TF_LITE_ENSURE(context, ancillary != nullptr);
const TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, i / 2);
TF_LITE_ENSURE(context, output != nullptr);

TfLiteStatus status = dsp_arr[i / 2]->Decode(*input, *ancillary, *output);
TF_LITE_ENSURE(context, status == kTfLiteOk);
}

return kTfLiteOk;
}

} // namespace

TFLMRegistration Register_DECODE() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}

} // namespace tflite
36 changes: 36 additions & 0 deletions tensorflow/lite/micro/kernels/decode_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/lite/micro/kernels/decode_state.h"

#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
#include "tensorflow/lite/micro/micro_context.h"

namespace tflite {

DecodeState* DecodeState::CreateDecodeStateLUT(
const TfLiteContext* context, MicroProfilerInterface* profiler) {
MicroContext* const micro_context = GetMicroContext(context);
void* buffer =
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateLUT));
if (buffer == nullptr) {
return nullptr;
}
DecodeState* dsp = new (buffer) DecodeStateLUT(context, profiler);

return dsp;
}

} // namespace tflite
87 changes: 87 additions & 0 deletions tensorflow/lite/micro/kernels/decode_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_
#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_

#include <cstdint>

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_profiler_interface.h"

namespace tflite {

struct DecodeState {
DecodeState() = delete;

DecodeState(const TfLiteContext* context, MicroProfilerInterface* profiler)
: context_(context), micro_profiler_(profiler) {}

virtual TfLiteStatus Setup(const TfLiteTensor& input,
const TfLiteTensor& ancillary,
const TfLiteTensor& output) = 0;
virtual TfLiteStatus Decode(const TfLiteEvalTensor& input,
const TfLiteEvalTensor& ancillary,
const TfLiteEvalTensor& output) = 0;

static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context,
MicroProfilerInterface* profiler);

static uint8_t Type(const TfLiteTensor& ancillary) {
return GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
}

static uint8_t Type(const TfLiteEvalTensor& ancillary) {
return micro::GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
}

static uint8_t Version(const TfLiteTensor& ancillary) {
return GetTensorData<uint8_t>(&ancillary)[kDcmVersionOffset];
}

static uint8_t Version(const TfLiteEvalTensor& ancillary) {
return micro::GetTensorData<uint8_t>(&ancillary)[kDcmVersionOffset];
}

protected:
virtual ~DecodeState() = default;

// Decode Common Metadata constants
public:
static constexpr uint8_t kDcmTypeLUT = 0;
static constexpr uint8_t kDcmTypeCustom = 127;

static constexpr size_t kDcmSizeInBytes = 16;

private:
static constexpr size_t kDcmDecodeTypeOffset = 0;
static constexpr size_t kDcmVersionOffset = 1;

// DecodeState vars
protected:
const TfLiteContext* context_;
MicroProfilerInterface* micro_profiler_;

private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};

} // namespace tflite

#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECODE_STATE_H_
Loading
Loading