Skip to content

Commit 9755ab1

Browse files
committed
Merge pull request opencv#17556 from nglee:dev_optFlowTVL1Async
2 parents c7dcdc0 + 2043e06 commit 9755ab1

File tree

5 files changed

+199
-54
lines changed

5 files changed

+199
-54
lines changed

modules/core/include/opencv2/core/cuda/common.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ namespace cv { namespace cuda
101101
cudaChannelFormatDesc desc = cudaCreateChannelDesc<T>();
102102
cudaSafeCall( cudaBindTexture2D(0, tex, img.ptr(), &desc, img.cols, img.rows, img.step) );
103103
}
104+
105+
template<class T> inline void createTextureObjectPitch2D(cudaTextureObject_t* tex, PtrStepSz<T>& img, const cudaTextureDesc& texDesc)
106+
{
107+
cudaResourceDesc resDesc;
108+
memset(&resDesc, 0, sizeof(resDesc));
109+
resDesc.resType = cudaResourceTypePitch2D;
110+
resDesc.res.pitch2D.devPtr = static_cast<void*>(img.ptr());
111+
resDesc.res.pitch2D.height = img.rows;
112+
resDesc.res.pitch2D.width = img.cols;
113+
resDesc.res.pitch2D.pitchInBytes = img.step;
114+
resDesc.res.pitch2D.desc = cudaCreateChannelDesc<T>();
115+
116+
cudaSafeCall( cudaCreateTextureObject(tex, &resDesc, &texDesc, NULL) );
117+
}
104118
}
105119
}}
106120

