Skip to content

Add support for Half dtype and mixed precision training. #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cuda/include/ball_query.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#pragma once
#include <torch/extension.h>

std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tensor xyz,
const float radius, const int nsample);
std::pair<torch::Tensor, torch::Tensor> ball_query_dense(torch::Tensor new_xyz, torch::Tensor xyz,
const float radius, const int nsample);

std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Tensor y,
at::Tensor batch_x, at::Tensor batch_y,
const float radius, const int nsample);
std::pair<torch::Tensor, torch::Tensor>
ball_query_partial_dense(torch::Tensor x, torch::Tensor y, torch::Tensor batch_x,
torch::Tensor batch_y, const float radius, const int nsample);

at::Tensor degree(at::Tensor row, int64_t num_nodes);
torch::Tensor degree(torch::Tensor row, int64_t num_nodes);
15 changes: 12 additions & 3 deletions cuda/include/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
#include <torch/extension.h>
#include <vector>

std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows);
at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, at::Tensor weight);
at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tensor weight,



std::vector<torch::Tensor> three_nn(torch::Tensor unknowns, torch::Tensor knows);
torch::Tensor three_interpolate(torch::Tensor points, torch::Tensor idx, torch::Tensor weight);
torch::Tensor three_interpolate_grad(torch::Tensor grad_out, torch::Tensor idx, torch::Tensor weight,
const int m);

std::vector<torch::Tensor> three_nn_kernel_wrapper(torch::Tensor unknown, torch::Tensor known);
torch::Tensor three_interpolate_kernel_wrapper(torch::Tensor points, torch::Tensor idx,
torch::Tensor weight);
torch::Tensor three_interpolate_grad_kernel_wrapper(torch::Tensor grad_out, torch::Tensor idx,
torch::Tensor weight, const int m);
2 changes: 1 addition & 1 deletion cuda/include/sampling.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pragma once
#include <torch/extension.h>

at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
torch::Tensor furthest_point_sampling(torch::Tensor points, const int nsamples);
72 changes: 17 additions & 55 deletions cuda/src/ball_query.cpp
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
#include "ball_query.h"
#include "compat.h"
#include "utils.h"

void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
const float* new_xyz, const float* xyz, int64_t* idx,
float* dist_out);

void query_ball_point_kernel_partial_wrapper(int64_t batch_size, int size_x, int size_y,
float radius, int nsample, const float* x,
const float* y, const int64_t* batch_x,
const int64_t* batch_y, int64_t* idx_out,
float* dist_out);

std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tensor xyz,
const float radius, const int nsample)
#include <torch/extension.h>

std::pair<torch::Tensor, torch::Tensor> query_ball_point_kernel_dense_wrapper(float radius,
int nsample,
torch::Tensor new_xyz,
torch::Tensor xyz);
std::pair<torch::Tensor, torch::Tensor>
query_ball_point_kernel_partial_wrapper(float radius, int nsample, torch::Tensor x, torch::Tensor y,
torch::Tensor batch_x, torch::Tensor batch_y);

std::pair<torch::Tensor, torch::Tensor> ball_query_dense(torch::Tensor new_xyz, torch::Tensor xyz,
const float radius, const int nsample)
{
CHECK_CONTIGUOUS(new_xyz);
CHECK_CONTIGUOUS(xyz);
Expand All @@ -23,28 +22,13 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
CHECK_CUDA(xyz);
CHECK_CUDA(new_xyz);

at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
at::device(new_xyz.device()).dtype(at::ScalarType::Long));
at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
at::device(new_xyz.device()).dtype(at::ScalarType::Float));

query_ball_point_kernel_dense_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), radius,
nsample, new_xyz.DATA_PTR<float>(), xyz.DATA_PTR<float>(),
idx.DATA_PTR<int64_t>(), dist.DATA_PTR<float>());

return std::make_pair(idx, dist);
return query_ball_point_kernel_dense_wrapper(radius, nsample, new_xyz, xyz);
}

at::Tensor degree(at::Tensor row, int64_t num_nodes)
{
auto zero = at::zeros(num_nodes, row.options());
auto one = at::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}

std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Tensor y,
at::Tensor batch_x, at::Tensor batch_y,
const float radius, const int nsample)
std::pair<torch::Tensor, torch::Tensor>
ball_query_partial_dense(torch::Tensor x, torch::Tensor y, torch::Tensor batch_x,
torch::Tensor batch_y, const float radius, const int nsample)
{
CHECK_CONTIGUOUS(x);
CHECK_CONTIGUOUS(y);
Expand All @@ -55,27 +39,5 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Ten
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);

