Skip to content

Commit dda33bf

Browse files
fannymonorialalek
authored andcommitted
Merge pull request #2229 from fannymonori:gsoc_dnn_superres
* Adding dnn based super resolution module. * Fixed whitespace error in unit test * Fixed errors with time measuring functions. * Updated unit tests in dnn superres * Deleted unnecessary indents in dnn superres * Refactored includes in dnn superres * Moved video upsampling functions to sample code in dnn superres. * Replaced couts with CV_Error in dnn superres * Moved benchmarking functionality to sample codes in dnn superres. * Added performance test to dnn superres * Resolve buildbot errors * update dnn_superres - avoid highgui dependency - cleanup public API - use InputArray/OutputArray - test: avoid legacy test API
1 parent 26129cf commit dda33bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1618
-0
lines changed

modules/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ $ cmake -D OPENCV_EXTRA_MODULES_PATH=<opencv_contrib>/modules -D BUILD_opencv_<r
2424

2525
- **dnn_objdetect**: Object Detection using CNNs -- Implements compact CNN Model for object detection. Trained using Caffe but uses opencv_dnn modeule.
2626

27+
- **dnn_superres**: Superresolution using CNNs -- Contains four trained convolutional neural networks to upscale images.
28+
2729
- **dnns_easily_fooled**: Subvert DNNs -- This code can use the activations in a network to fool the networks into recognizing something else.
2830

2931
- **dpm**: Deformable Part Model -- Felzenszwalb's Cascade with deformable parts object recognition code.

modules/dnn_superres/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(the_description "Super Resolution using CNNs")
2+
3+
ocv_define_module(dnn_superres opencv_core opencv_imgproc opencv_dnn
4+
OPTIONAL opencv_datasets opencv_quality # samples
5+
)

