Skip to content

Commit 209802c

Browse files
committed
Leaky RELU support for TFLite.
1 parent 79faf85 commit 209802c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

modules/dnn/src/tflite/tflite_importer.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
271271
dispatch["DEPTHWISE_CONV_2D"] = &TFLiteImporter::parseDWConvolution;
272272
dispatch["ADD"] = dispatch["MUL"] = &TFLiteImporter::parseEltwise;
273273
dispatch["RELU"] = dispatch["PRELU"] = dispatch["HARD_SWISH"] =
274-
dispatch["LOGISTIC"] = &TFLiteImporter::parseActivation;
274+
dispatch["LOGISTIC"] = dispatch["LEAKY_RELU"] = &TFLiteImporter::parseActivation;
275275
dispatch["MAX_POOL_2D"] = dispatch["AVERAGE_POOL_2D"] = &TFLiteImporter::parsePooling;
276276
dispatch["MaxPoolingWithArgmax2D"] = &TFLiteImporter::parsePoolingWithArgmax;
277277
dispatch["MaxUnpooling2D"] = &TFLiteImporter::parseUnpooling;
@@ -1029,6 +1029,7 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
10291029
}
10301030

10311031
void TFLiteImporter::parseActivation(const Operator& op, const std::string& opcode, LayerParams& activParams, bool isFused) {
1032+
float slope = 0.;
10321033
if (opcode == "NONE")
10331034
return;
10341035
else if (opcode == "RELU6")
@@ -1041,6 +1042,13 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
10411042
activParams.type = "HardSwish";
10421043
else if (opcode == "LOGISTIC")
10431044
activParams.type = "Sigmoid";
1045+
else if (opcode == "LEAKY_RELU")
1046+
{
1047+
activParams.type = "ReLU";
1048+
auto options = reinterpret_cast<const LeakyReluOptions*>(op.builtin_options());
1049+
slope = options->alpha();
1050+
activParams.set("negative_slope", slope);
1051+
}
10441052
else
10451053
CV_Error(Error::StsNotImplemented, "Unsupported activation " + opcode);
10461054

@@ -1072,6 +1080,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
10721080
y = 1.0f / (1.0f + std::exp(-x));
10731081
else if (opcode == "HARD_SWISH")
10741082
y = x * max(0.f, min(1.f, x / 6.f + 0.5f));
1083+
else if (opcode == "LEAKY_RELU")
1084+
y = x >= 0.f ? x : slope*x;
10751085
else
10761086
CV_Error(Error::StsNotImplemented, "Lookup table for " + opcode);
10771087

modules/dnn/test/test_tflite_importer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ TEST_P(Test_TFLite, global_max_pooling_2d) {
268268
testLayer("global_max_pooling_2d");
269269
}
270270

271+
TEST_P(Test_TFLite, leakyRelu) {
272+
testLayer("leakyRelu");
273+
}
274+
271275
INSTANTIATE_TEST_CASE_P(/**/, Test_TFLite, dnnBackendsAndTargets());
272276

273277
}} // namespace

0 commit comments

Comments
 (0)