Skip to content

Commit 0f35412

Browse files
committed
Merge pull request opencv#19545 from SamFC10:exp
2 parents c131c12 + 6111935 commit 0f35412

File tree

12 files changed

+253
-3
lines changed

12 files changed

+253
-3
lines changed

modules/dnn/include/opencv2/dnn/all_layers.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,14 @@ CV__DNN_INLINE_NS_BEGIN
499499
static Ptr<PowerLayer> create(const LayerParams &params);
500500
};
501501

502+
class CV_EXPORTS ExpLayer : public ActivationLayer
503+
{
504+
public:
505+
float base, scale, shift;
506+
507+
static Ptr<ExpLayer> create(const LayerParams &params);
508+
};
509+
502510
/* Layers used in semantic segmentation */
503511

504512
class CV_EXPORTS CropLayer : public Layer

modules/dnn/src/cuda/activations.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale,
145145
generic_op<T, PowerFunctor<T>>(stream, output, input, {exp, scale, shift});
146146
}
147147

148+
template <class T>
149+
void exp(const Stream& stream, Span<T> output, View<T> input, T normScale, T normShift) {
150+
generic_op<T, ExpFunctor<T>>(stream, output, input, {normScale, normShift});
151+
}
152+
148153
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
149154
template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
150155
template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
@@ -156,6 +161,7 @@ template void elu<__half>(const Stream&, Span<__half>, View<__half>);
156161
template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
157162
template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
158163
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
164+
template void exp<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
159165
#endif
160166

161167

@@ -169,6 +175,7 @@ template void elu<float>(const Stream&, Span<float>, View<float>);
169175
template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
170176
template void bnll<float>(const Stream&, Span<float>, View<float>);
171177
template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
178+
template void exp<float>(const Stream&, Span<float>, View<float>, float, float);
172179

173180
template <class T, std::size_t N> static
174181
void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<T> input, std::size_t inner_size, View<T> slope) {

modules/dnn/src/cuda/functors.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,25 @@ struct PowerFunctor {
228228
T exp, scale, shift;
229229
};
230230

231+
template <class T>
232+
struct ExpFunctor {
233+
struct Params {
234+
CUDA4DNN_HOST_DEVICE Params() : normScale(1), normShift(0) { }
235+
CUDA4DNN_HOST_DEVICE Params(T nScale_, T nShift_) : normScale(nScale_), normShift(nShift_) { }
236+
T normScale, normShift;
237+
};
238+
239+
CUDA4DNN_DEVICE ExpFunctor() : ExpFunctor(Params{}) { }
240+
CUDA4DNN_DEVICE ExpFunctor(const Params& params) : normScale{params.normScale}, normShift{params.normShift} { }
241+
242+
CUDA4DNN_DEVICE T operator()(T value) {
243+
using csl::device::fast_exp;
244+
return fast_exp(normShift + normScale * value);
245+
}
246+
247+
T normScale, normShift;
248+
};
249+
231250
template <class T>
232251
struct MaxFunctor {
233252
struct Params {
@@ -297,4 +316,4 @@ struct DivFunctor {
297316

298317
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
299318

300-
#endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */
319+
#endif /* OPENCV_DNN_SRC_CUDA_FUNCTORS_HPP */

modules/dnn/src/cuda4dnn/kernels/activations.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
4545
template <class T>
4646
void power(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T exp, T scale, T shift);
4747

48+
template <class T>
49+
void exp(const csl::Stream& stream, csl::Span<T> output, csl::View<T> input, T normScale, T normShift);
50+
4851
}}}} /* namespace cv::dnn::cuda4dnn::kernels */
4952

5053
#endif /* OPENCV_DNN_SRC_CUDA4DNN_KERNELS_ACTIVATIONS_HPP */