modules/dnn_superres/README.md

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Super Resolution using Convolutional Neural Networks
2+
3+
This module contains several learning-based algorithms for upscaling an image.
4+
5+
## Usage
6+
7+
Run the following command to build this module:
8+
9+
```make
10+
cmake -DOPENCV_EXTRA_MODULES_PATH=<opencv_contrib>/modules -Dopencv_dnn_superres=ON <opencv_source_dir>
11+
```
12+
13+
Refer to the tutorials to understand how to use this module.
14+
15+
## Models
16+
17+
There are four models which are trained.
18+
19+
#### EDSR
20+
21+
Trained models can be downloaded from [here](https://github.com/Saafke/EDSR_Tensorflow/tree/master/models).
22+
23+
- Size of the model: ~38.5MB. This is a quantized version, so that it can be uploaded to GitHub. (Original was 150MB.)
24+
- This model was trained for 3 days with a batch size of 16
25+
- Link to implementation code: https://github.com/Saafke/EDSR_Tensorflow
26+
- x2, x3, x4 trained models available
27+
- Advantage: Highly accurate
28+
- Disadvantage: Slow and large filesize
29+
- Speed: < 3 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
30+
- Original paper: [Enhanced Deep Residual Networks for Single Image Super-Resolution](https://arxiv.org/pdf/1707.02921.pdf) [1]
31+
32+
#### ESPCN
33+
34+
Trained models can be downloaded from [here](https://github.com/fannymonori/TF-ESPCN/tree/master/export).
35+
36+
- Size of the model: ~100kb
37+
- This model was trained for ~100 iterations with a batch size of 32
38+
- Link to implementation code: https://github.com/fannymonori/TF-ESPCN
39+
- x2, x3, x4 trained models available
40+
- Advantage: It is tiny and fast, and still performs well.
41+
- Disadvantage: Perform worse visually than newer, more robust models.
42+
- Speed: < 0.01 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
43+
- Original paper: [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](<https://arxiv.org/abs/1609.05158>) [2]
44+
45+
#### FSRCNN
46+
47+
Trained models can be downloaded from [here](https://github.com/Saafke/FSRCNN_Tensorflow/tree/master/models).
48+
49+
- Size of the model: ~40KB (~9kb for FSRCNN-small)
50+
- This model was trained for ~30 iterations with a batch size of 1
51+
- Link to implementation code: https://github.com/Saafke/FSRCNN_Tensorflow
52+
- Advantage: Fast, small and accurate
53+
- Disadvantage: Not state-of-the-art accuracy
54+
- Speed: < 0.01 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
55+
- Notes: FSRCNN-small has fewer parameters, thus less accurate but faster.
56+
- Original paper: [Accelerating the Super-Resolution Convolutional Neural Network](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html) [3]
57+
58+
#### LapSRN
59+
60+
Trained models can be downloaded from [here](https://github.com/fannymonori/TF-LapSRN/tree/master/export).
61+
62+
- Size of the model: between 1-5Mb
63+
- This model was trained for ~50 iterations with a batch size of 32
64+
- Link to implementation code: https://github.com/fannymonori/TF-LAPSRN
65+
- x2, x4, x8 trained models available
66+
- Advantage: The model can do multi-scale super-resolution with one forward pass. It can now support 2x, 4x, 8x, and [2x, 4x] and [2x, 4x, 8x] super-resolution.
67+
- Disadvantage: It is slower than ESPCN and FSRCNN, and the accuracy is worse than EDSR.
68+
- Speed: < 0.1 sec for every scaling factor on 256x256 images on an Intel i7-9700K CPU.
69+
- Original paper: [Deep laplacian pyramid networks for fast and accurate super-resolution](<https://arxiv.org/abs/1710.01992>) [4]
70+
71+
### Benchmarks
72+
73+
Comparing different algorithms. Scale x4 on monarch.png.
74+
75+
| | Inference time in seconds (CPU)| PSNR | SSIM |
76+
| ------------- |:-------------------:| ---------:|--------:|
77+
| ESPCN |0.01159 | 26.5471 | 0.88116 |
78+
| EDSR |3.26758 |**29.2404** |**0.92112** |
79+
| FSRCNN | 0.01298 | 26.5646 | 0.88064 |
80+
| LapSRN |0.28257 |26.7330 |0.88622 |
81+
| Bicubic |0.00031 |26.0635 |0.87537 |
82+
| Nearest neighbor |**0.00014** |23.5628 |0.81741 |
83+
| Lanczos |0.00101 |25.9115 |0.87057 |
84+
85+
### References
86+
[1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution"**, <i> 2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. </i> [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]
87+
88+
[2] Shi, W., Caballero, J., Huszár, F., Totz, J., Aitken, A., Bishop, R., Rueckert, D. and Wang, Z., **"Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network"**, <i>Proceedings of the IEEE conference on computer vision and pattern recognition</i> **CVPR 2016**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2016/papers/Shi_Real-Time_Single_Image_CVPR_2016_paper.pdf)] [[arXiv](https://arxiv.org/abs/1609.05158)]
89+
90+
[3] Chao Dong, Chen Change Loy, Xiaoou Tang. **"Accelerating the Super-Resolution Convolutional Neural Network"**, <i> in Proceedings of European Conference on Computer Vision </i>**ECCV 2016**. [[PDF](http://personal.ie.cuhk.edu.hk/~ccloy/files/eccv_2016_accelerating.pdf)]
91+
[[arXiv](https://arxiv.org/abs/1608.00367)] [[Project Page](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html)]
92+
93+
[4] Lai, W. S., Huang, J. B., Ahuja, N., and Yang, M. H., **"Deep laplacian pyramid networks for fast and accurate super-resolution"**, <i> In Proceedings of the IEEE conference on computer vision and pattern recognition </i>**CVPR 2017**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2017/papers/Lai_Deep_Laplacian_Pyramid_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1710.01992)] [[Project Page](http://vllab.ucmerced.edu/wlai24/LapSRN/)]
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#ifndef _OPENCV_DNN_SUPERRES_HPP_
6+
#define _OPENCV_DNN_SUPERRES_HPP_
7+
8+
/** @defgroup dnn_superres DNN used for super resolution
9+
10+
This module contains functionality for upscaling an image via convolutional neural networks.
11+
The following four models are implemented:
12+
13+
- EDSR <https://arxiv.org/abs/1707.02921>
14+
- ESPCN <https://arxiv.org/abs/1609.05158>
15+
- FSRCNN <https://arxiv.org/abs/1608.00367>
16+
- LapSRN <https://arxiv.org/abs/1710.01992>
17+
18+
*/
19+
20+
#include "opencv2/core.hpp"
21+
#include "opencv2/dnn.hpp"
22+
23+
namespace cv
24+
{
25+
namespace dnn_superres
26+
{
27+
28+
//! @addtogroup dnn_superres
29+
//! @{
30+
31+
/** @brief A class to upscale images via convolutional neural networks.
32+
The following four models are implemented:
33+
34+
- edsr
35+
- espcn
36+
- fsrcnn
37+
- lapsrn
38+
*/
39+
40+
class CV_EXPORTS DnnSuperResImpl
41+
{
42+
private:
43+
44+
/** @brief Net which holds the desired neural network
45+
*/
46+
dnn::Net net;
47+
48+
std::string alg; //algorithm
49+
50+
int sc; //scale factor
51+
52+
void preprocess(InputArray inpImg, OutputArray outpImg);
53+
54+
void reconstruct_YCrCb(InputArray inpImg, InputArray origImg, OutputArray outpImg, int scale);
55+
56+
void reconstruct_YCrCb(InputArray inpImg, InputArray origImg, OutputArray outpImg);
57+
58+
void preprocess_YCrCb(InputArray inpImg, OutputArray outpImg);
59+
60+
public:
61+
62+
/** @brief Empty constructor
63+
*/
64+
DnnSuperResImpl();
65+
66+
/** @brief Constructor which immediately sets the desired model
67+
@param algo String containing one of the desired models:
68+
- __edsr__
69+
- __espcn__
70+
- __fsrcnn__
71+
- __lapsrn__
72+
@param scale Integer specifying the upscale factor
73+
*/
74+
DnnSuperResImpl(const std::string& algo, int scale);
75+
76+
/** @brief Read the model from the given path
77+
@param path Path to the model file.
78+
*/
79+
void readModel(const std::string& path);
80+
81+
/** @brief Read the model from the given path
82+
@param weights Path to the model weights file.
83+
@param definition Path to the model definition file.
84+
*/
85+
void readModel(const std::string& weights, const std::string& definition);
86+
87+
/** @brief Set desired model
88+
@param algo String containing one of the desired models:
89+
- __edsr__
90+
- __espcn__
91+
- __fsrcnn__
92+
- __lapsrn__
93+
@param scale Integer specifying the upscale factor
94+
*/
95+
void setModel(const std::string& algo, int scale);
96+
97+
/** @brief Upsample via neural network
98+
@param img Image to upscale
99+
@param result Destination upscaled image
100+
*/
101+
void upsample(InputArray img, OutputArray result);
102+
103+
/** @brief Upsample via neural network of multiple outputs
104+
@param img Image to upscale
105+
@param imgs_new Destination upscaled images
106+
@param scale_factors Scaling factors of the output nodes
107+
@param node_names Names of the output nodes in the neural network
108+
*/
109+
void upsampleMultioutput(InputArray img, std::vector<Mat> &imgs_new, const std::vector<int>& scale_factors, const std::vector<String>& node_names);
110+
111+
/** @brief Returns the scale factor of the model:
112+
@return Current scale factor.
113+
*/
114+
int getScale();
115+
116+
/** @brief Returns the scale factor of the model:
117+
@return Current algorithm.
118+
*/
119+
std::string getAlgorithm();
120+
};
121+
122+
//! @} dnn_superres
123+
124+
}} // cv::dnn_superres::
125+
#endif
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#include "perf_precomp.hpp"
6+
7+
using namespace std;
8+
using namespace cv;
9+
using namespace perf;
10+
11+
namespace opencv_test { namespace {
12+
13+
typedef perf::TestBaseWithParam<tuple<tuple<string,string,int>, string> > dnn_superres;
14+
15+
#define MODEL testing::Values(tuple<string,string,int> {"espcn","ESPCN_x2.pb",2}, \
16+
tuple<string,string,int> {"lapsrn","LapSRN_x4.pb",4})
17+
#define IMAGES testing::Values("cv/dnn_superres/butterfly.png", "cv/shared/baboon.png", "cv/shared/lena.png")
18+
19+
const string TEST_DIR = "cv/dnn_superres";
20+
21+
PERF_TEST_P(dnn_superres, upsample, testing::Combine(MODEL, IMAGES))
22+
{
23+
tuple<string,string,int> model = get<0>( GetParam() );
24+
string image_name = get<1>( GetParam() );
25+
26+
string model_name = get<0>(model);
27+
string model_filename = get<1>(model);
28+
int scale = get<2>(model);
29+
30+
string model_path = getDataPath( TEST_DIR + "/" + model_filename );
31+
string image_path = getDataPath( image_name );
32+
33+
DnnSuperResImpl sr;
34+
sr.readModel(model_path);
35+
sr.setModel(model_name, scale);
36+
37+
Mat img = imread(image_path);
38+
39+
Mat img_new(img.rows * scale, img.cols * scale, CV_8UC3);
40+
41+
declare.in(img, WARMUP_RNG).out(img_new).iterations(10);
42+
43+
TEST_CYCLE() { sr.upsample(img, img_new); }
44+
45+
SANITY_CHECK_NOTHING();
46+
}
47+
48+
}}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#include "perf_precomp.hpp"
6+
7+
CV_PERF_TEST_MAIN( dnn_superres )
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#ifndef __OPENCV_PERF_PRECOMP_HPP__
6+
#define __OPENCV_PERF_PRECOMP_HPP__
7+
8+
#include "opencv2/ts.hpp"
9+
#include "opencv2/dnn_superres.hpp"
10+
11+
namespace opencv_test {
12+
using namespace cv::dnn_superres;
13+
}
14+
15+
#endif
125 KB
Loading

0 commit comments

Comments
 (0)