From f489caa72d4b263771969a503f3deaab5a588bf0 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Jul 2025 01:32:17 +0000 Subject: [PATCH 1/4] fix linking error on non-blackwell devices Signed-off-by: Lucas Wilkinson --- csrc/attention/mla/sm100_cutlass_mla_kernel.cu | 6 ++++++ csrc/ops.h | 13 ------------- csrc/torch_bindings.cpp | 5 ++--- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 0d57ff4cc7c..8d2caa62ef7 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -270,4 +270,10 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba } #endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); + m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); +} + // clang-format on diff --git a/csrc/ops.h b/csrc/ops.h index 20ad163dc0d..7f3e6b6923a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& seq_lens, torch::Tensor const& page_table, double scale); -void sm100_cutlass_mla_decode( - torch::Tensor const& out, torch::Tensor const& q_nope, - torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, torch::Tensor const& page_table, - torch::Tensor const& workspace, double sm_scale, - int64_t num_kv_splits = - 1 /* Set to 1 to avoid cuda_graph issue by default. */); - -int64_t sm100_cutlass_mla_get_workspace_size( - int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, - int64_t num_kv_splits = - 1 /* Set to 1 to avoid cuda_graph issue by default. */); - torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor); #ifndef USE_ROCM diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 370edc20149..e2cb0012d0d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor page_table, Tensor workspace, float " "scale," " int num_kv_splits) -> ()"); - ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode); + // conditionally compiled so impl in source file // SM100 CUTLASS MLA workspace ops.def( "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," " int sm_count, int num_kv_splits) " "-> int"); - ops.impl("sm100_cutlass_mla_get_workspace_size", - &sm100_cutlass_mla_get_workspace_size); + // conditionally compiled so impl in source file // Compute NVFP4 block quantized tensor. ops.def( From 4bb6e6e71254f340aedc87ead74f71e5bbf9a980 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Jul 2025 01:32:45 +0000 Subject: [PATCH 2/4] format Signed-off-by: Lucas Wilkinson --- csrc/torch_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e2cb0012d0d..23e9212a2f1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -521,14 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor page_table, Tensor workspace, float " "scale," " int num_kv_splits) -> ()"); - // conditionally compiled so impl in source file + // conditionally compiled so impl in source file // SM100 CUTLASS MLA workspace ops.def( "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," " int sm_count, int num_kv_splits) " "-> int"); - // conditionally compiled so impl in source file + // conditionally compiled so impl in source file // Compute NVFP4 block quantized tensor. ops.def( From 7bfe03202cc830c7e94faffda11053a8f503f2db Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 15 Jul 2025 21:54:21 -0400 Subject: [PATCH 3/4] fix build Signed-off-by: Lucas Wilkinson --- csrc/attention/mla/sm100_cutlass_mla_kernel.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 8d2caa62ef7..80e2677f6d1 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -18,6 +18,7 @@ limitations under the License. * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 * by Alcanderian JieXin Liang */ +#include "core/registration.h" #include #include @@ -273,6 +274,9 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CPU, m) { m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); } From a41432f14c5ca7a5c6a12eb1d4239d165c58c2d3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 15 Jul 2025 21:59:02 -0400 Subject: [PATCH 4/4] no tensor args so use catchall Signed-off-by: Lucas Wilkinson --- csrc/attention/mla/sm100_cutlass_mla_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 80e2677f6d1..e0e95d06290 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -276,7 +276,7 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CPU, m) { +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); }