modules/dnn/src/cuda4dnn/primitives/activation.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,36 @@ namespace cv { namespace dnn { namespace cuda4dnn {
341341
const T exp, scale, shift;
342342
};
343343

344+
template <class T>
345+
class ExpOp final : public CUDABackendNode {
346+
public:
347+
using wrapper_type = GetCUDABackendWrapperType<T>;
348+
349+
ExpOp(csl::Stream stream_, T nScale_, T nShift_)
350+
: stream(std::move(stream_)), normScale{ nScale_ }, normShift{ nShift_ } { }
351+
352+
void forward(
353+
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
354+
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
355+
csl::Workspace& workspace) override
356+
{
357+
for (int i = 0; i < inputs.size(); i++)
358+
{
359+
auto input_wrapper = inputs[i].dynamicCast<wrapper_type>();
360+
auto input = input_wrapper->getView();
361+
362+
auto output_wrapper = outputs[i].dynamicCast<wrapper_type>();
363+
auto output = output_wrapper->getSpan();
364+
365+
kernels::exp<T>(stream, output, input, normScale, normShift);
366+
}
367+
}
368+
369+
private:
370+
csl::Stream stream;
371+
const T normScale, normShift;
372+
};
373+
344374
}}} /* namespace cv::dnn::cuda4dnn */
345375

346376
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_ACTIVATION_HPP */

modules/dnn/src/init.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ void initializeLayerFactory()
110110
CV_DNN_REGISTER_LAYER_CLASS(BNLL, BNLLLayer);
111111
CV_DNN_REGISTER_LAYER_CLASS(AbsVal, AbsLayer);
112112
CV_DNN_REGISTER_LAYER_CLASS(Power, PowerLayer);
113+
CV_DNN_REGISTER_LAYER_CLASS(Exp, ExpLayer);
113114
CV_DNN_REGISTER_LAYER_CLASS(BatchNorm, BatchNormLayer);
114115
CV_DNN_REGISTER_LAYER_CLASS(MaxUnpool, MaxUnpoolLayer);
115116
CV_DNN_REGISTER_LAYER_CLASS(Dropout, BlankLayer);

modules/dnn/src/layers/elementwise_layers.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,125 @@ struct PowerFunctor : public BaseFunctor
14001400
int64 getFLOPSPerElement() const { return power == 1 ? 2 : 10; }
14011401
};
14021402

