Skip to content

Add Otsu's method to cv::cuda::threshold #3943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: 4.x
Choose a base branch
from
Open

Conversation

troelsy
Copy link
Contributor

@troelsy troelsy commented May 22, 2025

I implemented Otsu's method in CUDA for a separate project and want to add it to cv::cuda::threshold

I have made an effort to use existing OpenCV functions in my code, but I had some trouble with ThresholdTypes and cv::cuda::calcHist. I couldn't figure out how to include precomp.hpp to get the definition of ThresholdTypes. For cv::cuda::calcHist I tried adding opencv_cudaimgproc, but it creates a circular dependency on cudaarithm. I have include a simple implementation of calcHist so the code runs, but I would like input on how to use cv::cuda::calcHist instead.

Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

  • I agree to contribute to the project under Apache 2 License.
  • To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
  • The PR is proposed to the proper branch
  • There is a reference to the original bug report and related work
  • There is accuracy test, performance test and test data in opencv_extra repository, if applicable
    Patch to opencv_extra has the same branch name.
  • The feature is well documented and sample code can be built with the project CMake

@asmorkalov
Copy link
Contributor

cc @cudawarped

@cudawarped
Copy link
Contributor

cudawarped commented May 27, 2025

@troelsy Your test currently doesn't use the Otsu method and when it does it fails the assert. I can't review the PR until this is fixed.

How does the performance of this compare to the CPU version?

@troelsy
Copy link
Contributor Author

troelsy commented May 27, 2025

I'm sorry, I didn't realize I hadn't run the test correctly. It should have been resolved now, please have another look.

There was a numerical error when computing the variance, which could make it off-by-one. Both implementations should get the same result now

I'm testing on an AWS g5.xlarge instance with a 1920x1080 GpuMat of CV_8UC1 and get the following timings (average of 10,000 iterations):

Otsu threshold CPU:   2371.643000µs
Otsu threshold GPU:   68.385000µs
Binary threshold GPU: 5.330000µs

So the current GPU version is about x35 faster than the CPU on this instance type (which is biased towards GPU). The Otsu computation adds about 63µs to the execution compared to a static binary threshold.

From my measurements, about half of the time is spend on allocating and deallocating memory, so if there is a way to reuse the allocations we could reduce the execution time significantly.

These are the individual kernel execution timings reported by Nsight Compute:

Kernel name            Minimum    Maximum    Average
histogram_kernel       4.86µs     4.90µs     4.88µs
otsu_mean              4.67µs     4.83µs     4.75µs
otsu_variance          5.63µs     5.66µs     5.65µs
otsu_score             4.26µs     4.29µs     4.27µs
cv::transformSmart     3.23µs     3.23µs     3.23µs

See also the screenshot of Nsight Systems for a visualization of the execution timeline. Mind that it says ~80µs, but that is due to profilling overhead.
Screenshot 2025-05-27 at 17 06 54

@asmorkalov asmorkalov self-requested a review May 27, 2025 18:32
@cudawarped
Copy link
Contributor

See also the screenshot of Nsight Systems for a visualization of the execution timeline. Mind that it says ~80µs, but that is due to profilling overhead.

You're measuring the API calls there not the execution time of the kernels/memory allocs etc.. As you are including synchronization this should be about the same but in general it is not the same thing. If there wasn't any synchronization the time for the API calls would be much much smaller than the execution time.

From my measurements, about half of the time is spend on allocating and deallocating memory, so if there is a way to reuse the allocations we could reduce the execution time significantly.

You should be able to reduce the allocations/deallocations by using BufferPool. The other main bottleneck is the copy of the threshold from device to host but because the NPP routines use a host threshold there isn't anything you can do about this.

{
extern __shared__ unsigned long long shared_memory_u64[];

uint bin_idx = blockIdx.x * blockDim.x + threadIdx.x;
Copy link
Contributor

@cudawarped cudawarped May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing. As you have 256 threads per block and 256 blocks I would suggest the following launch parameters

dim3 block_all = dim3(n_bins);
dim3 grid_all = dim3(n_thresholds);
dim3 block_score = dim3(n_thresholds);
dim3 grid_score = dim3(1);

with

uint bin_idx = threadIdx.x;
uint threshold = blockIdx.x;

Also intis usually more performant than uint.

template <uint n_bins, uint n_thresholds>
void compute_otsu_async(uint *histogram, uint *otsu_threshold, Stream &stream)
{
CV_Assert(n_bins == 256);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there is any reason to template n_bins and n_thresholds because your algorithm only supports 256 for both. I would remove the template and use statically allocated shared memory everywhere.


GpuMat gpu_threshold_sums = GpuMat(1, n_bins, CV_32SC1);
GpuMat gpu_sums = GpuMat(1, n_bins, CV_64FC1);
GpuMat gpu_variances = GpuMat(1, n_bins, CV_32SC4);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be CV_64C2?


// TODO: Replace this is cv::cuda::calcHist
template <uint n_bins>
__global__ void histogram_kernel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asmorkalov Due to a circular dependency when adding the opencv_cudaimgproc to opencv_cudaarithm in CMakeLists.txt cv::cuda::calcHist cannot be called from here. To avoid refactoring would it make sence at a minimum to only have a single implementation of calcHist even if it is duplicated or do you have a better suggestion? i.e. copy the implemntation from cv::cuda::calcHist here.

calcHist(src, gpu_histogram, stream);

GpuMat gpu_otsu_threshold(1, 1, CV_32SC1);
compute_otsu_async<n_bins, n_thresholds>(gpu_histogram.ptr<uint>(), gpu_otsu_threshold.ptr<uint>(), stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All kernel launches are async so I would rename this compute_otsu.

signed long long threshold_variance_above_i64 = (signed long long)(threshold_variance_above_f32 * bin_count);
signed long long threshold_variance_below_i64 = (signed long long)(threshold_variance_below_f32 * bin_count);
blockReduce<n_bins>((signed long long *)shared_memory_i64, threshold_variance_above_i64, bin_idx, plus<signed long long>());
__syncthreads();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there enough shared memory to have two arrays and avoid __syncthreads(); without affecting the occupancy?

The casts ((signed long long *)) is also redundant.

{
CV_Assert(depth == CV_8U);
CV_Assert(src.channels() == 1);
CV_Assert(maxVal == 255.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does maxVal have to be 255.0?

cv::Mat dst_gold;
double otsu_cpu = cv::threshold(src, dst_gold, 0, 255, threshOp);

EXPECT_NEAR(otsu_gpu, otsu_cpu, 1e-5);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the calculated threshold is always an integer which is cast to a double shouldn't the abs_diff be zero. Alternatively you could use ASSERT_DOUBLE_EQ instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants