@@ -271,7 +271,7 @@ TFLiteImporter::DispatchMap TFLiteImporter::buildDispatchMap()
271
271
dispatch[" DEPTHWISE_CONV_2D" ] = &TFLiteImporter::parseDWConvolution;
272
272
dispatch[" ADD" ] = dispatch[" MUL" ] = &TFLiteImporter::parseEltwise;
273
273
dispatch[" RELU" ] = dispatch[" PRELU" ] = dispatch[" HARD_SWISH" ] =
274
- dispatch[" LOGISTIC" ] = &TFLiteImporter::parseActivation;
274
+ dispatch[" LOGISTIC" ] = dispatch[ " LEAKY_RELU " ] = &TFLiteImporter::parseActivation;
275
275
dispatch[" MAX_POOL_2D" ] = dispatch[" AVERAGE_POOL_2D" ] = &TFLiteImporter::parsePooling;
276
276
dispatch[" MaxPoolingWithArgmax2D" ] = &TFLiteImporter::parsePoolingWithArgmax;
277
277
dispatch[" MaxUnpooling2D" ] = &TFLiteImporter::parseUnpooling;
@@ -1029,6 +1029,7 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
1029
1029
}
1030
1030
1031
1031
void TFLiteImporter::parseActivation (const Operator& op, const std::string& opcode, LayerParams& activParams, bool isFused) {
1032
+ float slope = 0 .;
1032
1033
if (opcode == " NONE" )
1033
1034
return ;
1034
1035
else if (opcode == " RELU6" )
@@ -1041,6 +1042,13 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
1041
1042
activParams.type = " HardSwish" ;
1042
1043
else if (opcode == " LOGISTIC" )
1043
1044
activParams.type = " Sigmoid" ;
1045
+ else if (opcode == " LEAKY_RELU" )
1046
+ {
1047
+ activParams.type = " ReLU" ;
1048
+ auto options = reinterpret_cast <const LeakyReluOptions*>(op.builtin_options ());
1049
+ slope = options->alpha ();
1050
+ activParams.set (" negative_slope" , slope);
1051
+ }
1044
1052
else
1045
1053
CV_Error (Error::StsNotImplemented, " Unsupported activation " + opcode);
1046
1054
@@ -1072,6 +1080,8 @@ void TFLiteImporter::parseActivation(const Operator& op, const std::string& opco
1072
1080
y = 1 .0f / (1 .0f + std::exp (-x));
1073
1081
else if (opcode == " HARD_SWISH" )
1074
1082
y = x * max (0 .f , min (1 .f , x / 6 .f + 0 .5f ));
1083
+ else if (opcode == " LEAKY_RELU" )
1084
+ y = x >= 0 .f ? x : slope*x;
1075
1085
else
1076
1086
CV_Error (Error::StsNotImplemented, " Lookup table for " + opcode);
1077
1087
0 commit comments