Skip to content

Commit d31a647

Browse files
[BugFix] Fix import error on non-blackwell machines (#21020)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 85431bd commit d31a647

File tree

3 files changed

+12
-16
lines changed

3 files changed

+12
-16
lines changed

csrc/attention/mla/sm100_cutlass_mla_kernel.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
1919
* by Alcanderian JieXin Liang
2020
*/
21+
#include "core/registration.h"
2122

2223
#include <ATen/cuda/CUDAContext.h>
2324
#include <c10/cuda/CUDAGuard.h>
@@ -270,4 +271,13 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
270271
}
271272

272273
#endif
274+
275+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
276+
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
277+
}
278+
279+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
280+
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
281+
}
282+
273283
// clang-format on

csrc/ops.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
167167
torch::Tensor const& seq_lens,
168168
torch::Tensor const& page_table, double scale);
169169

170-
void sm100_cutlass_mla_decode(
171-
torch::Tensor const& out, torch::Tensor const& q_nope,
172-
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
173-
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
174-
torch::Tensor const& workspace, double sm_scale,
175-
int64_t num_kv_splits =
176-
1 /* Set to 1 to avoid cuda_graph issue by default. */);
177-
178-
int64_t sm100_cutlass_mla_get_workspace_size(
179-
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
180-
int64_t num_kv_splits =
181-
1 /* Set to 1 to avoid cuda_graph issue by default. */);
182-
183170
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
184171

185172
#ifndef USE_ROCM

csrc/torch_bindings.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
521521
" Tensor page_table, Tensor workspace, float "
522522
"scale,"
523523
" int num_kv_splits) -> ()");
524-
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
524+
// conditionally compiled so impl in source file
525525

526526
// SM100 CUTLASS MLA workspace
527527
ops.def(
528528
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
529529
" int sm_count, int num_kv_splits) "
530530
"-> int");
531-
ops.impl("sm100_cutlass_mla_get_workspace_size",
532-
&sm100_cutlass_mla_get_workspace_size);
531+
// conditionally compiled so impl in source file
533532

534533
// Compute NVFP4 block quantized tensor.
535534
ops.def(

0 commit comments

Comments
 (0)