Skip to content

Commit b85d79e

Browse files
benizmergify[bot]
authored andcommitted
feat: support for batches for NCNN image models
1 parent 4c89a1c commit b85d79e

File tree

3 files changed

+211
-165
lines changed

3 files changed

+211
-165
lines changed

src/backends/ncnn/ncnninputconns.h

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -98,40 +98,48 @@ namespace dd
9898
{
9999
throw;
100100
}
101-
cv::Mat bgr = this->_images.at(0);
102-
_height = bgr.rows;
103-
_width = bgr.cols;
104101

105-
_in = ncnn::Mat::from_pixels(bgr.data, ncnn::Mat::PIXEL_BGR, bgr.cols,
106-
bgr.rows);
107-
if (_has_mean_scalar)
102+
for (size_t i = 0; i < _images.size(); ++i)
108103
{
109-
if (_std.empty())
110-
{
111-
_in.substract_mean_normalize(_mean.data(), 0);
112-
}
113-
else
114-
{
115-
_in.substract_mean_normalize(_mean.data(), _std.data());
116-
}
117-
}
118-
else if (_scale != 1.0)
119-
{
120-
float norm = 1.0 / _scale;
121-
if (_bw)
122-
{
123-
std::vector<float> vscale = { norm };
124-
_in.substract_mean_normalize(0, vscale.data());
125-
}
126-
else
127-
{
128-
std::vector<float> vscale = { norm, norm, norm };
129-
_in.substract_mean_normalize(0, vscale.data());
130-
}
104+
cv::Mat bgr = this->_images.at(i);
105+
_height = bgr.rows;
106+
_width = bgr.cols;
107+
108+
_in.push_back(ncnn::Mat::from_pixels(bgr.data, ncnn::Mat::PIXEL_BGR,
109+
bgr.cols, bgr.rows));
110+
{
111+
if (_has_mean_scalar)
112+
{
113+
if (_std.empty())
114+
{
115+
_in.at(i).substract_mean_normalize(_mean.data(), 0);
116+
}
117+
else
118+
{
119+
_in.at(i).substract_mean_normalize(_mean.data(),
120+
_std.data());
121+
}
122+
}
123+
else if (_scale != 1.0)
124+
{
125+
float norm = 1.0 / _scale;
126+
if (_bw)
127+
{
128+
std::vector<float> vscale = { norm };
129+
_in.at(i).substract_mean_normalize(0, vscale.data());
130+
}
131+
else
132+
{
133+
std::vector<float> vscale = { norm, norm, norm };
134+
_in.at(i).substract_mean_normalize(0, vscale.data());
135+
}
136+
}
137+
_ids.push_back(this->_uris.at(i));
138+
_imgs_size.insert(std::pair<std::string, std::pair<int, int>>(
139+
this->_ids.at(i), this->_images_size.at(i)));
140+
}
141+
_out = std::vector<ncnn::Mat>(_ids.size(), ncnn::Mat());
131142
}
132-
_ids.push_back(this->_uris.at(0));
133-
_imgs_size.insert(std::pair<std::string, std::pair<int, int>>(
134-
this->_ids.at(0), this->_images_size.at(0)));
135143
}
136144

137145
double unscale_res(double res, int nout)
@@ -140,8 +148,8 @@ namespace dd
140148
}
141149

142150
public:
143-
ncnn::Mat _in;
144-
ncnn::Mat _out;
151+
std::vector<ncnn::Mat> _in;
152+
std::vector<ncnn::Mat> _out;
145153
std::vector<std::string> _ids; /**< input ids (e.g. image ids) */
146154
};
147155

@@ -207,7 +215,8 @@ namespace dd
207215
_height += l;
208216
}
209217
// Mat(w,h)
210-
_in.create(_width, _height);
218+
_in.emplace_back(_width, _height);
219+
_out.emplace_back();
211220

212221
int mati = 0;
213222

@@ -222,20 +231,20 @@ namespace dd
222231
for (unsigned int si = 0; si < this->_csvtsdata.size(); ++si)
223232
{
224233
if (_continuation)
225-
_in[mati++] = 1.0;
234+
_in.at(0)[mati++] = 1.0;
226235
else
227-
_in[mati++] = 0.0;
236+
_in.at(0)[mati++] = 0.0;
228237
for (int di = 0; di < _ntargets; ++di)
229-
_in[mati++] = 0.0;
238+
_in.at(0)[mati++] = 0.0;
230239
for (unsigned int di : input_pos)
231-
_in[mati++] = this->_csvtsdata[si][0]._v[di];
240+
_in.at(0)[mati++] = this->_csvtsdata[si][0]._v[di];
232241
for (unsigned int ti = 1; ti < this->_csvtsdata[si].size(); ++ti)
233242
{
234-
_in[mati++] = 1.0;
243+
_in.at(0)[mati++] = 1.0;
235244
for (int di = 0; di < _ntargets; ++di)
236-
_in[mati++] = 0.0;
245+
_in.at(0)[mati++] = 0.0;
237246
for (unsigned int di : input_pos)
238-
_in[mati++] = this->_csvtsdata[si][ti]._v[di];
247+
_in.at(0)[mati++] = this->_csvtsdata[si][ti]._v[di];
239248
}
240249
}
241250
_ids.push_back(this->_uris.at(0));
@@ -255,8 +264,8 @@ namespace dd
255264
}
256265

257266
public:
258-
ncnn::Mat _in;
259-
ncnn::Mat _out;
267+
std::vector<ncnn::Mat> _in;
268+
std::vector<ncnn::Mat> _out;
260269
int _height;
261270
int _width;
262271
std::vector<std::string> _ids; /**< input ids (e.g. image ids) */

0 commit comments

Comments
 (0)