modules/cudaimgproc/src/cuda/canny.cu

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -90,53 +90,47 @@ namespace cv { namespace cuda { namespace device
9090

9191
namespace canny
9292
{
93-
texture<uchar, cudaTextureType2D, cudaReadModeElementType> tex_src(false, cudaFilterModePoint, cudaAddressModeClamp);
9493
struct SrcTex
9594
{
95+
virtual ~SrcTex() {}
96+
97+
__host__ SrcTex(int _xoff, int _yoff) : xoff(_xoff), yoff(_yoff) {}
98+
99+
__device__ __forceinline__ virtual int operator ()(int y, int x) const = 0;
100+
96101
int xoff;
97102
int yoff;
98-
__host__ SrcTex(int _xoff, int _yoff) : xoff(_xoff), yoff(_yoff) {}
103+
};
104+
105+
texture<uchar, cudaTextureType2D, cudaReadModeElementType> tex_src(false, cudaFilterModePoint, cudaAddressModeClamp);
106+
struct SrcTexRef : SrcTex
107+
{
108+
__host__ SrcTexRef(int _xoff, int _yoff) : SrcTex(_xoff, _yoff) {}
99109

100-
__device__ __forceinline__ int operator ()(int y, int x) const
110+
__device__ __forceinline__ int operator ()(int y, int x) const override
101111
{
102112
return tex2D(tex_src, x + xoff, y + yoff);
103113
}
104114
};
105115

106-
struct SrcTexObject
116+
struct SrcTexObj : SrcTex
107117
{
108-
int xoff;
109-
int yoff;
110-
cudaTextureObject_t tex_src_object;
111-
__host__ SrcTexObject(int _xoff, int _yoff, cudaTextureObject_t _tex_src_object) : xoff(_xoff), yoff(_yoff), tex_src_object(_tex_src_object) { }
118+
__host__ SrcTexObj(int _xoff, int _yoff, cudaTextureObject_t _tex_src_object) : SrcTex(_xoff, _yoff), tex_src_object(_tex_src_object) { }
112119

113-
__device__ __forceinline__ int operator ()(int y, int x) const
120+
__device__ __forceinline__ int operator ()(int y, int x) const override
114121
{
115122
return tex2D<uchar>(tex_src_object, x + xoff, y + yoff);
116123
}
117124

125+
cudaTextureObject_t tex_src_object;
118126
};
119127

120-
template <class Norm> __global__
121-
void calcMagnitudeKernel(const SrcTex src, PtrStepi dx, PtrStepi dy, PtrStepSzf mag, const Norm norm)
122-
{
123-
const int x = blockIdx.x * blockDim.x + threadIdx.x;
124-
const int y = blockIdx.y * blockDim.y + threadIdx.y;
125-
126-
if (y >= mag.rows || x >= mag.cols)
127-
return;
128-
129-
int dxVal = (src(y - 1, x + 1) + 2 * src(y, x + 1) + src(y + 1, x + 1)) - (src(y - 1, x - 1) + 2 * src(y, x - 1) + src(y + 1, x - 1));
130-
int dyVal = (src(y + 1, x - 1) + 2 * src(y + 1, x) + src(y + 1, x + 1)) - (src(y - 1, x - 1) + 2 * src(y - 1, x) + src(y - 1, x + 1));
131-
132-
dx(y, x) = dxVal;
133-
dy(y, x) = dyVal;
134-
135-
mag(y, x) = norm(dxVal, dyVal);
136-
}
137-
138-
template <class Norm> __global__
139-
void calcMagnitudeKernel(const SrcTexObject src, PtrStepi dx, PtrStepi dy, PtrStepSzf mag, const Norm norm)
128+
template <
129+
class T,
130+
class Norm,
131+
typename = std::enable_if_t<std::is_base_of<SrcTex, T>::value>
132+
>
133+
__global__ void calcMagnitudeKernel(const T src, PtrStepi dx, PtrStepi dy, PtrStepSzf mag, const Norm norm)
140134
{
141135
const int x = blockIdx.x * blockDim.x + threadIdx.x;
142136
const int y = blockIdx.y * blockDim.y + threadIdx.y;
@@ -162,25 +156,16 @@ namespace canny
162156

163157
if (cc30)
164158
{
165-
cudaResourceDesc resDesc;
166-
memset(&resDesc, 0, sizeof(resDesc));
167-
resDesc.resType = cudaResourceTypePitch2D;
168-
resDesc.res.pitch2D.devPtr = srcWhole.ptr();
169-
resDesc.res.pitch2D.height = srcWhole.rows;
170-
resDesc.res.pitch2D.width = srcWhole.cols;
171-
resDesc.res.pitch2D.pitchInBytes = srcWhole.step;
172-
resDesc.res.pitch2D.desc = cudaCreateChannelDesc<uchar>();
173-
174159
cudaTextureDesc texDesc;
175160
memset(&texDesc, 0, sizeof(texDesc));
176161
texDesc.addressMode[0] = cudaAddressModeClamp;
177162
texDesc.addressMode[1] = cudaAddressModeClamp;
178163
texDesc.addressMode[2] = cudaAddressModeClamp;
179164

180165
cudaTextureObject_t tex = 0;
181-
cudaCreateTextureObject(&tex, &resDesc, &texDesc, NULL);
166+
createTextureObjectPitch2D(&tex, srcWhole, texDesc);
182167

183-
SrcTexObject src(xoff, yoff, tex);
168+
SrcTexObj src(xoff, yoff, tex);
184169

185170
if (L2Grad)
186171
{
@@ -205,7 +190,7 @@ namespace canny
205190
else
206191
{
207192
bindTexture(&tex_src, srcWhole);
208-
SrcTex src(xoff, yoff);
193+
SrcTexRef src(xoff, yoff);
209194

210195
if (L2Grad)
211196
{

modules/cudaimgproc/test/test_canny.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class CannyAsyncParallelLoopBody : public cv::ParallelLoopBody
116116
bool useL2gradient;
117117
};
118118

119-
#define NUM_STREAMS 64
119+
#define NUM_STREAMS 128
120120

121121
CUDA_TEST_P(Canny, Async)
122122
{

modules/cudaoptflow/src/cuda/tvl1flow.cu

Lines changed: 96 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "opencv2/core/cuda/common.hpp"
4646
#include "opencv2/core/cuda/border_interpolate.hpp"
4747
#include "opencv2/core/cuda/limits.hpp"
48+
#include "opencv2/core/cuda.hpp"
4849

4950
using namespace cv::cuda;
5051
using namespace cv::cuda::device;
@@ -101,11 +102,64 @@ namespace tvl1flow
101102
}
102103
}
103104

105+
struct SrcTex
106+
{
107+
virtual ~SrcTex() {}
108+
109+
__device__ __forceinline__ virtual float I1(float x, float y) const = 0;
110+
__device__ __forceinline__ virtual float I1x(float x, float y) const = 0;
111+
__device__ __forceinline__ virtual float I1y(float x, float y) const = 0;
112+
};
113+
104114
texture<float, cudaTextureType2D, cudaReadModeElementType> tex_I1 (false, cudaFilterModePoint, cudaAddressModeClamp);
105115
texture<float, cudaTextureType2D, cudaReadModeElementType> tex_I1x(false, cudaFilterModePoint, cudaAddressModeClamp);
106116
texture<float, cudaTextureType2D, cudaReadModeElementType> tex_I1y(false, cudaFilterModePoint, cudaAddressModeClamp);
117+
struct SrcTexRef : SrcTex
118+
{
119+
__device__ __forceinline__ float I1(float x, float y) const override
120+
{
121+
return tex2D(tex_I1, x, y);
122+
}
123+
__device__ __forceinline__ float I1x(float x, float y) const override
124+
{
125+
return tex2D(tex_I1x, x, y);
126+
}
127+
__device__ __forceinline__ float I1y(float x, float y) const override
128+
{
129+
return tex2D(tex_I1y, x, y);
130+
}
131+
};
132+
133+
struct SrcTexObj : SrcTex
134+
{
135+
__host__ SrcTexObj(cudaTextureObject_t tex_obj_I1_, cudaTextureObject_t tex_obj_I1x_, cudaTextureObject_t tex_obj_I1y_)
136+
: tex_obj_I1(tex_obj_I1_), tex_obj_I1x(tex_obj_I1x_), tex_obj_I1y(tex_obj_I1y_) {}
137+
138+
__device__ __forceinline__ float I1(float x, float y) const override
139+
{
140+
return tex2D<float>(tex_obj_I1, x, y);
141+
}
142+
__device__ __forceinline__ float I1x(float x, float y) const override
143+
{
144+
return tex2D<float>(tex_obj_I1x, x, y);
145+
}
146+
__device__ __forceinline__ float I1y(float x, float y) const override
147+
{
148+
return tex2D<float>(tex_obj_I1y, x, y);
149+
}
107150

108-
__global__ void warpBackwardKernel(const PtrStepSzf I0, const PtrStepf u1, const PtrStepf u2, PtrStepf I1w, PtrStepf I1wx, PtrStepf I1wy, PtrStepf grad, PtrStepf rho)
151+
cudaTextureObject_t tex_obj_I1;
152+
cudaTextureObject_t tex_obj_I1x;
153+
cudaTextureObject_t tex_obj_I1y;
154+
};
155+
156+
template <
157+
typename T,
158+
typename = std::enable_if_t<std::is_base_of<SrcTex, T>::value>
159+
>
160+
__global__ void warpBackwardKernel(
161+
const PtrStepSzf I0, const T src, const PtrStepf u1, const PtrStepf u2,
162+
PtrStepf I1w, PtrStepf I1wx, PtrStepf I1wy, PtrStepf grad, PtrStepf rho)
109163
{
110164
const int x = blockIdx.x * blockDim.x + threadIdx.x;
111165
const int y = blockIdx.y * blockDim.y + threadIdx.y;
@@ -136,9 +190,9 @@ namespace tvl1flow
136190
{
137191
const float w = bicubicCoeff(wx - cx) * bicubicCoeff(wy - cy);
138192

139-
sum += w * tex2D(tex_I1 , cx, cy);
140-
sumx += w * tex2D(tex_I1x, cx, cy);
141-
sumy += w * tex2D(tex_I1y, cx, cy);
193+
sum += w * src.I1(cx, cy);
194+
sumx += w * src.I1x(cx, cy);
195+
sumy += w * src.I1y(cx, cy);
142196

143197
wsum += w;
144198
}
@@ -173,15 +227,46 @@ namespace tvl1flow
173227
const dim3 block(32, 8);
174228
const dim3 grid(divUp(I0.cols, block.x), divUp(I0.rows, block.y));
175229

176-
bindTexture(&tex_I1 , I1);
177-
bindTexture(&tex_I1x, I1x);
178-
bindTexture(&tex_I1y, I1y);
230+
bool cc30 = deviceSupports(FEATURE_SET_COMPUTE_30);
179231

180-
warpBackwardKernel<<<grid, block, 0, stream>>>(I0, u1, u2, I1w, I1wx, I1wy, grad, rho);
181-
cudaSafeCall( cudaGetLastError() );
232+
if (cc30)
233+
{
234+
cudaTextureDesc texDesc;
235+
memset(&texDesc, 0, sizeof(texDesc));
236+
texDesc.addressMode[0] = cudaAddressModeClamp;
237+
texDesc.addressMode[1] = cudaAddressModeClamp;
238+
texDesc.addressMode[2] = cudaAddressModeClamp;
182239

183-
if (!stream)
184-
cudaSafeCall( cudaDeviceSynchronize() );
240+
cudaTextureObject_t texObj_I1 = 0, texObj_I1x = 0, texObj_I1y = 0;
241+
242+
createTextureObjectPitch2D(&texObj_I1, I1, texDesc);
243+
createTextureObjectPitch2D(&texObj_I1x, I1x, texDesc);
244+
createTextureObjectPitch2D(&texObj_I1y, I1y, texDesc);
245+
246+
warpBackwardKernel << <grid, block, 0, stream >> > (I0, SrcTexObj(texObj_I1, texObj_I1x, texObj_I1y), u1, u2, I1w, I1wx, I1wy, grad, rho);
247+
cudaSafeCall(cudaGetLastError());
248+
249+
if (!stream)
250+
cudaSafeCall(cudaDeviceSynchronize());
251+
else
252+
cudaSafeCall(cudaStreamSynchronize(stream));
253+
254+
cudaSafeCall(cudaDestroyTextureObject(texObj_I1));
255+
cudaSafeCall(cudaDestroyTextureObject(texObj_I1x));
256+
cudaSafeCall(cudaDestroyTextureObject(texObj_I1y));
257+
}
258+
else
259+
{
260+
bindTexture(&tex_I1, I1);
261+
bindTexture(&tex_I1x, I1x);
262+
bindTexture(&tex_I1y, I1y);
263+
264+
warpBackwardKernel << <grid, block, 0, stream >> > (I0, SrcTexRef(), u1, u2, I1w, I1wx, I1wy, grad, rho);
265+
cudaSafeCall(cudaGetLastError());
266+
267+
if (!stream)
268+
cudaSafeCall(cudaDeviceSynchronize());
269+
}
185270
}
186271
}
187272

modules/cudaoptflow/test/test_optflow.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,71 @@ CUDA_TEST_P(OpticalFlowDual_TVL1, Accuracy)
405405
EXPECT_MAT_SIMILAR(flow, d_flow, 4e-3);
406406
}
407407

408+
class TVL1AsyncParallelLoopBody : public cv::ParallelLoopBody
409+
{
410+
public:
411+
TVL1AsyncParallelLoopBody(const cv::cuda::GpuMat& d_img1_, const cv::cuda::GpuMat& d_img2_, cv::cuda::GpuMat* d_flow_, int iterations_, double gamma_)
412+
: d_img1(d_img1_), d_img2(d_img2_), d_flow(d_flow_), iterations(iterations_), gamma(gamma_) {}
413+
~TVL1AsyncParallelLoopBody() {}
414+
void operator()(const cv::Range& r) const
415+
{
416+
for (int i = r.start; i < r.end; i++) {
417+
cv::cuda::Stream stream;
418+
cv::Ptr<cv::cuda::OpticalFlowDual_TVL1> d_alg = cv::cuda::OpticalFlowDual_TVL1::create();
419+
d_alg->setNumIterations(iterations);
420+
d_alg->setGamma(gamma);
421+
d_alg->calc(d_img1, d_img2, d_flow[i], stream);
422+
stream.waitForCompletion();
423+
}
424+
}
425+
protected:
426+
const cv::cuda::GpuMat& d_img1;
427+
const cv::cuda::GpuMat& d_img2;
428+
cv::cuda::GpuMat* d_flow;
429+
int iterations;
430+
double gamma;
431+
};
432+
433+
#define NUM_STREAMS 16
434+
435+
CUDA_TEST_P(OpticalFlowDual_TVL1, Async)
436+
{
437+
if (!supportFeature(devInfo, cv::cuda::FEATURE_SET_COMPUTE_30))
438+
{
439+
throw SkipTestException("CUDA device doesn't support texture objects");
440+
}
441+
else
442+
{
443+
cv::Mat frame0 = readImage("opticalflow/rubberwhale1.png", cv::IMREAD_GRAYSCALE);
444+
ASSERT_FALSE(frame0.empty());
445+
446+
cv::Mat frame1 = readImage("opticalflow/rubberwhale2.png", cv::IMREAD_GRAYSCALE);
447+
ASSERT_FALSE(frame1.empty());
448+
449+
const int iterations = 10;
450+
451+
// Synchronous call
452+
cv::Ptr<cv::cuda::OpticalFlowDual_TVL1> d_alg =
453+
cv::cuda::OpticalFlowDual_TVL1::create();
454+
d_alg->setNumIterations(iterations);
455+
d_alg->setGamma(gamma);
456+
457+
cv::cuda::GpuMat d_flow_gold;
458+
d_alg->calc(loadMat(frame0), loadMat(frame1), d_flow_gold);
459+
460+
// Asynchronous call
461+
cv::cuda::GpuMat d_flow[NUM_STREAMS];
462+
cv::parallel_for_(cv::Range(0, NUM_STREAMS), TVL1AsyncParallelLoopBody(loadMat(frame0), loadMat(frame1), d_flow, iterations, gamma));
463+
464+
// Compare the results of synchronous call and asynchronous call
465+
for (int i = 0; i < NUM_STREAMS; i++)
466+
EXPECT_MAT_NEAR(d_flow_gold, d_flow[i], 0.0);
467+
}
468+
}
469+
408470
INSTANTIATE_TEST_CASE_P(CUDA_OptFlow, OpticalFlowDual_TVL1, testing::Combine(
409471
ALL_DEVICES,
410472
testing::Values(Gamma(0.0), Gamma(1.0))));
411473

412-
413474
}} // namespace
414475
#endif // HAVE_CUDA

0 commit comments

Comments
 (0)