1403+
struct ExpFunctor : public BaseFunctor
1404+
{
1405+
typedef ExpLayer Layer;
1406+
float base, scale, shift;
1407+
float normScale, normShift;
1408+
1409+
ExpFunctor(float base_ = -1.f, float scale_ = 1.f, float shift_ = 0.f)
1410+
: base(base_), scale(scale_), shift(shift_)
1411+
{
1412+
CV_Check(base, base == -1.f || base > 0.f, "Unsupported 'base' value");
1413+
}
1414+
1415+
bool supportBackend(int backendId, int targetId)
1416+
{
1417+
return backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_CUDA ||
1418+
backendId == DNN_BACKEND_HALIDE || backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
1419+
}
1420+
1421+
void finalize()
1422+
{
1423+
// For base > 0 :
1424+
// y = base^(scale * input + shift)
1425+
// ln(y) = ln(base)*(scale * input + shift)
1426+
// y = exp((ln(base)*scale) * input + (ln(base)*shift))
1427+
// y = exp(normalized_scale * input + normalized_shift)
1428+
1429+
float ln_base = (base == -1.f) ? 1.f : log(base);
1430+
normScale = scale * ln_base;
1431+
normShift = shift * ln_base;
1432+
}
1433+
1434+
void apply(const float* srcptr, float* dstptr, int len, size_t planeSize, int cn0, int cn1) const
1435+
{
1436+
float a = normScale, b = normShift;
1437+
for( int cn = cn0; cn < cn1; cn++, srcptr += planeSize, dstptr += planeSize )
1438+
{
1439+
for( int i = 0; i < len; i++ )
1440+
{
1441+
float x = srcptr[i];
1442+
dstptr[i] = exp(a*x + b);
1443+
}
1444+
}
1445+
}
1446+
1447+
#ifdef HAVE_OPENCL
1448+
bool applyOCL(InputArrayOfArrays inps, OutputArrayOfArrays outs, OutputArrayOfArrays internals)
1449+
{
1450+
std::vector<UMat> inputs;
1451+
std::vector<UMat> outputs;
1452+
1453+
inps.getUMatVector(inputs);
1454+
outs.getUMatVector(outputs);
1455+
String buildopt = oclGetTMacro(inputs[0]);
1456+
1457+
for (size_t i = 0; i < inputs.size(); i++)
1458+
{
1459+
UMat& src = inputs[i];
1460+
UMat& dst = outputs[i];
1461+
1462+
ocl::Kernel kernel("ExpForward", ocl::dnn::activations_oclsrc, buildopt);
1463+
kernel.set(0, (int)src.total());
1464+
kernel.set(1, ocl::KernelArg::PtrReadOnly(src));
1465+
kernel.set(2, ocl::KernelArg::PtrWriteOnly(dst));
1466+
kernel.set(3, (float)normScale);
1467+
kernel.set(4, (float)normShift);
1468+
1469+
size_t gSize = src.total();
1470+
CV_Assert(kernel.run(1, &gSize, NULL, false));
1471+
}
1472+
return true;
1473+
}
1474+
#endif
1475+
1476+
#ifdef HAVE_CUDA
1477+
Ptr<BackendNode> initCUDA(int target, csl::Stream stream)
1478+
{
1479+
return make_cuda_node<cuda4dnn::ExpOp>(target, stream, normScale, normShift);
1480+
}
1481+
#endif
1482+
1483+
#ifdef HAVE_HALIDE
1484+
void attachHalide(const Halide::Expr& input, Halide::Func& top)
1485+
{
1486+
Halide::Var x("x"), y("y"), c("c"), n("n");
1487+
top(x, y, c, n) = exp(normScale * input + normShift);
1488+
}
1489+
#endif // HAVE_HALIDE
1490+
1491+
#ifdef HAVE_DNN_IE_NN_BUILDER_2019
1492+
InferenceEngine::Builder::Layer initInfEngineBuilderAPI()
1493+
{
1494+
CV_Error(Error::StsNotImplemented, "");
1495+
}
1496+
#endif // HAVE_DNN_IE_NN_BUILDER_2019
1497+
1498+
#ifdef HAVE_DNN_NGRAPH
1499+
std::shared_ptr<ngraph::Node> initNgraphAPI(const std::shared_ptr<ngraph::Node>& node)
1500+
{
1501+
auto scale_node = std::make_shared<ngraph::op::Constant>(ngraph::element::f32,
1502+
ngraph::Shape{1}, &normScale);
1503+
auto shift_node = std::make_shared<ngraph::op::Constant>(ngraph::element::f32,
1504+
ngraph::Shape{1}, &normShift);
1505+
auto mul = std::make_shared<ngraph::op::v1::Multiply>(scale_node, node, ngraph::op::AutoBroadcastType::NUMPY);
1506+
auto scale_shift = std::make_shared<ngraph::op::v1::Add>(mul, shift_node, ngraph::op::AutoBroadcastType::NUMPY);
1507+
return std::make_shared<ngraph::op::v0::Exp>(scale_shift);
1508+
}
1509+
#endif // HAVE_DNN_NGRAPH
1510+
1511+
#ifdef HAVE_VULKAN
1512+
std::shared_ptr<vkcom::OpBase> initVkCom()
1513+
{
1514+
// TODO: add vkcom implementation
1515+
return std::shared_ptr<vkcom::OpBase>();
1516+
}
1517+
#endif // HAVE_VULKAN
1518+
1519+
int64 getFLOPSPerElement() const { return 3; }
1520+
};
1521+
14031522
struct ChannelsPReLUFunctor : public BaseFunctor
14041523
{
14051524
typedef ChannelsPReLULayer Layer;
@@ -1634,6 +1753,20 @@ Ptr<PowerLayer> PowerLayer::create(const LayerParams& params)
16341753
return l;
16351754
}
16361755

