Skip to content

Commit b9f8525

Browse files
benizmergify[bot]
authored andcommitted
feat(ml): torch image basic data augmentation
1 parent 566e5fb commit b9f8525

File tree

8 files changed

+278
-7
lines changed

8 files changed

+278
-7
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ if (USE_TORCH)
101101
backends/torch/torchmodule.cc
102102
backends/torch/torchutils.cc
103103
backends/torch/optim/ranger.cc
104+
backends/torch/torchdataaug.cc
104105
)
105106
endif()
106107

src/backends/torch/torchdataaug.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/**
2+
* DeepDetect
3+
* Copyright (c) 2021 Jolibrain
4+
* Authors: Emmanuel Benazera <emmanuel.benazera@jolibrain.com>
5+
*
6+
* This file is part of deepdetect.
7+
*
8+
* deepdetect is free software: you can redistribute it and/or modify
9+
* it under the terms of the GNU Lesser General Public License as published by
10+
* the Free Software Foundation, either version 3 of the License, or
11+
* (at your option) any later version.
12+
*
13+
* deepdetect is distributed in the hope that it will be useful,
14+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16+
* GNU Lesser General Public License for more details.
17+
*
18+
* You should have received a copy of the GNU Lesser General Public License
19+
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
20+
*/
21+
22+
#include "torchdataaug.h"
23+
24+
namespace dd
25+
{
26+
27+
void TorchImgRandAugCV::augment(cv::Mat &src)
28+
{
29+
// apply augmentation
30+
if (_mirror)
31+
applyMirror(src);
32+
if (_rotate)
33+
applyRotate(src);
34+
35+
// should be last, in this order
36+
if (_cutout > 0.0)
37+
applyCutout(src);
38+
if (_crop_size > 0)
39+
applyCrop(src);
40+
}
41+
42+
void TorchImgRandAugCV::applyMirror(cv::Mat &src)
43+
{
44+
#pragma omp critical
45+
{
46+
if (_bernouilli(_rnd_gen))
47+
{
48+
cv::Mat dst;
49+
cv::flip(src, dst, 1);
50+
src = dst;
51+
}
52+
}
53+
}
54+
55+
void TorchImgRandAugCV::applyRotate(cv::Mat &src)
56+
{
57+
int rot = 0;
58+
#pragma omp critical
59+
{
60+
rot = _uniform_int_rotate(_rnd_gen);
61+
}
62+
if (rot == 0)
63+
return;
64+
else if (rot == 1) // 90
65+
{
66+
cv::Mat dst;
67+
cv::transpose(src, dst);
68+
cv::flip(dst, src, 1);
69+
}
70+
else if (rot == 2) // 180
71+
{
72+
cv::Mat dst;
73+
cv::flip(src, dst, -1);
74+
src = dst;
75+
}
76+
else if (rot == 3) // 270
77+
{
78+
cv::Mat dst;
79+
cv::transpose(src, dst);
80+
cv::flip(dst, src, 0);
81+
}
82+
}
83+
84+
void TorchImgRandAugCV::applyCrop(cv::Mat &src)
85+
{
86+
int crop_x = 0;
87+
int crop_y = 0;
88+
#pragma omp critical
89+
{
90+
crop_x = _uniform_int_crop_x(_rnd_gen);
91+
crop_y = _uniform_int_crop_y(_rnd_gen);
92+
}
93+
cv::Rect crop(crop_x, crop_y, _crop_size, _crop_size);
94+
cv::Mat dst = src(crop).clone();
95+
src = dst;
96+
}
97+
98+
void TorchImgRandAugCV::applyCutout(cv::Mat &src)
99+
{
100+
// Draw random between 0 and 1
101+
float r1 = 0.0;
102+
#pragma omp critical
103+
{
104+
r1 = _uniform_real_1(_rnd_gen);
105+
}
106+
if (r1 > _cutout)
107+
return;
108+
109+
#pragma omp critical
110+
{
111+
// get shape and area to erase
112+
float s = _uniform_real_cutout_s(_rnd_gen) * _img_width
113+
* _img_height; // area
114+
float r = _uniform_real_cutout_r(_rnd_gen); // aspect ratio
115+
116+
int w = std::min(_img_width,
117+
static_cast<int>(std::floor(std::sqrt(s / r))));
118+
int h = std::min(_img_height,
119+
static_cast<int>(std::floor(std::sqrt(s * r))));
120+
std::uniform_int_distribution<int> distx(0, _img_width - w);
121+
std::uniform_int_distribution<int> disty(0, _img_height - h);
122+
int rect_x = distx(_rnd_gen);
123+
int rect_y = disty(_rnd_gen);
124+
125+
// erase
126+
cv::Rect rect(rect_x, rect_y, w, h);
127+
cv::Mat selected_area = src(rect);
128+
cv::randu(selected_area, cv::Scalar(_cutout_vl, _cutout_vl, _cutout_vl),
129+
cv::Scalar(_cutout_vh, _cutout_vh, _cutout_vh)); // TODO: bw
130+
}
131+
}
132+
}