at::Tensor idx =
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Long));

at::Tensor dist =
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Float));

cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t*)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;

batch_x = degree(batch_x, batch_size);
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);

query_ball_point_kernel_partial_wrapper(
batch_size, x.size(0), y.size(0), radius, nsample, x.DATA_PTR<float>(), y.DATA_PTR<float>(),
batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(), idx.DATA_PTR<int64_t>(),
dist.DATA_PTR<float>());

return std::make_pair(idx, dist);
return query_ball_point_kernel_partial_wrapper(radius, nsample, x, y, batch_x, batch_y);
}
108 changes: 81 additions & 27 deletions cuda/src/ball_query_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#include <stdlib.h>

#include "cuda_utils.h"

#include <torch/extension.h>
// input: new_xyz(b, m, 3) xyz(b, n, 3)
// output: idx(b, m, nsample)
template <typename scalar_t>
__global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample,
const float* __restrict__ new_xyz,
const float* __restrict__ xyz,
const scalar_t* __restrict__ new_xyz,
const scalar_t* __restrict__ xyz,
int64_t* __restrict__ idx_out,
float* __restrict__ dist_out)
scalar_t* __restrict__ dist_out)
{
int batch_index = blockIdx.x;
xyz += batch_index * n * 3;
Expand All @@ -24,15 +25,15 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
float radius2 = radius * radius;
for (int j = index; j < m; j += stride)
{
float new_x = new_xyz[j * 3 + 0];
float new_y = new_xyz[j * 3 + 1];
float new_z = new_xyz[j * 3 + 2];
scalar_t new_x = new_xyz[j * 3 + 0];
scalar_t new_y = new_xyz[j * 3 + 1];
scalar_t new_z = new_xyz[j * 3 + 2];
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k)
{
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 =
scalar_t x = xyz[k * 3 + 0];
scalar_t y = xyz[k * 3 + 1];
scalar_t z = xyz[k * 3 + 2];
scalar_t d2 =
(new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < radius2)
{
Expand All @@ -51,13 +52,14 @@ __global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius,
}
}

template <typename scalar_t>
__global__ void query_ball_point_kernel_partial_dense(int size_x, int size_y, float radius,
int nsample, const float* __restrict__ x,
const float* __restrict__ y,
int nsample, const scalar_t* __restrict__ x,
const scalar_t* __restrict__ y,
const int64_t* __restrict__ batch_x,
const int64_t* __restrict__ batch_y,
int64_t* __restrict__ idx_out,
float* __restrict__ dist_out)
scalar_t* __restrict__ dist_out)
{
// taken from
// https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu
Expand All @@ -75,7 +77,7 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x, int size_y, fl
int64_t count = 0;
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++)
{
float dist = 0;
scalar_t dist = 0;
for (ptrdiff_t d = 0; d < 3; d++)
{
dist += (x[n_x * 3 + d] - y[n_y * 3 + d]) * (x[n_x * 3 + d] - y[n_y * 3 + d]);
Expand All @@ -94,25 +96,77 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x, int size_y, fl
}
}

void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
const float* new_xyz, const float* xyz, int64_t* idx,
float* dist_out)
std::pair<torch::Tensor, torch::Tensor> query_ball_point_kernel_dense_wrapper(float radius,
int nsample,
torch::Tensor new_xyz,
torch::Tensor xyz)
{
int b = xyz.size(0);
int n = xyz.size(1);
int m = new_xyz.size(1);
auto idx =
torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, torch::CUDA(torch::ScalarType::Long));
auto dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
torch::CUDA(xyz.scalar_type()));

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0, stream>>>(b, n, m, radius, nsample,
new_xyz, xyz, idx, dist_out);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "query_ball_point_kernel_dense_cuda",
(
[&]
{
query_ball_point_kernel_dense<scalar_t><<<b, opt_n_threads(m), 0, stream>>>(
b, n, m, radius, nsample, new_xyz.data_ptr<scalar_t>(),
xyz.data_ptr<scalar_t>(), idx.data_ptr<int64_t>(), dist.data_ptr<scalar_t>());
}));

CUDA_CHECK_ERRORS();
return std::make_pair(idx, dist);
}

