Skip to content

Commit 73dc666

Browse files
committed
Merge pull request #1212 from dkurt:torch_softmax_layer
2 parents 6c9d6d5 + 78ff9d9 commit 73dc666

File tree

5 files changed

+35
-0
lines changed

5 files changed

+35
-0
lines changed

modules/dnn/include/opencv2/dnn/all_layers.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ namespace dnn
251251
class CV_EXPORTS SoftmaxLayer : public Layer
252252
{
253253
public:
254+
bool logSoftMax;
255+
254256
static Ptr<SoftmaxLayer> create(const LayerParams& params);
255257
};
256258

modules/dnn/include/opencv2/dnn/dnn.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
436436
* - nn.SpatialMaxPooling, nn.SpatialAveragePooling
437437
* - nn.ReLU, nn.TanH, nn.Sigmoid
438438
* - nn.Reshape
439+
* - nn.SoftMax, nn.LogSoftMax
439440
*
440441
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
441442
*/

modules/dnn/src/layers/softmax_layer.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class SoftMaxLayerImpl : public SoftmaxLayer
5757
SoftMaxLayerImpl(const LayerParams& params)
5858
{
5959
axisRaw = params.get<int>("axis", 1);
60+
logSoftMax = params.get<int>("log_softmax", false);
6061
setParamsFrom(params);
6162
}
6263

@@ -143,6 +144,14 @@ class SoftMaxLayerImpl : public SoftmaxLayer
143144
for (size_t i = 0; i < innerSize; i++)
144145
dstPtr[srcOffset + cnDim * cnStep + i] /= bufPtr[bufOffset + i];
145146
}
147+
if (logSoftMax)
148+
{
149+
for (size_t cnDim = 0; cnDim < channels; cnDim++)
150+
{
151+
for (size_t i = 0; i < innerSize; i++)
152+
dstPtr[srcOffset + cnDim * cnStep + i] = log(dstPtr[srcOffset + cnDim * cnStep + i]);
153+
}
154+
}
146155
}
147156
}
148157

modules/dnn/src/torch/torch_importer.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,17 @@ struct TorchImporter : public ::cv::dnn::Importer
741741
layerParams.set("indices_blob_id", tensorParams["indices"].first);
742742
curModule->modules.push_back(newModule);
743743
}
744+
else if (nnName == "SoftMax")
745+
{
746+
newModule->apiType = "SoftMax";
747+
curModule->modules.push_back(newModule);
748+
}
749+
else if (nnName == "LogSoftMax")
750+
{
751+
newModule->apiType = "SoftMax";
752+
layerParams.set("log_softmax", true);
753+
curModule->modules.push_back(newModule);
754+
}
744755
else
745756
{
746757
CV_Error(Error::StsNotImplemented, "Unknown nn class \"" + className + "\"");

modules/dnn/test/test_torch_importer.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ TEST(Torch_Importer, net_cadd_table)
159159
runTorchNet("net_cadd_table");
160160
}
161161

162+
TEST(Torch_Importer, net_softmax)
163+
{
164+
runTorchNet("net_softmax");
165+
runTorchNet("net_softmax_spatial");
166+
}
167+
168+
TEST(Torch_Importer, net_logsoftmax)
169+
{
170+
runTorchNet("net_logsoftmax");
171+
runTorchNet("net_logsoftmax_spatial");
172+
}
173+
162174
TEST(Torch_Importer, ENet_accuracy)
163175
{
164176
Net net;

0 commit comments

Comments
 (0)