Skip to content

feat: LPC CUDA kernel #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,32 @@ def test_scan_equiv(samples: int, cmplx: bool, device: str):
).item()


@pytest.mark.parametrize(
"samples",
[1024],
)
@pytest.mark.parametrize("samples", [1021, 4097])
@pytest.mark.parametrize(
"cmplx",
[True, False],
)
def test_lpc_equiv(samples: int, cmplx: bool):
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_lpc_equiv(samples: int, cmplx: bool, device: str):
batch_size = 4
x, A, zi = tuple(
x.to("cpu") for x in create_test_inputs(batch_size, samples, cmplx)
x.to(device) for x in create_test_inputs(batch_size, samples, cmplx)
)
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
if device == "cuda":
numba_y = lpc_cuda(x, A, zi)
else:
numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy()))
ext_y = torch.ops.torchlpc.lpc(x, A, zi)

assert torch.allclose(numba_y, ext_y)
19 changes: 10 additions & 9 deletions torchlpc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,21 @@ def lpc_np(x: np.ndarray, A: np.ndarray, zi: np.ndarray) -> np.ndarray:
class LPC(Function):
@staticmethod
def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor:
if x.is_cuda:
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
elif EXTENSION_LOADED:
if EXTENSION_LOADED:
y = torch.ops.torchlpc.lpc(x, A, zi)
else:
warnings.warn(
"Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0."
)
y = lpc_np(
x.detach().cpu().numpy(),
A.detach().cpu().numpy(),
zi.detach().cpu().numpy(),
)
y = torch.from_numpy(y).to(x.device, x.dtype)
if x.is_cuda:
y = lpc_cuda(x.detach(), A.detach(), zi.detach())
else:
y = lpc_np(
x.detach().cpu().numpy(),
A.detach().cpu().numpy(),
zi.detach().cpu().numpy(),
)
y = torch.from_numpy(y).to(x.device, x.dtype)
return y

@staticmethod
Expand Down
177 changes: 177 additions & 0 deletions torchlpc/csrc/cuda/lpc.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include <assert.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include <torch/script.h>
#include <torch/torch.h>

// CUDA kernel for LPC computation
template <typename scalar_t>
__global__ void lpc_cuda_kernel(scalar_t* padded_y, // [B, T + order]
const scalar_t* A, // [B, T, order]
int64_t B, int64_t T, int64_t order) {
extern __shared__ char smem[];
scalar_t* sm = reinterpret_cast<scalar_t*>(smem);

int b = blockIdx.x;
int i = threadIdx.x;

if (b >= B || i >= order) return;

// Initialize shared memory with the first 'order' elements
sm[i] = padded_y[b * (T + order) + i];
__syncthreads();

int circular_idx = 0;
for (int t = 0; t < T; ++t) {
circular_idx = t % order;
scalar_t a = -A[((b * T + t) * order) + i];

// Compute s as in the Python code
int idx_offset = circular_idx - i - 1;
if (i > circular_idx - 1) {
idx_offset += order;
}
scalar_t s = sm[(idx_offset + order) % order];

scalar_t v = a * s;

if (i == order - 1) {
sm[circular_idx] = v;
v = padded_y[b * (T + order) + t + order];
}
__syncthreads();

// Atomic add to shared memory
atomicAdd(&sm[circular_idx], v);
__syncthreads();

if (i == order - 1) {
padded_y[b * (T + order) + t + order] = sm[circular_idx];
}
__syncthreads();
}
}
// CUDA kernel for complex LPC computation
template <typename scalar_t>
__global__ void lpc_cuda_kernel_complex(
scalar_t* padded_y_real, // [B, T + order]
scalar_t* padded_y_imag, // [B, T + order]
const scalar_t* A_real, // [B, T, order]
const scalar_t* A_imag, // [B, T, order]
int64_t B, int64_t T, int64_t order) {
extern __shared__ char smem[];
scalar_t* sm_real = reinterpret_cast<scalar_t*>(smem);
scalar_t* sm_imag = sm_real + order;

int b = blockIdx.x;
int i = threadIdx.x;

if (b >= B || i >= order) return;

// Initialize shared memory with the first 'order' elements
sm_real[i] = padded_y_real[b * (T + order) + i];
sm_imag[i] = padded_y_imag[b * (T + order) + i];
__syncthreads();

int circular_idx = 0;
for (int t = 0; t < T; ++t) {
circular_idx = t % order;
scalar_t a_real = -A_real[((b * T + t) * order) + i];
scalar_t a_imag = -A_imag[((b * T + t) * order) + i];

int idx_offset = circular_idx - i - 1;
if (i > circular_idx - 1) {
idx_offset += order;
}
int s_idx = (idx_offset + order) % order;
scalar_t s_real = sm_real[s_idx];
scalar_t s_imag = sm_imag[s_idx];

// Complex multiply: v = a * s
scalar_t v_real = a_real * s_real - a_imag * s_imag;
scalar_t v_imag = a_real * s_imag + a_imag * s_real;

if (i == order - 1) {
sm_real[circular_idx] = v_real;
sm_imag[circular_idx] = v_imag;
v_real = padded_y_real[b * (T + order) + t + order];
v_imag = padded_y_imag[b * (T + order) + t + order];
}
__syncthreads();

atomicAdd(&sm_real[circular_idx], v_real);
atomicAdd(&sm_imag[circular_idx], v_imag);
__syncthreads();

if (i == order - 1) {
padded_y_real[b * (T + order) + t + order] = sm_real[circular_idx];
padded_y_imag[b * (T + order) + t + order] = sm_imag[circular_idx];
}
__syncthreads();
}
}

at::Tensor lpc_cuda_wrapper(const at::Tensor& x, const at::Tensor& a,
const at::Tensor& zi) {
TORCH_CHECK(x.is_floating_point() || x.is_complex(),
"Input must be floating point or complex");
TORCH_CHECK(a.scalar_type() == x.scalar_type(),
"Coefficients must have the same scalar type as input");
TORCH_CHECK(zi.scalar_type() == x.scalar_type(),
"Initial conditions must have the same scalar type as input");

TORCH_CHECK(x.dim() == 2, "Input must be 2D");
TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D");
TORCH_CHECK(x.size(0) == zi.size(0),
"Batch size of input and initial conditions must match");

const at::cuda::OptionalCUDAGuard device_guard(device_of(x));

auto a_contiguous = a.contiguous();

at::Tensor out;
auto order = a_contiguous.size(2);
assert(order <= 1024 && "LPC order must be less than or equal to 1024");
auto threads_per_block = order;

if (x.is_floating_point()) {
out = at::cat({zi.flip(1), x}, 1).contiguous();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "lpc_cuda", [&] {
auto padded_y = out.mutable_data_ptr<scalar_t>();
auto A = a_contiguous.const_data_ptr<scalar_t>();
auto B = x.size(0);
auto T = x.size(1);

lpc_cuda_kernel<scalar_t><<<B, threads_per_block,
threads_per_block * sizeof(scalar_t)>>>(
padded_y, A, B, T, order);
});
} else {
auto out_real =
at::cat({at::real(zi).flip(1), at::real(x)}, 1).contiguous();
auto out_imag =
at::cat({at::imag(zi).flip(1), at::imag(x)}, 1).contiguous();
auto a_real = at::real(a_contiguous).contiguous();
auto a_imag = at::imag(a_contiguous).contiguous();
AT_DISPATCH_FLOATING_TYPES(
out_real.scalar_type(), "lpc_cuda_complex", [&] {
auto padded_y_real = out_real.mutable_data_ptr<scalar_t>();
auto padded_y_imag = out_imag.mutable_data_ptr<scalar_t>();
auto A_real = a_real.const_data_ptr<scalar_t>();
auto A_imag = a_imag.const_data_ptr<scalar_t>();
auto B = x.size(0);
auto T = x.size(1);

lpc_cuda_kernel_complex<scalar_t>
<<<B, threads_per_block,
2 * threads_per_block * sizeof(scalar_t)>>>(
padded_y_real, padded_y_imag, A_real, A_imag, B, T,
order);
});
out = at::view_as_complex(at::stack({out_real, out_imag}, -1));
}
return out.slice(1, order, out.size(1)).contiguous();
}

TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("lpc", &lpc_cuda_wrapper); }
62 changes: 38 additions & 24 deletions torchlpc/recurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,48 @@
from .core import lpc_cuda, lpc_np
from . import EXTENSION_LOADED

if EXTENSION_LOADED:
lpc_cuda_runner = torch.ops.torchlpc.lpc
lpc_cpu_runner = torch.ops.torchlpc.lpc

scan_cuda_runner = torch.ops.torchlpc.scan
scan_cpu_runner = torch.ops.torchlpc.scan
else:
lpc_cuda_runner = lpc_cuda
lpc_cpu_runner = lambda x, A, zi: torch.from_numpy(
lpc_np(x.detach().numpy(), A.detach().numpy(), zi.detach().numpy())
)

scan_cuda_runner = lambda impulse, decay, initial_state: (
lambda out: (
out,
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
decay.shape[0],
decay.shape[1],
),
)
)(torch.empty_like(impulse))[0]
scan_cpu_runner = lambda impulse, decay, initial_state: torch.from_numpy(
lpc_np(
impulse.detach().numpy(),
-decay.unsqueeze(2).detach().numpy(),
initial_state.unsqueeze(1).detach().numpy(),
)
)


def _cuda_recurrence(
impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor
) -> torch.Tensor:
n_dims, n_steps = decay.shape
if n_dims * WARPSIZE < n_steps:
if EXTENSION_LOADED:
runner = torch.ops.torchlpc.scan
else:

def runner(impulse, decay, initial_state):
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
return out

runner = scan_cuda_runner
else:
runner = lambda impulse, decay, initial_state: lpc_cuda(
runner = lambda impulse, decay, initial_state: lpc_cuda_runner(
impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)
)
return runner(impulse, decay, initial_state)
Expand All @@ -44,14 +62,10 @@ def _cpu_recurrence(
n_dims, _ = decay.shape
# This is just a rough estimation of the computational cost
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
runner = torch.ops.torchlpc.scan
runner = scan_cpu_runner
else:
runner = lambda impulse, decay, initial_state: torch.from_numpy(
lpc_np(
impulse.detach().numpy(),
-decay.unsqueeze(2).detach().numpy(),
initial_state.unsqueeze(1).detach().numpy(),
)
runner = lambda impulse, decay, initial_state: lpc_cpu_runner(
impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1)
)
return runner(impulse, decay, initial_state)

Expand Down
Loading