diff --git a/cuda/include/ball_query.h b/cuda/include/ball_query.h index 99f61a5..0d6b37d 100644 --- a/cuda/include/ball_query.h +++ b/cuda/include/ball_query.h @@ -1,11 +1,11 @@ #pragma once #include -std::pair ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, - const float radius, const int nsample); +std::pair ball_query_dense(torch::Tensor new_xyz, torch::Tensor xyz, + const float radius, const int nsample); -std::pair ball_query_partial_dense(at::Tensor x, at::Tensor y, - at::Tensor batch_x, at::Tensor batch_y, - const float radius, const int nsample); +std::pair +ball_query_partial_dense(torch::Tensor x, torch::Tensor y, torch::Tensor batch_x, + torch::Tensor batch_y, const float radius, const int nsample); -at::Tensor degree(at::Tensor row, int64_t num_nodes); +torch::Tensor degree(torch::Tensor row, int64_t num_nodes); diff --git a/cuda/include/interpolate.h b/cuda/include/interpolate.h index 7eeff71..6e96b51 100644 --- a/cuda/include/interpolate.h +++ b/cuda/include/interpolate.h @@ -3,7 +3,16 @@ #include #include -std::vector three_nn(at::Tensor unknowns, at::Tensor knows); -at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, at::Tensor weight); -at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight, + + + +std::vector three_nn(torch::Tensor unknowns, torch::Tensor knows); +torch::Tensor three_interpolate(torch::Tensor points, torch::Tensor idx, torch::Tensor weight); +torch::Tensor three_interpolate_grad(torch::Tensor grad_out, torch::Tensor idx, torch::Tensor weight, const int m); + +std::vector three_nn_kernel_wrapper(torch::Tensor unknown, torch::Tensor known); +torch::Tensor three_interpolate_kernel_wrapper(torch::Tensor points, torch::Tensor idx, + torch::Tensor weight); +torch::Tensor three_interpolate_grad_kernel_wrapper(torch::Tensor grad_out, torch::Tensor idx, + torch::Tensor weight, const int m); \ No newline at end of file diff --git a/cuda/include/sampling.h b/cuda/include/sampling.h index fc4f13d..c39d872 100644 --- a/cuda/include/sampling.h +++ b/cuda/include/sampling.h @@ -1,4 +1,4 @@ #pragma once #include -at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); +torch::Tensor furthest_point_sampling(torch::Tensor points, const int nsamples); diff --git a/cuda/src/ball_query.cpp b/cuda/src/ball_query.cpp index 23d7e7f..3b67525 100644 --- a/cuda/src/ball_query.cpp +++ b/cuda/src/ball_query.cpp @@ -1,19 +1,18 @@ #include "ball_query.h" #include "compat.h" #include "utils.h" - -void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample, - const float* new_xyz, const float* xyz, int64_t* idx, - float* dist_out); - -void query_ball_point_kernel_partial_wrapper(int64_t batch_size, int size_x, int size_y, - float radius, int nsample, const float* x, - const float* y, const int64_t* batch_x, - const int64_t* batch_y, int64_t* idx_out, - float* dist_out); - -std::pair ball_query_dense(at::Tensor new_xyz, at::Tensor xyz, - const float radius, const int nsample) +#include + +std::pair query_ball_point_kernel_dense_wrapper(float radius, + int nsample, + torch::Tensor new_xyz, + torch::Tensor xyz); +std::pair +query_ball_point_kernel_partial_wrapper(float radius, int nsample, torch::Tensor x, torch::Tensor y, + torch::Tensor batch_x, torch::Tensor batch_y); + +std::pair ball_query_dense(torch::Tensor new_xyz, torch::Tensor xyz, + const float radius, const int nsample) { CHECK_CONTIGUOUS(new_xyz); CHECK_CONTIGUOUS(xyz); @@ -23,28 +22,13 @@ std::pair ball_query_dense(at::Tensor new_xyz, at::Tenso CHECK_CUDA(xyz); CHECK_CUDA(new_xyz); - at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, - at::device(new_xyz.device()).dtype(at::ScalarType::Long)); - at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1, - at::device(new_xyz.device()).dtype(at::ScalarType::Float)); - - query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), radius, - nsample, new_xyz.DATA_PTR(), xyz.DATA_PTR(), - idx.DATA_PTR(), dist.DATA_PTR()); - - return std::make_pair(idx, dist); + return query_ball_point_kernel_dense_wrapper(radius, nsample, new_xyz, xyz); } -at::Tensor degree(at::Tensor row, int64_t num_nodes) -{ - auto zero = at::zeros(num_nodes, row.options()); - auto one = at::ones(row.size(0), row.options()); - return zero.scatter_add_(0, row, one); -} -std::pair ball_query_partial_dense(at::Tensor x, at::Tensor y, - at::Tensor batch_x, at::Tensor batch_y, - const float radius, const int nsample) +std::pair +ball_query_partial_dense(torch::Tensor x, torch::Tensor y, torch::Tensor batch_x, + torch::Tensor batch_y, const float radius, const int nsample) { CHECK_CONTIGUOUS(x); CHECK_CONTIGUOUS(y); @@ -55,27 +39,5 @@ std::pair ball_query_partial_dense(at::Tensor x, at::Ten CHECK_CUDA(batch_x); CHECK_CUDA(batch_y); - at::Tensor idx = - torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Long)); - - at::Tensor dist = - torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Float)); - - cudaSetDevice(x.get_device()); - auto batch_sizes = (int64_t*)malloc(sizeof(int64_t)); - cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR(), sizeof(int64_t), - cudaMemcpyDeviceToHost); - auto batch_size = batch_sizes[0] + 1; - - batch_x = degree(batch_x, batch_size); - batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0); - batch_y = degree(batch_y, batch_size); - batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0); - - query_ball_point_kernel_partial_wrapper( - batch_size, x.size(0), y.size(0), radius, nsample, x.DATA_PTR(), y.DATA_PTR(), - batch_x.DATA_PTR(), batch_y.DATA_PTR(), idx.DATA_PTR(), - dist.DATA_PTR()); - - return std::make_pair(idx, dist); + return query_ball_point_kernel_partial_wrapper(radius, nsample, x, y, batch_x, batch_y); } diff --git a/cuda/src/ball_query_gpu.cu b/cuda/src/ball_query_gpu.cu index 1aaa451..226cbb3 100644 --- a/cuda/src/ball_query_gpu.cu +++ b/cuda/src/ball_query_gpu.cu @@ -3,14 +3,15 @@ #include #include "cuda_utils.h" - +#include // input: new_xyz(b, m, 3) xyz(b, n, 3) // output: idx(b, m, nsample) +template __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample, - const float* __restrict__ new_xyz, - const float* __restrict__ xyz, + const scalar_t* __restrict__ new_xyz, + const scalar_t* __restrict__ xyz, int64_t* __restrict__ idx_out, - float* __restrict__ dist_out) + scalar_t* __restrict__ dist_out) { int batch_index = blockIdx.x; xyz += batch_index * n * 3; @@ -24,15 +25,15 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, float radius2 = radius * radius; for (int j = index; j < m; j += stride) { - float new_x = new_xyz[j * 3 + 0]; - float new_y = new_xyz[j * 3 + 1]; - float new_z = new_xyz[j * 3 + 2]; + scalar_t new_x = new_xyz[j * 3 + 0]; + scalar_t new_y = new_xyz[j * 3 + 1]; + scalar_t new_z = new_xyz[j * 3 + 2]; for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { - float x = xyz[k * 3 + 0]; - float y = xyz[k * 3 + 1]; - float z = xyz[k * 3 + 2]; - float d2 = + scalar_t x = xyz[k * 3 + 0]; + scalar_t y = xyz[k * 3 + 1]; + scalar_t z = xyz[k * 3 + 2]; + scalar_t d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); if (d2 < radius2) { @@ -51,13 +52,14 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, } } +template __global__ void query_ball_point_kernel_partial_dense(int size_x, int size_y, float radius, - int nsample, const float* __restrict__ x, - const float* __restrict__ y, + int nsample, const scalar_t* __restrict__ x, + const scalar_t* __restrict__ y, const int64_t* __restrict__ batch_x, const int64_t* __restrict__ batch_y, int64_t* __restrict__ idx_out, - float* __restrict__ dist_out) + scalar_t* __restrict__ dist_out) { // taken from // https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu @@ -75,7 +77,7 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x, int size_y, fl int64_t count = 0; for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) { - float dist = 0; + scalar_t dist = 0; for (ptrdiff_t d = 0; d < 3; d++) { dist += (x[n_x * 3 + d] - y[n_y * 3 + d]) * (x[n_x * 3 + d] - y[n_y * 3 + d]); @@ -94,25 +96,77 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x, int size_y, fl } } -void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample, - const float* new_xyz, const float* xyz, int64_t* idx, - float* dist_out) +std::pair query_ball_point_kernel_dense_wrapper(float radius, + int nsample, + torch::Tensor new_xyz, + torch::Tensor xyz) { + int b = xyz.size(0); + int n = xyz.size(1); + int m = new_xyz.size(1); + auto idx = + torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, torch::CUDA(torch::ScalarType::Long)); + auto dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1, + torch::CUDA(xyz.scalar_type())); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - query_ball_point_kernel_dense<<>>(b, n, m, radius, nsample, - new_xyz, xyz, idx, dist_out); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + new_xyz.scalar_type(), "query_ball_point_kernel_dense_cuda", + ( + [&] + { + query_ball_point_kernel_dense<<>>( + b, n, m, radius, nsample, new_xyz.data_ptr(), + xyz.data_ptr(), idx.data_ptr(), dist.data_ptr()); + })); CUDA_CHECK_ERRORS(); + return std::make_pair(idx, dist); } -void query_ball_point_kernel_partial_wrapper(int64_t batch_size, int size_x, int size_y, - float radius, int nsample, const float* x, - const float* y, const int64_t* batch_x, - const int64_t* batch_y, int64_t* idx_out, - float* dist_out) +torch::Tensor degree(torch::Tensor row, int64_t num_nodes) { - query_ball_point_kernel_partial_dense<<>>( - size_x, size_y, radius, nsample, x, y, batch_x, batch_y, idx_out, dist_out); + auto zero = torch::zeros(num_nodes, row.options()); + auto one = torch::ones(row.size(0), row.options()); + return zero.scatter_add_(0, row, one); +} + +std::pair +query_ball_point_kernel_partial_wrapper(float radius, int nsample, torch::Tensor x, torch::Tensor y, + torch::Tensor batch_x, torch::Tensor batch_y) +{ + + int size_x = x.size(0); + int size_y = y.size(0); + auto idx = torch::full({y.size(0), nsample}, -1, torch::CUDA(torch::ScalarType::Long)); + + auto dist = torch::full({y.size(0), nsample}, -1, torch::CUDA(y.scalar_type())); + + cudaSetDevice(x.get_device()); + auto batch_sizes = (int64_t*)malloc(sizeof(int64_t)); + cudaMemcpy(batch_sizes, batch_x[-1].data_ptr(), sizeof(int64_t), + cudaMemcpyDeviceToHost); + auto batch_size = batch_sizes[0] + 1; + + batch_x = degree(batch_x, batch_size); + batch_x = torch::cat({torch::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0); + batch_y = degree(batch_y, batch_size); + batch_y = torch::cat({torch::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x.scalar_type(), "query_ball_point_kernel_dense_cuda", + ( + [&] + { + query_ball_point_kernel_partial_dense + <<>>( + size_x, size_y, radius, nsample, x.data_ptr(), + y.data_ptr(), batch_x.data_ptr(), + batch_y.data_ptr(), idx.data_ptr(), + dist.data_ptr()); + })); CUDA_CHECK_ERRORS(); + + return std::make_pair(idx, dist); } diff --git a/cuda/src/chamfer_dist_gpu.cu b/cuda/src/chamfer_dist_gpu.cu index 0b3eee3..3aff6b0 100644 --- a/cuda/src/chamfer_dist_gpu.cu +++ b/cuda/src/chamfer_dist_gpu.cu @@ -1,7 +1,7 @@ #include #include #include - +#include #include "cuda_utils.h" #include @@ -159,12 +159,12 @@ std::vector chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch const int batch_size = xyz1.size(0); const int n = xyz1.size(1); // num_points point cloud A const int m = xyz2.size(1); // num_points point cloud B - torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(xyz1.scalar_type())); - torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(xyz1.scalar_type())); - torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt)); - torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt)); + auto dist1 = torch::zeros({batch_size, n}, torch::CUDA(xyz1.scalar_type())); + auto dist2 = torch::zeros({batch_size, m}, torch::CUDA(xyz1.scalar_type())); + auto idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt)); + auto idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt)); - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( xyz1.scalar_type(), "chamfer_dist_cuda", ([&] { chamfer_dist_kernel<<>>( batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(), @@ -202,12 +202,12 @@ __global__ void chamfer_dist_grad_kernel(int b, int n, const scalar_t* __restric scalar_t y2 = xyz2[(i * m + j2) * 3 + 1]; scalar_t z2 = xyz2[(i * m + j2) * 3 + 2]; scalar_t g = grad_dist1[i * n + j] * 2; - atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2)); - atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2)); - atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2)); - atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2))); - atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2))); - atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2))); + gpuAtomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2)); + gpuAtomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2)); + gpuAtomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2)); + gpuAtomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2))); + gpuAtomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2))); + gpuAtomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2))); } } } @@ -220,10 +220,10 @@ std::vector chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1, const int batch_size = xyz1.size(0); const int n = xyz1.size(1); // num_points point cloud A const int m = xyz2.size(1); // num_points point cloud B - torch::Tensor grad_xyz1 = torch::zeros_like(xyz1); - torch::Tensor grad_xyz2 = torch::zeros_like(xyz2); + auto grad_xyz1 = torch::zeros_like(xyz1); + auto grad_xyz2 = torch::zeros_like(xyz2); - AT_DISPATCH_FLOATING_TYPES( + AT_DISPATCH_FLOATING_TYPES_AND_HALF( xyz1.scalar_type(), "chamfer_dist_grad_cuda", ([&] { chamfer_dist_grad_kernel<<>>( batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(), diff --git a/cuda/src/cubic_feature_sampling_gpu.cu b/cuda/src/cubic_feature_sampling_gpu.cu index dd017de..07bac3f 100644 --- a/cuda/src/cubic_feature_sampling_gpu.cu +++ b/cuda/src/cubic_feature_sampling_gpu.cu @@ -1,10 +1,10 @@ +#include "cuda_utils.h" +#include #include #include #include #include -#include "cuda_utils.h" - #define CUDA_NUM_THREADS 512 // Computer the number of threads needed in GPU @@ -113,18 +113,20 @@ std::vector cubic_feature_sampling_kernel_wrapper(int scale, int int n_cubic_channels = cubic_features.size(1); int n_vertices = std::pow(neighborhood_size * 2, 3); - torch::Tensor point_features = torch::zeros({batch_size, n_pts, n_vertices, n_cubic_channels}, - torch::CUDA(ptcloud.scalar_type())); - torch::Tensor grid_pt_indexes = - torch::zeros({batch_size, n_pts, n_vertices}, torch::CUDA(torch::kInt)); - - AT_DISPATCH_FLOATING_TYPES( - ptcloud.scalar_type(), "cubic_feature_sampling_cuda", ([&] { - cubic_feature_sampling_kernel<<>>( - scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels, - ptcloud.data_ptr(), cubic_features.data_ptr(), - point_features.data_ptr(), grid_pt_indexes.data_ptr()); - })); + auto point_features = torch::zeros({batch_size, n_pts, n_vertices, n_cubic_channels}, + torch::CUDA(ptcloud.scalar_type())); + auto grid_pt_indexes = torch::zeros({batch_size, n_pts, n_vertices}, torch::CUDA(torch::kInt)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + ptcloud.scalar_type(), "cubic_feature_sampling_cuda", + ( + [&] + { + cubic_feature_sampling_kernel<<>>( + scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels, + ptcloud.data_ptr(), cubic_features.data_ptr(), + point_features.data_ptr(), grid_pt_indexes.data_ptr()); + })); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) @@ -170,7 +172,7 @@ __global__ void cubic_feature_sampling_grad_kernel(int scale, int neighborhood_s // atomicAdd(&(grad_ptcloud[i * 3 + 0]), grad_val); // atomicAdd(&(grad_ptcloud[i * 3 + 1]), grad_val); // atomicAdd(&(grad_ptcloud[i * 3 + 2]), grad_val); - atomicAdd(&(grad_cubic_features[k * cub_scale + vertex_idx]), grad_val); + gpuAtomicAdd(&(grad_cubic_features[k * cub_scale + vertex_idx]), grad_val); } } } @@ -186,19 +188,21 @@ cubic_feature_sampling_grad_kernel_wrapper(int scale, int neighborhood_size, int n_pts = grid_pt_indexes.size(1); int n_vertices = std::pow(neighborhood_size * 2, 3); - torch::Tensor grad_ptcloud = + auto grad_ptcloud = torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grad_point_features.scalar_type())); - torch::Tensor grad_cubic_features = - torch::zeros({batch_size, n_cubic_channels, scale, scale, scale}, - torch::CUDA(grad_point_features.scalar_type())); - - AT_DISPATCH_FLOATING_TYPES( - grad_point_features.scalar_type(), "cubic_feature_sampling_grad_cuda", ([&] { - cubic_feature_sampling_grad_kernel<<>>( - scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels, - grad_point_features.data_ptr(), grid_pt_indexes.data_ptr(), - grad_ptcloud.data_ptr(), grad_cubic_features.data_ptr()); - })); + auto grad_cubic_features = torch::zeros({batch_size, n_cubic_channels, scale, scale, scale}, + torch::CUDA(grad_point_features.scalar_type())); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_point_features.scalar_type(), "cubic_feature_sampling_grad_cuda", + ( + [&] + { + cubic_feature_sampling_grad_kernel<<>>( + scale, neighborhood_size, n_vertices, n_pts, n_cubic_channels, + grad_point_features.data_ptr(), grid_pt_indexes.data_ptr(), + grad_ptcloud.data_ptr(), grad_cubic_features.data_ptr()); + })); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) diff --git a/cuda/src/gridding_gpu.cu b/cuda/src/gridding_gpu.cu index 5fdc68b..24fde23 100644 --- a/cuda/src/gridding_gpu.cu +++ b/cuda/src/gridding_gpu.cu @@ -1,10 +1,10 @@ +#include "cuda_utils.h" +#include #include #include #include #include -#include "cuda_utils.h" - #define CUDA_NUM_THREADS 512 // Computer the number of threads needed in GPU @@ -19,8 +19,7 @@ __device__ int compute_index(int offset_x, int offset_y, int offset_z, int len_y return offset_x * len_y * len_z + offset_y * len_z + offset_z; } -template -__device__ scalar_t compute_weight(scalar_t x, scalar_t x0) +template __device__ scalar_t compute_weight(scalar_t x, scalar_t x0) { return 1 - abs(x - x0); } @@ -127,44 +126,44 @@ gridding_kernel(int n_grid_vertices, int n_pts, float min_x, float min_y, float { // LLL -> Lower X, Lower Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 0]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 0] * - grid_pt_weights[j * 24 + 1] * - grid_pt_weights[j * 24 + 2]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 0] * + grid_pt_weights[j * 24 + 1] * + grid_pt_weights[j * 24 + 2]); // LLU -> Lower X, Lower Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 1]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 3] * - grid_pt_weights[j * 24 + 4] * - grid_pt_weights[j * 24 + 5]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 3] * + grid_pt_weights[j * 24 + 4] * + grid_pt_weights[j * 24 + 5]); // LUL -> Lower X, Upper Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 2]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 6] * - grid_pt_weights[j * 24 + 7] * - grid_pt_weights[j * 24 + 8]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 6] * + grid_pt_weights[j * 24 + 7] * + grid_pt_weights[j * 24 + 8]); // LUU -> Lower X, Upper Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 3]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 9] * - grid_pt_weights[j * 24 + 10] * - grid_pt_weights[j * 24 + 11]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 9] * + grid_pt_weights[j * 24 + 10] * + grid_pt_weights[j * 24 + 11]); // ULL -> Upper X, Lower Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 4]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 12] * - grid_pt_weights[j * 24 + 13] * - grid_pt_weights[j * 24 + 14]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 12] * + grid_pt_weights[j * 24 + 13] * + grid_pt_weights[j * 24 + 14]); // ULU -> Upper X, Lower Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 5]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 15] * - grid_pt_weights[j * 24 + 16] * - grid_pt_weights[j * 24 + 17]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 15] * + grid_pt_weights[j * 24 + 16] * + grid_pt_weights[j * 24 + 17]); // UUL -> Upper X, Upper Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 6]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 18] * - grid_pt_weights[j * 24 + 19] * - grid_pt_weights[j * 24 + 20]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 18] * + grid_pt_weights[j * 24 + 19] * + grid_pt_weights[j * 24 + 20]); // UUU -> Upper X, Upper Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 7]; - atomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 21] * - grid_pt_weights[j * 24 + 22] * - grid_pt_weights[j * 24 + 23]); + gpuAtomicAdd(&(grid_weights[gvtx_idx]), grid_pt_weights[j * 24 + 21] * + grid_pt_weights[j * 24 + 22] * + grid_pt_weights[j * 24 + 23]); } } @@ -179,19 +178,22 @@ std::vector gridding_kernel_warpper(float min_x, float max_x, flo int len_z = max_z - min_z + 1; int n_grid_vertices = len_x * len_y * len_z; - torch::Tensor grid_weights = + auto grid_weights = torch::zeros({batch_size, n_grid_vertices}, torch::CUDA(ptcloud.scalar_type())); - torch::Tensor grid_pt_weights = + auto grid_pt_weights = torch::zeros({batch_size, n_pts, 8, 3}, torch::CUDA(ptcloud.scalar_type())); - torch::Tensor grid_pt_indexes = torch::zeros({batch_size, n_pts, 8}, torch::CUDA(torch::kInt)); - - AT_DISPATCH_FLOATING_TYPES( - ptcloud.scalar_type(), "gridding_cuda", ([&] { - gridding_kernel<<>>( - n_grid_vertices, n_pts, min_x, min_y, min_z, len_y, len_z, - ptcloud.data_ptr(), grid_weights.data_ptr(), - grid_pt_weights.data_ptr(), grid_pt_indexes.data_ptr()); - })); + auto grid_pt_indexes = torch::zeros({batch_size, n_pts, 8}, torch::CUDA(torch::kInt)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + ptcloud.scalar_type(), "gridding_cuda", + ( + [&] + { + gridding_kernel<<>>( + n_grid_vertices, n_pts, min_x, min_y, min_z, len_y, len_z, + ptcloud.data_ptr(), grid_weights.data_ptr(), + grid_pt_weights.data_ptr(), grid_pt_indexes.data_ptr()); + })); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) @@ -227,9 +229,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 0]; y_weights = grid_pt_weights[j * 24 + 1]; z_weights = grid_pt_weights[j * 24 + 2]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); // LLU -> Lower X, Lower Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 1]; @@ -237,9 +239,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 3]; y_weights = grid_pt_weights[j * 24 + 4]; z_weights = grid_pt_weights[j * 24 + 5]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); // LUL -> Lower X, Upper Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 2]; @@ -247,9 +249,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 6]; y_weights = grid_pt_weights[j * 24 + 7]; z_weights = grid_pt_weights[j * 24 + 8]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); // LUU -> Lower X, Upper Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 3]; @@ -257,9 +259,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 9]; y_weights = grid_pt_weights[j * 24 + 10]; z_weights = grid_pt_weights[j * 24 + 11]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), -grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); // ULL -> Upper X, Lower Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 4]; @@ -267,9 +269,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 12]; y_weights = grid_pt_weights[j * 24 + 13]; z_weights = grid_pt_weights[j * 24 + 14]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); // ULU -> Upper X, Lower Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 5]; @@ -277,9 +279,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 15]; y_weights = grid_pt_weights[j * 24 + 16]; z_weights = grid_pt_weights[j * 24 + 17]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), -grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); // UUL -> Upper X, Upper Y, Lower Z gvtx_idx = grid_pt_indexes[j * 8 + 6]; @@ -287,9 +289,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 18]; y_weights = grid_pt_weights[j * 24 + 19]; z_weights = grid_pt_weights[j * 24 + 20]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), -grad_vtx * x_weights * y_weights); // UUU -> Upper X, Upper Y, Upper Z gvtx_idx = grid_pt_indexes[j * 8 + 7]; @@ -297,9 +299,9 @@ gridding_grad_kernel(int n_grid_vertices, int n_pts, const scalar_t* __restrict_ x_weights = grid_pt_weights[j * 24 + 21]; y_weights = grid_pt_weights[j * 24 + 22]; z_weights = grid_pt_weights[j * 24 + 23]; - atomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); - atomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 0]), grad_vtx * y_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 1]), grad_vtx * x_weights * z_weights); + gpuAtomicAdd(&(grad_ptcloud[j * 3 + 2]), grad_vtx * x_weights * y_weights); } } @@ -311,16 +313,19 @@ torch::Tensor gridding_grad_kernel_warpper(torch::Tensor grid_pt_weights, int n_grid_vertices = grad_grid.size(1); int n_pts = grid_pt_indexes.size(1); - torch::Tensor grad_ptcloud = + auto grad_ptcloud = torch::zeros({batch_size, n_pts, 3}, torch::CUDA(grid_pt_weights.scalar_type())); - AT_DISPATCH_FLOATING_TYPES( - grid_pt_weights.scalar_type(), "gridding_grad_cuda", ([&] { - gridding_grad_kernel<<>>( - n_grid_vertices, n_pts, grid_pt_weights.data_ptr(), - grid_pt_indexes.data_ptr(), grad_grid.data_ptr(), - grad_ptcloud.data_ptr()); - })); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid_pt_weights.scalar_type(), "gridding_grad_cuda", + ( + [&] + { + gridding_grad_kernel<<>>( + n_grid_vertices, n_pts, grid_pt_weights.data_ptr(), + grid_pt_indexes.data_ptr(), grad_grid.data_ptr(), + grad_ptcloud.data_ptr()); + })); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) diff --git a/cuda/src/interpolate.cpp b/cuda/src/interpolate.cpp index f341922..d1e7712 100644 --- a/cuda/src/interpolate.cpp +++ b/cuda/src/interpolate.cpp @@ -2,14 +2,7 @@ #include "compat.h" #include "utils.h" -void three_nn_kernel_wrapper(int b, int n, int m, const float* unknown, const float* known, - float* dist2, int* idx); -void three_interpolate_kernel_wrapper(int b, int c, int m, int n, const float* points, - const int* idx, const float* weight, float* out); -void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, const float* grad_out, - const int* idx, const float* weight, float* grad_points); - -std::vector three_nn(at::Tensor unknowns, at::Tensor knows) +std::vector three_nn(torch::Tensor unknowns, torch::Tensor knows) { CHECK_CONTIGUOUS(unknowns); CHECK_CONTIGUOUS(knows); @@ -19,19 +12,10 @@ std::vector three_nn(at::Tensor unknowns, at::Tensor knows) CHECK_CUDA(knows); CHECK_CUDA(unknowns); - at::Tensor idx = torch::zeros({unknowns.size(0), unknowns.size(1), 3}, - at::device(unknowns.device()).dtype(at::ScalarType::Int)); - at::Tensor dist2 = torch::zeros({unknowns.size(0), unknowns.size(1), 3}, - at::device(unknowns.device()).dtype(at::ScalarType::Float)); - - three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), - unknowns.DATA_PTR(), knows.DATA_PTR(), - dist2.DATA_PTR(), idx.DATA_PTR()); - - return {dist2, idx}; + return three_nn_kernel_wrapper(unknowns, knows); } -at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, at::Tensor weight) +torch::Tensor three_interpolate(torch::Tensor points, torch::Tensor idx, torch::Tensor weight) { CHECK_CONTIGUOUS(points); CHECK_CONTIGUOUS(idx); @@ -43,17 +27,11 @@ at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, at::Tensor weigh CHECK_CUDA(idx); CHECK_CUDA(weight); - at::Tensor output = torch::zeros({points.size(0), points.size(1), idx.size(1)}, - at::device(points.device()).dtype(at::ScalarType::Float)); - - three_interpolate_kernel_wrapper(points.size(0), points.size(1), points.size(2), idx.size(1), - points.DATA_PTR(), idx.DATA_PTR(), - weight.DATA_PTR(), output.DATA_PTR()); - - return output; + return three_interpolate_kernel_wrapper(points, idx, weight); } -at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight, - const int m) + +torch::Tensor three_interpolate_grad(torch::Tensor grad_out, torch::Tensor idx, + torch::Tensor weight, const int m) { CHECK_CONTIGUOUS(grad_out); CHECK_CONTIGUOUS(idx); @@ -65,12 +43,5 @@ at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tenso CHECK_CUDA(weight); CHECK_CUDA(grad_out); - at::Tensor output = torch::zeros({grad_out.size(0), grad_out.size(1), m}, - at::device(grad_out.device()).dtype(at::ScalarType::Float)); - - three_interpolate_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), grad_out.size(2), m, - grad_out.DATA_PTR(), idx.DATA_PTR(), - weight.DATA_PTR(), output.DATA_PTR()); - - return output; + return three_interpolate_grad_kernel_wrapper(grad_out, idx, weight, m); } diff --git a/cuda/src/interpolate_gpu.cu b/cuda/src/interpolate_gpu.cu index db38ba2..4946adb 100644 --- a/cuda/src/interpolate_gpu.cu +++ b/cuda/src/interpolate_gpu.cu @@ -1,13 +1,17 @@ +#include "cuda_utils.h" +#include #include #include #include - -#include "cuda_utils.h" +#include +#include // input: unknown(b, n, 3) known(b, m, 3) // output: dist2(b, n, 3), idx(b, n, 3) -__global__ void three_nn_kernel(int b, int n, int m, const float* __restrict__ unknown, - const float* __restrict__ known, float* __restrict__ dist2, +template +__global__ void three_nn_kernel(int b, int n, int m, const double upper_bd, + const scalar_t* __restrict__ unknown, + const scalar_t* __restrict__ known, scalar_t* __restrict__ dist2, int* __restrict__ idx) { int batch_index = blockIdx.x; @@ -20,18 +24,18 @@ __global__ void three_nn_kernel(int b, int n, int m, const float* __restrict__ u int stride = blockDim.x; for (int j = index; j < n; j += stride) { - float ux = unknown[j * 3 + 0]; - float uy = unknown[j * 3 + 1]; - float uz = unknown[j * 3 + 2]; + scalar_t ux = unknown[j * 3 + 0]; + scalar_t uy = unknown[j * 3 + 1]; + scalar_t uz = unknown[j * 3 + 2]; + scalar_t best1 = upper_bd, best2 = upper_bd, best3 = upper_bd; - double best1 = 1e40, best2 = 1e40, best3 = 1e40; int besti1 = 0, besti2 = 0, besti3 = 0; for (int k = 0; k < m; ++k) { - float x = known[k * 3 + 0]; - float y = known[k * 3 + 1]; - float z = known[k * 3 + 2]; - float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + scalar_t x = known[k * 3 + 0]; + scalar_t y = known[k * 3 + 1]; + scalar_t z = known[k * 3 + 2]; + scalar_t d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); if (d < best1) { best3 = best2; @@ -64,21 +68,48 @@ __global__ void three_nn_kernel(int b, int n, int m, const float* __restrict__ u } } -void three_nn_kernel_wrapper(int b, int n, int m, const float* unknown, const float* known, - float* dist2, int* idx) +std::vector three_nn_kernel_wrapper(torch::Tensor unknowns, torch::Tensor knows) { + int b = unknowns.size(0); + int n = unknowns.size(1); + int m = knows.size(1); + + auto idx = torch::zeros({b, n, 3}, torch::CUDA(torch::kInt)); + auto dist2 = torch::zeros({b, n, 3}, torch::CUDA(unknowns.scalar_type())); + double upper_bd = 0; + switch (unknowns.scalar_type()) + { + case torch::ScalarType::Double: + upper_bd = 1e40; + break; + case torch::ScalarType::Half: + upper_bd = 65504; + break; + default: + upper_bd = 1e20; + break; + } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - three_nn_kernel<<>>(b, n, m, unknown, known, dist2, idx); - + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + unknowns.scalar_type(), "three_nn_kernel_cuda", + ( + [&] + { + three_nn_kernel<<>>( + b, n, m, upper_bd, unknowns.data_ptr(), knows.data_ptr(), + dist2.data_ptr(), idx.data_ptr()); + })); CUDA_CHECK_ERRORS(); + return {dist2, idx}; } // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) // output: out(b, c, n) -__global__ void three_interpolate_kernel(int b, int c, int m, int n, - const float* __restrict__ points, - const int* __restrict__ idx, - const float* __restrict__ weight, float* __restrict__ out) +template +__global__ void +three_interpolate_kernel(int b, int c, int m, int n, const scalar_t* __restrict__ points, + const int* __restrict__ idx, const scalar_t* __restrict__ weight, + scalar_t* __restrict__ out) { int batch_index = blockIdx.x; points += batch_index * m * c; @@ -94,9 +125,9 @@ __global__ void three_interpolate_kernel(int b, int c, int m, int n, { const int l = i / n; const int j = i % n; - float w1 = weight[j * 3 + 0]; - float w2 = weight[j * 3 + 1]; - float w3 = weight[j * 3 + 2]; + scalar_t w1 = weight[j * 3 + 0]; + scalar_t w2 = weight[j * 3 + 1]; + scalar_t w3 = weight[j * 3 + 2]; int i1 = idx[j * 3 + 0]; int i2 = idx[j * 3 + 1]; @@ -106,24 +137,38 @@ __global__ void three_interpolate_kernel(int b, int c, int m, int n, } } -void three_interpolate_kernel_wrapper(int b, int c, int m, int n, const float* points, - const int* idx, const float* weight, float* out) +torch::Tensor three_interpolate_kernel_wrapper(torch::Tensor points, torch::Tensor idx, + torch::Tensor weight) { + int b = points.size(0); + int c = points.size(1); + int m = points.size(2); + int n = idx.size(1); + + auto out = torch::zeros({b, c, n}, torch::CUDA(points.scalar_type())); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - three_interpolate_kernel<<>>(b, c, m, n, points, idx, - weight, out); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "three_interpolate_kernel_cuda", + ( + [&] + { + three_interpolate_kernel<<>>( + b, c, m, n, points.data_ptr(), idx.data_ptr(), + weight.data_ptr(), out.data_ptr()); + })); CUDA_CHECK_ERRORS(); + return out; } // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) // output: grad_points(b, c, m) - -__global__ void three_interpolate_grad_kernel(int b, int c, int n, int m, - const float* __restrict__ grad_out, - const int* __restrict__ idx, - const float* __restrict__ weight, - float* __restrict__ grad_points) +template +__global__ void +three_interpolate_grad_kernel(int b, int c, int n, int m, const scalar_t* __restrict__ grad_out, + const int* __restrict__ idx, const scalar_t* __restrict__ weight, + scalar_t* __restrict__ grad_points) { int batch_index = blockIdx.x; grad_out += batch_index * n * c; @@ -137,26 +182,39 @@ __global__ void three_interpolate_grad_kernel(int b, int c, int n, int m, { const int l = i / n; const int j = i % n; - float w1 = weight[j * 3 + 0]; - float w2 = weight[j * 3 + 1]; - float w3 = weight[j * 3 + 2]; + scalar_t w1 = weight[j * 3 + 0]; + scalar_t w2 = weight[j * 3 + 1]; + scalar_t w3 = weight[j * 3 + 2]; int i1 = idx[j * 3 + 0]; int i2 = idx[j * 3 + 1]; int i3 = idx[j * 3 + 2]; - atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); - atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); - atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); + gpuAtomicAdd(grad_points + l * m + i1, grad_out[i] * w1); + gpuAtomicAdd(grad_points + l * m + i2, grad_out[i] * w2); + gpuAtomicAdd(grad_points + l * m + i3, grad_out[i] * w3); } } -void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, const float* grad_out, - const int* idx, const float* weight, float* grad_points) +torch::Tensor three_interpolate_grad_kernel_wrapper(torch::Tensor grad_out, torch::Tensor idx, + torch::Tensor weight, const int m) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - three_interpolate_grad_kernel<<>>( - b, c, n, m, grad_out, idx, weight, grad_points); + int b = grad_out.size(0); + int c = grad_out.size(1); + int n = grad_out.size(2); + auto grad_points = torch::zeros({b, c, m}, torch::CUDA(grad_out.scalar_type())); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_out.scalar_type(), "three_interpolate_grad_kernel_cuda", + ( + [&] + { + three_interpolate_grad_kernel<<>>( + b, c, n, m, grad_out.data_ptr(), idx.data_ptr(), + weight.data_ptr(), grad_points.data_ptr()); + })); CUDA_CHECK_ERRORS(); + return grad_points; } diff --git a/cuda/src/sampling.cpp b/cuda/src/sampling.cpp index e2a4dbe..cbe1692 100644 --- a/cuda/src/sampling.cpp +++ b/cuda/src/sampling.cpp @@ -2,24 +2,13 @@ #include "compat.h" #include "utils.h" -void furthest_point_sampling_kernel_wrapper(int b, int n, int m, const float* dataset, float* temp, - int* idxs); +torch::Tensor furthest_point_sampling_kernel_wrapper(torch::Tensor points, const int nsamples); -at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) +torch::Tensor furthest_point_sampling(torch::Tensor points, const int nsamples) { CHECK_CONTIGUOUS(points); CHECK_IS_FLOAT(points); CHECK_CUDA(points); - at::Tensor output = torch::zeros({points.size(0), nsamples}, - at::device(points.device()).dtype(at::ScalarType::Int)); - - at::Tensor tmp = torch::full({points.size(0), points.size(1)}, 1e10, - at::device(points.device()).dtype(at::ScalarType::Float)); - - furthest_point_sampling_kernel_wrapper(points.size(0), points.size(1), nsamples, - points.DATA_PTR(), tmp.DATA_PTR(), - output.DATA_PTR()); - - return output; + return furthest_point_sampling_kernel_wrapper(points, nsamples); } diff --git a/cuda/src/sampling_gpu.cu b/cuda/src/sampling_gpu.cu index db631ee..0b43d80 100644 --- a/cuda/src/sampling_gpu.cu +++ b/cuda/src/sampling_gpu.cu @@ -1,11 +1,13 @@ +#include "cuda_utils.h" #include #include +#include -#include "cuda_utils.h" - -__device__ void __update(float* __restrict__ dists, int* __restrict__ dists_i, int idx1, int idx2) +template +__device__ void __update(scalar_t* __restrict__ dists, int* __restrict__ dists_i, int idx1, + int idx2) { - const float v1 = dists[idx1], v2 = dists[idx2]; + const scalar_t v1 = dists[idx1], v2 = dists[idx2]; const int i1 = dists_i[idx1], i2 = dists_i[idx2]; dists[idx1] = max(v1, v2); dists_i[idx1] = v2 > v1 ? i2 : i1; @@ -13,14 +15,14 @@ __device__ void __update(float* __restrict__ dists, int* __restrict__ dists_i, i // Input dataset: (b, n, 3), tmp: (b, n) // Ouput idxs (b, m) -template +template __global__ void furthest_point_sampling_kernel(int b, int n, int m, - const float* __restrict__ dataset, - float* __restrict__ temp, int* __restrict__ idxs) + const scalar_t* __restrict__ dataset, + scalar_t* __restrict__ temp, int* __restrict__ idxs) { if (m <= 0) return; - __shared__ float dists[block_size]; + __shared__ scalar_t dists[block_size]; __shared__ int dists_i[block_size]; int batch_index = blockIdx.x; @@ -36,26 +38,26 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, idxs[0] = old; __syncthreads(); - for (int j = 1; j < m; j++) + for (int j = 0; j < m; j++) { int besti = 0; - float best = -1; - float x1 = dataset[old * 3 + 0]; - float y1 = dataset[old * 3 + 1]; - float z1 = dataset[old * 3 + 2]; + scalar_t best = -1; + scalar_t x1 = dataset[old * 3 + 0]; + scalar_t y1 = dataset[old * 3 + 1]; + scalar_t z1 = dataset[old * 3 + 2]; for (int k = tid; k < n; k += stride) { - float x2, y2, z2; + scalar_t x2, y2, z2; x2 = dataset[k * 3 + 0]; y2 = dataset[k * 3 + 1]; z2 = dataset[k * 3 + 2]; - float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); + scalar_t mag = (x2 * x2) + (y2 * y2) + (z2 * z2); if (mag <= 1e-3) continue; - float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + scalar_t d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); - float d2 = min(d, temp[k]); + scalar_t d2 = min(d, temp[k]); temp[k] = d2; besti = d2 > best ? k : besti; best = d2 > best ? d2 : best; @@ -68,7 +70,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 256) { - __update(dists, dists_i, tid, tid + 256); + __update(dists, dists_i, tid, tid + 256); } __syncthreads(); } @@ -76,7 +78,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 128) { - __update(dists, dists_i, tid, tid + 128); + __update(dists, dists_i, tid, tid + 128); } __syncthreads(); } @@ -84,7 +86,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 64) { - __update(dists, dists_i, tid, tid + 64); + __update(dists, dists_i, tid, tid + 64); } __syncthreads(); } @@ -92,7 +94,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 32) { - __update(dists, dists_i, tid, tid + 32); + __update(dists, dists_i, tid, tid + 32); } __syncthreads(); } @@ -100,7 +102,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 16) { - __update(dists, dists_i, tid, tid + 16); + __update(dists, dists_i, tid, tid + 16); } __syncthreads(); } @@ -108,7 +110,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 8) { - __update(dists, dists_i, tid, tid + 8); + __update(dists, dists_i, tid, tid + 8); } __syncthreads(); } @@ -116,7 +118,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 4) { - __update(dists, dists_i, tid, tid + 4); + __update(dists, dists_i, tid, tid + 4); } __syncthreads(); } @@ -124,7 +126,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 2) { - __update(dists, dists_i, tid, tid + 2); + __update(dists, dists_i, tid, tid + 2); } __syncthreads(); } @@ -132,7 +134,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, { if (tid < 1) { - __update(dists, dists_i, tid, tid + 1); + __update(dists, dists_i, tid, tid + 1); } __syncthreads(); } @@ -143,59 +145,157 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m, } } -void furthest_point_sampling_kernel_wrapper(int b, int n, int m, const float* dataset, float* temp, - int* idxs) +torch::Tensor furthest_point_sampling_kernel_wrapper(torch::Tensor points, const int nsamples) { - unsigned int n_threads = opt_n_threads(n); + + int b = points.size(0); + int n = points.size(1); + int m = nsamples; + auto idxs = + torch::zeros({points.size(0), nsamples}, torch::CUDA(torch::ScalarType::Int)); + + float init_num = 0; + switch (points.scalar_type()) + { + case torch::ScalarType::Half: + init_num = 65504; + break; + default: + init_num = 1e10; + break; + } + + auto temp = + torch::full({points.size(0), points.size(1)}, init_num, torch::CUDA(points.scalar_type())); + const unsigned int n_threads = opt_n_threads(n); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (n_threads) { case 512: - furthest_point_sampling_kernel<512> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 256: - furthest_point_sampling_kernel<256> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 128: - furthest_point_sampling_kernel<128> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 64: - furthest_point_sampling_kernel<64> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 32: - furthest_point_sampling_kernel<32> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 16: - furthest_point_sampling_kernel<16> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 8: - furthest_point_sampling_kernel<8> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 4: - furthest_point_sampling_kernel<4> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 2: - furthest_point_sampling_kernel<2> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; case 1: - furthest_point_sampling_kernel<1> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); break; default: - furthest_point_sampling_kernel<512> - <<>>(b, n, m, dataset, temp, idxs); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "furthest_point_sampling_kernel_cuda", + ( + [&] + { + furthest_point_sampling_kernel<<>>( + b, n, m, points.data_ptr(), temp.data_ptr(), + idxs.data_ptr()); + })); + break; } CUDA_CHECK_ERRORS(); + return idxs; } diff --git a/test/test_chamfer_dist.py b/test/test_chamfer_dist.py index bba0612..843803c 100644 --- a/test/test_chamfer_dist.py +++ b/test/test_chamfer_dist.py @@ -6,14 +6,14 @@ from torch.autograd import gradcheck -from . import run_if_cuda + ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") sys.path.insert(0, ROOT) from torch_points_kernels.chamfer_dist import ChamferFunction, chamfer_dist - +from test import run_if_cuda class TestChamferDistance(unittest.TestCase): @run_if_cuda diff --git a/test/test_fps.py b/test/test_fps.py index 31bf8f3..fe3a2c5 100644 --- a/test/test_fps.py +++ b/test/test_fps.py @@ -3,10 +3,14 @@ import os import sys + + ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") sys.path.insert(0, ROOT) from torch_points_kernels.points_cpu import fps +from torch_points_kernels import furthest_point_sample +from test import run_if_cuda class TestFps(unittest.TestCase): @@ -20,6 +24,15 @@ def test_random(self): idx = fps(points, 2, True) self.assertNotEqual(idx[0][0], 0) + @run_if_cuda + def test_gpu(self): + points = torch.randn([16, 100, 3]).cuda() + nsamples = 2 + idx = furthest_point_sample(points,nsamples) + idx_cpu = furthest_point_sample(points.cpu(),nsamples) + sorted_idx, _ = torch.sort(idx.cpu()) + sorted_idx_cpu, _ = torch.sort(idx_cpu) + torch.testing.assert_allclose(sorted_idx,sorted_idx_cpu) if __name__ == "__main__": unittest.main() diff --git a/test/test_interpolate.py b/test/test_interpolate.py index 67f364b..874f5fc 100644 --- a/test/test_interpolate.py +++ b/test/test_interpolate.py @@ -2,9 +2,11 @@ import torch from torch.autograd import gradcheck from torch_points_kernels import three_interpolate, three_nn - -from . import run_if_cuda - +import sys +import os +ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") +sys.path.insert(0, ROOT) +from test import run_if_cuda class TestInterpolate(unittest.TestCase): @run_if_cuda diff --git a/torch_points_kernels/chamfer_dist.py b/torch_points_kernels/chamfer_dist.py index 528216d..6ff1faf 100644 --- a/torch_points_kernels/chamfer_dist.py +++ b/torch_points_kernels/chamfer_dist.py @@ -1,11 +1,12 @@ import torch - +from torch.cuda.amp import custom_bwd,custom_fwd if torch.cuda.is_available(): import torch_points_kernels.points_cuda as tpcuda class ChamferFunction(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.half) def forward(ctx, xyz1, xyz2): if not torch.cuda.is_available(): raise NotImplementedError("CPU version is not available for Chamfer Distance") @@ -16,6 +17,7 @@ def forward(ctx, xyz1, xyz2): return dist1, dist2 @staticmethod + @custom_bwd def backward(ctx, grad_dist1, grad_dist2): xyz1, xyz2, idx1, idx2 = ctx.saved_tensors grad_xyz1, grad_xyz2 = tpcuda.chamfer_dist_grad(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2) diff --git a/torch_points_kernels/cubic_feature_sampling.py b/torch_points_kernels/cubic_feature_sampling.py index 18e178b..1deab73 100644 --- a/torch_points_kernels/cubic_feature_sampling.py +++ b/torch_points_kernels/cubic_feature_sampling.py @@ -1,11 +1,12 @@ import torch - +from torch.cuda.amp import custom_bwd,custom_fwd if torch.cuda.is_available(): import torch_points_kernels.points_cuda as tpcuda class CubicFeatureSamplingFunction(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.half) def forward(ctx, ptcloud, cubic_features, neighborhood_size=1): scale = cubic_features.size(2) if not torch.cuda.is_available(): @@ -18,6 +19,7 @@ def forward(ctx, ptcloud, cubic_features, neighborhood_size=1): return point_features @staticmethod + @custom_bwd def backward(ctx, grad_point_features): scale, neighborhood_size, grid_pt_indexes = ctx.saved_tensors scale = int(scale.item()) diff --git a/torch_points_kernels/gridding.py b/torch_points_kernels/gridding.py index a2b187a..6b399e2 100644 --- a/torch_points_kernels/gridding.py +++ b/torch_points_kernels/gridding.py @@ -1,4 +1,5 @@ import torch +from torch.cuda.amp import custom_bwd,custom_fwd if torch.cuda.is_available(): import torch_points_kernels.points_cuda as tpcuda @@ -6,6 +7,7 @@ class GriddingFunction(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.half) def forward(ctx, ptcloud, scale): if not torch.cuda.is_available(): raise NotImplementedError("CPU version is not available for Chamfer Distance") @@ -21,6 +23,7 @@ def forward(ctx, ptcloud, scale): return grid @staticmethod + @custom_bwd def backward(ctx, grad_grid): grid_pt_weights, grid_pt_indexes = ctx.saved_tensors grad_ptcloud = tpcuda.gridding_grad(grid_pt_weights, grid_pt_indexes, grad_grid) diff --git a/torch_points_kernels/torchpoints.py b/torch_points_kernels/torchpoints.py index ebef1d3..ce7e755 100644 --- a/torch_points_kernels/torchpoints.py +++ b/torch_points_kernels/torchpoints.py @@ -1,6 +1,7 @@ import torch from torch.autograd import Function import torch.nn as nn +from torch.cuda.amp import custom_bwd,custom_fwd import sys from typing import Optional, Any, Tuple @@ -66,19 +67,20 @@ def three_nn(unknown, known): class ThreeInterpolate(Function): @staticmethod + @custom_fwd(cast_inputs=torch.half) def forward(ctx, features, idx, weight): # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor B, c, m = features.size() n = idx.size(1) ctx.three_interpolate_for_backward = (idx, weight, m) - if features.is_cuda: return tpcuda.three_interpolate(features, idx, weight) else: return tpcpu.knn_interpolate(features, idx, weight) @staticmethod + @custom_bwd def backward(ctx, grad_out): # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] r"""