src/backends/torch/torchdataaug.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/**
2+
* DeepDetect
3+
* Copyright (c) 2021 Jolibrain
4+
* Authors: Emmanuel Benazera <emmanuel.benazera@jolibrain.com>
5+
*
6+
* This file is part of deepdetect.
7+
*
8+
* deepdetect is free software: you can redistribute it and/or modify
9+
* it under the terms of the GNU Lesser General Public License as published by
10+
* the Free Software Foundation, either version 3 of the License, or
11+
* (at your option) any later version.
12+
*
13+
* deepdetect is distributed in the hope that it will be useful,
14+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16+
* GNU Lesser General Public License for more details.
17+
*
18+
* You should have received a copy of the GNU Lesser General Public License
19+
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
20+
*/
21+
22+
#ifndef TORCHDATAAUG_H
23+
#define TORCHDATAAUG_H
24+
25+
#include <opencv2/opencv.hpp>
26+
#include <random>
27+
28+
namespace dd
29+
{
30+
class TorchImgRandAugCV
31+
{
32+
public:
33+
TorchImgRandAugCV()
34+
{
35+
}
36+
37+
TorchImgRandAugCV(const int &img_width, const int &img_height,
38+
const bool &mirror, const bool &rotate,
39+
const int &crop_size, const bool &cutout)
40+
: _img_width(img_width), _img_height(img_height), _mirror(mirror),
41+
_rotate(rotate), _crop_size(crop_size), _cutout(cutout),
42+
_uniform_real_1(0.0, 1.0), _bernouilli(0.5),
43+
_uniform_int_rotate(0, 3)
44+
{
45+
if (_crop_size > 0)
46+
{
47+
_uniform_int_crop_x
48+
= std::uniform_int_distribution<int>(0, _img_width - _crop_size);
49+
_uniform_int_crop_y = std::uniform_int_distribution<int>(
50+
0, _img_height - _crop_size);
51+
}
52+
if (_cutout > 0.0)
53+
{
54+
_uniform_real_cutout_s
55+
= std::uniform_real_distribution<float>(_cutout_sl, _cutout_sh);
56+
_uniform_real_cutout_r
57+
= std::uniform_real_distribution<float>(_cutout_rl, _cutout_rh);
58+
}
59+
}
60+
61+
~TorchImgRandAugCV()
62+
{
63+
}
64+
65+
void augment(cv::Mat &src);
66+
67+
protected:
68+
void applyMirror(cv::Mat &src);
69+
void applyRotate(cv::Mat &src);
70+
void applyCrop(cv::Mat &src);
71+
void applyCutout(cv::Mat &src);
72+
73+
private:
74+
int _img_width = 224;
75+
int _img_height = 224;
76+
77+
// augmentation options & parameter
78+
bool _mirror = false;
79+
bool _rotate = false;
80+
int _crop_size = -1;
81+
float _cutout = 0.0;
82+
float _cutout_sl = 0.02; /**< min proportion of erased area wrt image. */
83+
float _cutout_sh = 0.4; /**< max proportion of erased area wrt image. */
84+
float _cutout_rl = 0.3; /**< min aspect ratio of erased area. */
85+
float _cutout_rh = 3.0; /**< max aspect ratio of erased area. */
86+
int _cutout_vl = 0; /**< min erased area pixel value. */
87+
int _cutout_vh = 255; /**< max erased area pixel value. */
88+
89+
// random generators
90+
std::default_random_engine _rnd_gen;
91+
std::uniform_real_distribution<float>
92+
_uniform_real_1; /**< random real uniform between 0 and 1. */
93+
std::bernoulli_distribution _bernouilli;
94+
std::uniform_int_distribution<int> _uniform_int_rotate;
95+
std::uniform_int_distribution<int> _uniform_int_crop_x;
96+
std::uniform_int_distribution<int> _uniform_int_crop_y;
97+
std::uniform_real_distribution<float> _uniform_real_cutout_s;
98+
std::uniform_real_distribution<float> _uniform_real_cutout_r;
99+
};
100+
}
101+
102+
#endif

src/backends/torch/torchdataset.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ namespace dd
137137
_dbData = std::shared_ptr<db::DB>(db::GetDB(_backend));
138138
_dbData->Open(_dbFullName, db::NEW);
139139
_txn = std::shared_ptr<db::Transaction>(_dbData->NewTransaction());
140+
_logger->info("Preparing db of {}x{} images", bgr.cols, bgr.rows);
140141
}
141142

142143
// data & target keys
@@ -296,7 +297,7 @@ namespace dd
296297
std::vector<BatchToStack> data, target;
297298
bool first_iter = true;
298299

299-
if (!_db)
300+
if (!_db) // Note: no data augmentation if no db
300301
{
301302
if (!_lfiles.empty()) // prefetch batch from file list
302303
{
@@ -428,6 +429,9 @@ namespace dd
428429
torch::Tensor targett;
429430
read_image_from_db(datas, targets, bgr, targett, inputc->_bw);
430431

432+
// data augmentation can apply here, with OpenCV
433+
_img_rand_aug_cv.augment(bgr);
434+
431435
torch::Tensor imgt
432436
= image_to_tensor(bgr, inputc->height(), inputc->width());
433437

src/backends/torch/torchdataset.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "backends/torch/db_lmdb.hpp"
3434

3535
#include "inputconnectorstrategy.h"
36+
#include "torchdataaug.h"
3637
#include "torchutils.h"
3738

3839
#include <opencv2/opencv.hpp>
@@ -79,7 +80,8 @@ namespace dd
7980
= nullptr; /**< back ptr to input connector. */
8081
bool _classification = true; /**< whether a classification dataset. */
8182

82-
bool _image = false; /**< whether an image dataset. */
83+
bool _image = false; /**< whether an image dataset. */
84+
TorchImgRandAugCV _img_rand_aug_cv; /**< image data augmentation policy. */
8385

8486
/**
8587
* \brief empty constructor
@@ -98,7 +100,8 @@ namespace dd
98100
_logger(d._logger), _shuffle(d._shuffle), _dbData(d._dbData),
99101
_indices(d._indices), _lfiles(d._lfiles), _batches(d._batches),
100102
_dbFullName(d._dbFullName), _inputc(d._inputc),
101-
_classification(d._classification), _image(d._image)
103+
_classification(d._classification), _image(d._image),
104+
_img_rand_aug_cv(d._img_rand_aug_cv)
102105
{
103106
}
104107

src/backends/torch/torchinputconns.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ namespace dd
228228
if (shouldLoad)
229229
{
230230
if (_db)
231-
_tilogger->info("Load from db");
231+
_tilogger->info("Preparation for training from db");
232232
// Get files paths
233233
try
234234
{

src/backends/torch/torchlib.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,35 @@ namespace dd
485485
throw;
486486
}
487487

488+
// TODO: set inputc dataset data augmentation options
488489
APIData ad_mllib = ad.getobj("parameters").getobj("mllib");
490+
bool has_data_augmentation
491+
= ad_mllib.has("mirror") || ad_mllib.has("rotate")
492+
|| ad_mllib.has("crop_size") || ad_mllib.has("cutout");
493+
if (has_data_augmentation)
494+
{
495+
bool has_mirror
496+
= ad_mllib.has("mirror") && ad_mllib.get("mirror").get<bool>();
497+
this->_logger->info("mirror: {}", has_mirror);
498+
bool has_rotate
499+
= ad_mllib.has("rotate") && ad_mllib.get("rotate").get<bool>();
500+
this->_logger->info("rotate: {}", has_rotate);
501+
int crop_size = -1;
502+
if (ad_mllib.has("crop_size"))
503+
{
504+
crop_size = ad_mllib.get("crop_size").get<int>();
505+
this->_logger->info("crop_size : {}", crop_size);
506+
}
507+
float cutout = 0.0;
508+
if (ad_mllib.has("cutout"))
509+
{
510+
cutout = ad_mllib.get("cutout").get<double>();
511+
this->_logger->info("cutout: {}", cutout);
512+
}
513+
inputc._dataset._img_rand_aug_cv
514+
= TorchImgRandAugCV(inputc.width(), inputc.height(), has_mirror,
515+
has_rotate, crop_size, cutout);
516+
}
489517

490518
// solver params
491519
int64_t iterations = 1;
@@ -610,7 +638,6 @@ namespace dd
610638

611639
for (TorchBatch batch : *dataloader)
612640
{
613-
614641
auto tstart = steady_clock::now();
615642
if (_masked_lm)
616643
{

tests/ut-torchapi.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ TEST(torchapi, service_train_images)
333333
"\"supervised\",\"model\":{\"repository\":\""
334334
+ resnet50_train_repo
335335
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
336-
"\"width\":224,\"height\":224,\"db\":true},\"mllib\":{\"nclasses\":"
336+
"\"width\":256,\"height\":256,\"db\":true},\"mllib\":{\"nclasses\":"
337337
"2,\"finetuning\":true,\"gpu\":true}}}";
338338
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
339339
ASSERT_EQ(created_str, joutstr);
@@ -345,7 +345,9 @@ TEST(torchapi, service_train_images)
345345
+ iterations_resnet50 + ",\"base_lr\":" + torch_lr
346346
+ ",\"iter_size\":4,\"solver_type\":\"ADAM\",\"test_"
347347
"interval\":200},\"net\":{\"batch_size\":4},"
348-
"\"resume\":false},"
348+
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
349+
"\"cutout\":0.5}"
350+
","
349351
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true},"
350352
"\"output\":{\"measure\":[\"f1\",\"acc\"]}},\"data\":[\""
351353
+ resnet50_train_data + "\",\"" + resnet50_test_data + "\"]}";

0 commit comments

Comments
 (0)