Skip to content

[BugFix] Fix import error on non-blackwell machines #21020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,10 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
}

#endif

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function sm100_cutlass_mla_get_workspace_size is a host-side function that calculates a workspace size and does not involve any GPU operations. Registering it only for the CUDA dispatch key is incorrect and can lead to runtime errors if called in a context where the PyTorch dispatcher selects a different backend. Host-only functions like this should be registered for the CPU dispatch key to ensure they can be called correctly regardless of the context.

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
  m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CPU, m) {
  m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}


// clang-format on
13 changes: 0 additions & 13 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table, double scale);

void sm100_cutlass_mla_decode(
torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
torch::Tensor const& workspace, double sm_scale,
int64_t num_kv_splits =
1 /* Set to 1 to avoid cuda_graph issue by default. */);

int64_t sm100_cutlass_mla_get_workspace_size(
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
int64_t num_kv_splits =
1 /* Set to 1 to avoid cuda_graph issue by default. */);

torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

#ifndef USE_ROCM
Expand Down
5 changes: 2 additions & 3 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,15 +521,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor page_table, Tensor workspace, float "
"scale,"
" int num_kv_splits) -> ()");
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
// conditionally compiled so impl in source file

// SM100 CUTLASS MLA workspace
ops.def(
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
" int sm_count, int num_kv_splits) "
"-> int");
ops.impl("sm100_cutlass_mla_get_workspace_size",
&sm100_cutlass_mla_get_workspace_size);
// conditionally compiled so impl in source file

// Compute NVFP4 block quantized tensor.
ops.def(
Expand Down