@@ -320,34 +320,78 @@ void SegmentationModel::segment(InputArray frame, OutputArray mask)
320
320
}
321
321
}
322
322
323
- void disableRegionNMS (Net& net)
323
+ class DetectionModel_Impl : public Model ::Impl
324
324
{
325
- for (String& name : net.getUnconnectedOutLayersNames ())
325
+ public:
326
+ virtual ~DetectionModel_Impl () {}
327
+ DetectionModel_Impl () : Impl() {}
328
+ DetectionModel_Impl (const DetectionModel_Impl&) = delete ;
329
+ DetectionModel_Impl (DetectionModel_Impl&&) = delete ;
330
+
331
+ void disableRegionNMS (Net& net)
326
332
{
327
- int layerId = net.getLayerId (name);
328
- Ptr<RegionLayer> layer = net.getLayer (layerId).dynamicCast <RegionLayer>();
329
- if (!layer.empty ())
333
+ for (String& name : net.getUnconnectedOutLayersNames ())
330
334
{
331
- layer->nmsThreshold = 0 ;
335
+ int layerId = net.getLayerId (name);
336
+ Ptr<RegionLayer> layer = net.getLayer (layerId).dynamicCast <RegionLayer>();
337
+ if (!layer.empty ())
338
+ {
339
+ layer->nmsThreshold = 0 ;
340
+ }
332
341
}
333
342
}
334
- }
343
+
344
+ void setNmsAcrossClasses (bool value) {
345
+ nmsAcrossClasses = value;
346
+ }
347
+
348
+ bool getNmsAcrossClasses () {
349
+ return nmsAcrossClasses;
350
+ }
351
+
352
+ private:
353
+ bool nmsAcrossClasses = false ;
354
+ };
335
355
336
356
DetectionModel::DetectionModel (const String& model, const String& config)
337
- : Model(model, config)
357
+ : DetectionModel(readNet(model, config))
358
+ {
359
+ // nothing
360
+ }
361
+
362
+ DetectionModel::DetectionModel (const Net& network) : Model()
338
363
{
339
- disableRegionNMS (getNetwork_ ()); // FIXIT Move to DetectionModel::Impl::initNet()
364
+ impl = makePtr<DetectionModel_Impl>();
365
+ impl->initNet (network);
366
+ impl.dynamicCast <DetectionModel_Impl>()->disableRegionNMS (getNetwork_ ()); // FIXIT Move to DetectionModel::Impl::initNet()
367
+ }
368
+
369
+ DetectionModel::DetectionModel () : Model()
370
+ {
371
+ // nothing
372
+ }
373
+
374
+ DetectionModel& DetectionModel::setNmsAcrossClasses (bool value)
375
+ {
376
+ CV_Assert (impl != nullptr && impl.dynamicCast <DetectionModel_Impl>() != nullptr ); // remove once default constructor is removed
377
+
378
+ impl.dynamicCast <DetectionModel_Impl>()->setNmsAcrossClasses (value);
379
+ return *this ;
340
380
}
341
381
342
- DetectionModel::DetectionModel ( const Net& network) : Model(network )
382
+ bool DetectionModel::getNmsAcrossClasses ( )
343
383
{
344
- disableRegionNMS (getNetwork_ ()); // FIXIT Move to DetectionModel::Impl::initNet()
384
+ CV_Assert (impl != nullptr && impl.dynamicCast <DetectionModel_Impl>() != nullptr ); // remove once default constructor is removed
385
+
386
+ return impl.dynamicCast <DetectionModel_Impl>()->getNmsAcrossClasses ();
345
387
}
346
388
347
389
void DetectionModel::detect (InputArray frame, CV_OUT std::vector<int >& classIds,
348
390
CV_OUT std::vector<float >& confidences, CV_OUT std::vector<Rect>& boxes,
349
391
float confThreshold, float nmsThreshold)
350
392
{
393
+ CV_Assert (impl != nullptr && impl.dynamicCast <DetectionModel_Impl>() != nullptr ); // remove once default constructor is removed
394
+
351
395
std::vector<Mat> detections;
352
396
impl->processFrame (frame, detections);
353
397
@@ -413,7 +457,7 @@ void DetectionModel::detect(InputArray frame, CV_OUT std::vector<int>& classIds,
413
457
{
414
458
std::vector<int > predClassIds;
415
459
std::vector<Rect> predBoxes;
416
- std::vector<float > predConf ;
460
+ std::vector<float > predConfidences ;
417
461
for (int i = 0 ; i < detections.size (); ++i)
418
462
{
419
463
// Network produces output blob with a shape NxC where N is a number of
@@ -442,45 +486,59 @@ void DetectionModel::detect(InputArray frame, CV_OUT std::vector<int>& classIds,
442
486
height = std::max (1 , std::min (height, frameHeight - top));
443
487
444
488
predClassIds.push_back (classIdPoint.x );
445
- predConf .push_back (static_cast <float >(conf));
489
+ predConfidences .push_back (static_cast <float >(conf));
446
490
predBoxes.emplace_back (left, top, width, height);
447
491
}
448
492
}
449
493
450
494
if (nmsThreshold)
451
495
{
452
- std::map<int , std::vector<size_t > > class2indices;
453
- for (size_t i = 0 ; i < predClassIds.size (); i++)
496
+ if (getNmsAcrossClasses ())
454
497
{
455
- if (predConf[i] >= confThreshold)
498
+ std::vector<int > indices;
499
+ NMSBoxes (predBoxes, predConfidences, confThreshold, nmsThreshold, indices);
500
+ for (int idx : indices)
456
501
{
457
- class2indices[predClassIds[i]].push_back (i);
502
+ boxes.push_back (predBoxes[idx]);
503
+ confidences.push_back (predConfidences[idx]);
504
+ classIds.push_back (predClassIds[idx]);
458
505
}
459
506
}
460
- for ( const auto & it : class2indices)
507
+ else
461
508
{
462
- std::vector<Rect> localBoxes;
463
- std::vector<float > localConfidences;
464
- for (size_t idx : it.second )
509
+ std::map<int , std::vector<size_t > > class2indices;
510
+ for (size_t i = 0 ; i < predClassIds.size (); i++)
465
511
{
466
- localBoxes.push_back (predBoxes[idx]);
467
- localConfidences.push_back (predConf[idx]);
512
+ if (predConfidences[i] >= confThreshold)
513
+ {
514
+ class2indices[predClassIds[i]].push_back (i);
515
+ }
468
516
}
469
- std::vector<int > indices;
470
- NMSBoxes (localBoxes, localConfidences, confThreshold, nmsThreshold, indices);
471
- classIds.resize (classIds.size () + indices.size (), it.first );
472
- for (int idx : indices)
517
+ for (const auto & it : class2indices)
473
518
{
474
- boxes.push_back (localBoxes[idx]);
475
- confidences.push_back (localConfidences[idx]);
519
+ std::vector<Rect> localBoxes;
520
+ std::vector<float > localConfidences;
521
+ for (size_t idx : it.second )
522
+ {
523
+ localBoxes.push_back (predBoxes[idx]);
524
+ localConfidences.push_back (predConfidences[idx]);
525
+ }
526
+ std::vector<int > indices;
527
+ NMSBoxes (localBoxes, localConfidences, confThreshold, nmsThreshold, indices);
528
+ classIds.resize (classIds.size () + indices.size (), it.first );
529
+ for (int idx : indices)
530
+ {
531
+ boxes.push_back (localBoxes[idx]);
532
+ confidences.push_back (localConfidences[idx]);
533
+ }
476
534
}
477
535
}
478
536
}
479
537
else
480
538
{
481
539
boxes = std::move (predBoxes);
482
540
classIds = std::move (predClassIds);
483
- confidences = std::move (predConf );
541
+ confidences = std::move (predConfidences );
484
542
}
485
543
}
486
544
else
0 commit comments