1756+
Ptr<ExpLayer> ExpLayer::create(const LayerParams& params)
1757+
{
1758+
float base = params.get<float>("base", -1.0f);
1759+
float scale = params.get<float>("scale", 1.0f);
1760+
float shift = params.get<float>("shift", 0.0f);
1761+
Ptr<ExpLayer> l(new ElementWiseLayer<ExpFunctor>(ExpFunctor(base, scale, shift)));
1762+
l->setParamsFrom(params);
1763+
l->base = base;
1764+
l->scale = scale;
1765+
l->shift = shift;
1766+
1767+
return l;
1768+
}
1769+
16371770
Ptr<Layer> ChannelsPReLULayer::create(const LayerParams& params)
16381771
{
16391772
CV_Assert(params.blobs.size() == 1);

modules/dnn/src/opencl/activations.cl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,14 @@ __kernel void ELUForward(const int n, __global const T* in, __global T* out)
140140
out[index] = (src >= 0.f) ? src : exp(src) - 1;
141141
}
142142
}
143+
144+
__kernel void ExpForward(const int n, __global const T* in, __global T* out,
145+
const KERNEL_ARG_DTYPE normScale,
146+
const KERNEL_ARG_DTYPE normShift)
147+
{
148+
int index = get_global_id(0);
149+
if (index < n)
150+
{
151+
out[index] = exp(normShift + normScale * in[index]);
152+
}
153+
}

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2425,7 +2425,7 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
24252425
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
24262426
}
24272427
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
2428-
type == "Relu" || type == "Elu" ||
2428+
type == "Relu" || type == "Elu" || type == "Exp" ||
24292429
type == "Identity" || type == "Relu6")
24302430
{
24312431
CV_CheckGT(num_inputs, 0, "");

modules/dnn/test/test_halide_layers.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,31 @@ INSTANTIATE_TEST_CASE_P(Layer_Test_Halide, Power, Combine(
632632
dnnBackendsAndTargetsWithHalide()
633633
));
634634

635+
typedef TestWithParam<tuple<Vec3f, tuple<Backend, Target> > > Exp;
636+
TEST_P(Exp, Accuracy)
637+
{
638+
float base = get<0>(GetParam())[0];
639+
float scale = get<0>(GetParam())[1];
640+
float shift = get<0>(GetParam())[2];
641+
Backend backendId = get<0>(get<1>(GetParam()));
642+
Target targetId = get<1>(get<1>(GetParam()));
643+
644+
LayerParams lp;
645+
lp.set("base", base);
646+
lp.set("scale", scale);
647+
lp.set("shift", shift);
648+
lp.type = "Exp";
649+
lp.name = "testLayer";
650+
testInPlaceActivation(lp, backendId, targetId);
651+
}
652+
653+
INSTANTIATE_TEST_CASE_P(Layer_Test_Halide, Exp, Combine(
654+
/*base, scale, shift*/ Values(Vec3f(0.9f, -1.0f, 1.1f), Vec3f(0.9f, 1.1f, -1.0f),
655+
Vec3f(-1.0f, 0.9f, 1.1f), Vec3f(-1.0f, 1.1f, 0.9f),
656+
Vec3f(1.1f, 0.9f, -1.0f), Vec3f(1.1f, -1.0f, 0.9f)),
657+
dnnBackendsAndTargetsWithHalide()
658+
));
659+
635660
TEST_P(Test_Halide_layers, ChannelsPReLU)
636661
{
637662
LayerParams lp;

0 commit comments

Comments
 (0)