Skip to content

Commit a42d4da

Browse files
Added Spatial Attention Module in Darknet Importer
1 parent 0689c70 commit a42d4da

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

modules/dnn/src/darknet/darknet_io.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,29 @@ namespace cv {
558558
fused_layer_names.push_back(last_layer);
559559
}
560560

561+
void setSAM(int from)
562+
{
563+
cv::dnn::LayerParams eltwise_param;
564+
eltwise_param.name = "SAM-name";
565+
eltwise_param.type = "Eltwise";
566+
567+
eltwise_param.set<std::string>("operation", "prod");
568+
eltwise_param.set<std::string>("output_channels_mode", "same");
569+
570+
darknet::LayerParameter lp;
571+
std::string layer_name = cv::format("sam_%d", layer_id);
572+
lp.layer_name = layer_name;
573+
lp.layer_type = eltwise_param.type;
574+
lp.layerParams = eltwise_param;
575+
lp.bottom_indexes.push_back(last_layer);
576+
lp.bottom_indexes.push_back(fused_layer_names.at(from));
577+
last_layer = layer_name;
578+
net->layers.push_back(lp);
579+
580+
layer_id++;
581+
fused_layer_names.push_back(last_layer);
582+
}
583+
561584
void setUpsample(int scaleFactor)
562585
{
563586
cv::dnn::LayerParams param;
@@ -837,6 +860,14 @@ namespace cv {
837860
from = from < 0 ? from + layers_counter : from;
838861
setParams.setScaleChannels(from);
839862
}
863+
else if (layer_type == "sam")
864+
{
865+
std::string bottom_layer = getParam<std::string>(layer_params, "from", "");
866+
CV_Assert(!bottom_layer.empty());
867+
int from = std::atoi(bottom_layer.c_str());
868+
from = from < 0 ? from + layers_counter : from;
869+
setParams.setSAM(from);
870+
}
840871
else if (layer_type == "upsample")
841872
{
842873
int scaleFactor = getParam<int>(layer_params, "stride", 1);

modules/dnn/test/test_darknet_importer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,11 @@ TEST_P(Test_Darknet_layers, relu)
770770
testDarknetLayer("relu");
771771
}
772772

773+
TEST_P(Test_Darknet_layers, sam)
774+
{
775+
testDarknetLayer("sam", true);
776+
}
777+
773778
INSTANTIATE_TEST_CASE_P(/**/, Test_Darknet_layers, dnnBackendsAndTargets());
774779

775780
}} // namespace

0 commit comments

Comments
 (0)