void query_ball_point_kernel_partial_wrapper(int64_t batch_size, int size_x, int size_y,
float radius, int nsample, const float* x,
const float* y, const int64_t* batch_x,
const int64_t* batch_y, int64_t* idx_out,
float* dist_out)
torch::Tensor degree(torch::Tensor row, int64_t num_nodes)
{
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>>(
size_x, size_y, radius, nsample, x, y, batch_x, batch_y, idx_out, dist_out);
auto zero = torch::zeros(num_nodes, row.options());
auto one = torch::ones(row.size(0), row.options());
return zero.scatter_add_(0, row, one);
}

std::pair<torch::Tensor, torch::Tensor>
query_ball_point_kernel_partial_wrapper(float radius, int nsample, torch::Tensor x, torch::Tensor y,
torch::Tensor batch_x, torch::Tensor batch_y)
{

int size_x = x.size(0);
int size_y = y.size(0);
auto idx = torch::full({y.size(0), nsample}, -1, torch::CUDA(torch::ScalarType::Long));

auto dist = torch::full({y.size(0), nsample}, -1, torch::CUDA(y.scalar_type()));

cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t*)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data_ptr<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;

batch_x = degree(batch_x, batch_size);
batch_x = torch::cat({torch::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
batch_y = degree(batch_y, batch_size);
batch_y = torch::cat({torch::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "query_ball_point_kernel_dense_cuda",
(
[&]
{
query_ball_point_kernel_partial_dense<scalar_t>
<<<batch_size, TOTAL_THREADS_SPARSE>>>(
size_x, size_y, radius, nsample, x.data_ptr<scalar_t>(),
y.data_ptr<scalar_t>(), batch_x.data_ptr<int64_t>(),
batch_y.data_ptr<int64_t>(), idx.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>());
}));

CUDA_CHECK_ERRORS();

return std::make_pair(idx, dist);
}
30 changes: 15 additions & 15 deletions cuda/src/chamfer_dist_gpu.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#include <THC/THCAtomics.cuh>
#include "cuda_utils.h"
#include <vector>

Expand Down Expand Up @@ -159,12 +159,12 @@ std::vector<torch::Tensor> chamfer_dist_kernel_wrapper(torch::Tensor xyz1, torch
const int batch_size = xyz1.size(0);
const int n = xyz1.size(1); // num_points point cloud A
const int m = xyz2.size(1); // num_points point cloud B
torch::Tensor dist1 = torch::zeros({batch_size, n}, torch::CUDA(xyz1.scalar_type()));
torch::Tensor dist2 = torch::zeros({batch_size, m}, torch::CUDA(xyz1.scalar_type()));
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
auto dist1 = torch::zeros({batch_size, n}, torch::CUDA(xyz1.scalar_type()));
auto dist2 = torch::zeros({batch_size, m}, torch::CUDA(xyz1.scalar_type()));
auto idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
auto idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));

AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_dist_cuda", ([&] {
chamfer_dist_kernel<scalar_t><<<dim3(32, 16, 1), 512>>>(
batch_size, n, xyz1.data_ptr<scalar_t>(), m, xyz2.data_ptr<scalar_t>(),
Expand Down Expand Up @@ -202,12 +202,12 @@ __global__ void chamfer_dist_grad_kernel(int b, int n, const scalar_t* __restric
scalar_t y2 = xyz2[(i * m + j2) * 3 + 1];
scalar_t z2 = xyz2[(i * m + j2) * 3 + 2];
scalar_t g = grad_dist1[i * n + j] * 2;
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
gpuAtomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
gpuAtomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
gpuAtomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
gpuAtomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
gpuAtomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
gpuAtomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
}
}
}
Expand All @@ -220,10 +220,10 @@ std::vector<torch::Tensor> chamfer_dist_grad_kernel_wrapper(torch::Tensor xyz1,
const int batch_size = xyz1.size(0);
const int n = xyz1.size(1); // num_points point cloud A
const int m = xyz2.size(1); // num_points point cloud B
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1);
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2);
auto grad_xyz1 = torch::zeros_like(xyz1);
auto grad_xyz2 = torch::zeros_like(xyz2);

AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_dist_grad_cuda", ([&] {
chamfer_dist_grad_kernel<scalar_t><<<dim3(1, 16, 1), 256>>>(
batch_size, n, xyz1.data_ptr<scalar_t>(), m, xyz2.data_ptr<scalar_t>(),
Expand Down
Loading