Skip to content

Commit 70870d9

Browse files
committed
[Attention] MLA - cutlass decode with unresticted num_heads
1 parent 9907fc4 commit 70870d9

File tree

13 files changed

+3301
-1
lines changed

13 files changed

+3301
-1
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
563563
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
564564
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
565565
set(SRCS
566-
"csrc/attention/mla/cutlass_mla_kernels.cu")
566+
"csrc/attention/mla/cutlass_mla_kernels.cu"
567+
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
567568
set_gencode_flags_for_srcs(
568569
SRCS "${SRCS}"
569570
CUDA_ARCHS "${MLA_ARCHS}")
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
/*
33+
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
34+
* by Alcanderian JieXin Liang
35+
*/
36+
37+
/*!
38+
\file
39+
\brief An universal device layer for cutlass 3.x-style kernels.
40+
*/
41+
42+
// clang-format off
43+
#pragma once
44+
45+
// common
46+
#include "cutlass/cutlass.h"
47+
#include "cutlass/device_kernel.h"
48+
49+
#if !defined(__CUDACC_RTC__)
50+
#include "cutlass/cluster_launch.hpp"
51+
#include "cutlass/trace.h"
52+
#endif // !defined(__CUDACC_RTC__)
53+
54+
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
55+
#include "../kernel/sm100_fmha_mla_reduction.hpp"
56+
57+
////////////////////////////////////////////////////////////////////////////////
58+
59+
namespace cutlass::fmha::device {
60+
61+
using namespace cute;
62+
using namespace cutlass::fmha::kernel;
63+
64+
65+
////////////////////////////////////////////////////////////////////////////////
66+
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
67+
////////////////////////////////////////////////////////////////////////////////
68+
69+
template<
70+
class Kernel_
71+
>
72+
class MLA {
73+
public:
74+
75+
using Kernel = Kernel_;
76+
77+
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
78+
typename Kernel::ElementOut,
79+
typename Kernel::ElementAcc,
80+
typename Kernel::ElementAcc,
81+
Kernel::TileShapeH::value,
82+
Kernel::TileShapeL::value,
83+
256 /*Max split*/
84+
>;
85+
86+
/// Argument structure: User API
87+
using KernelArguments = typename Kernel::Arguments;
88+
using ReductionArguments = typename ReductionKernel::Arguments;
89+
90+
using Arguments = KernelArguments;
91+
92+
/// Argument structure: Kernel API
93+
using KernelParams = typename Kernel::Params;
94+
using ReductionParams = typename ReductionKernel::Params;
95+
struct Params {
96+
KernelParams fmha_params;
97+
ReductionParams reduction_params;
98+
};
99+
100+
private:
101+
102+
/// Kernel API parameters object
103+
Params params_;
104+
105+
bool is_initialized(bool set = false) {
106+
static bool initialized = false;
107+
if (set) initialized = true;
108+
return initialized;
109+
}
110+
111+
static ReductionArguments to_reduction_args(Arguments const& args) {
112+
auto [H, K, D, B] = args.problem_shape;
113+
return ReductionArguments{
114+
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
115+
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
116+
args.ptr_split_kv, Kernel::TileShapeS::value
117+
};
118+
}
119+
120+
public:
121+
122+
/// Access the Params structure
123+
Params const& params() const {
124+
return params_;
125+
}
126+
127+
static void set_split_kv (KernelArguments& args) {
128+
if (args.split_kv >= 1) return;
129+
auto [H, K, D, B] = args.problem_shape;
130+
int sm_count = args.hw_info.sm_count;
131+
int max_splits = ceil_div(K, 128);
132+
int sms_per_batch = max(1, sm_count / B);
133+
int split_heur = min(max_splits, sms_per_batch);
134+
int waves = ceil_div(B * split_heur, sm_count);
135+
int k_waves = ceil_div(max_splits, split_heur);
136+
int split_wave_aware = ceil_div(max_splits, k_waves);
137+
args.split_kv = split_wave_aware;
138+
}
139+
140+
/// Determines whether the GEMM can execute the given problem.
141+
static Status
142+
can_implement(Arguments const& args) {
143+
if (! Kernel::can_implement(args)) {
144+
return Status::kInvalid;
145+
}
146+
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
147+
return Status::kInvalid;
148+
}
149+
return Status::kSuccess;
150+
}
151+
152+
/// Gets the workspace size
153+
static size_t
154+
get_workspace_size(Arguments const& args) {
155+
size_t workspace_bytes = 0;
156+
workspace_bytes += Kernel::get_workspace_size(args);
157+
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
158+
return workspace_bytes;
159+
}
160+
161+
/// Computes the maximum number of active blocks per multiprocessor
162+
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
163+
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
164+
int max_active_blocks = -1;
165+
int smem_size = Kernel::SharedStorageSize;
166+
167+
// first, account for dynamic smem capacity if needed
168+
cudaError_t result;
169+
if (smem_size >= (48 << 10)) {
170+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
171+
result = cudaFuncSetAttribute(
172+
device_kernel<Kernel>,
173+
cudaFuncAttributeMaxDynamicSharedMemorySize,
174+
smem_size);
175+
if (cudaSuccess != result) {
176+
result = cudaGetLastError(); // to clear the error bit
177+
CUTLASS_TRACE_HOST(
178+
" cudaFuncSetAttribute() returned error: "
179+
<< cudaGetErrorString(result));
180+
return -1;
181+
}
182+
}
183+
184+
// query occupancy after setting smem size
185+
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
186+
&max_active_blocks,
187+
device_kernel<Kernel>,
188+
Kernel::MaxThreadsPerBlock,
189+
smem_size);
190+
191+
if (cudaSuccess != result) {
192+
result = cudaGetLastError(); // to clear the error bit
193+
CUTLASS_TRACE_HOST(
194+
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
195+
<< cudaGetErrorString(result));
196+
return -1;
197+
}
198+
199+
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
200+
return max_active_blocks;
201+
}
202+
203+
/// Initializes GEMM state from arguments.
204+
Status
205+
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
206+
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
207+
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
208+
209+
// Initialize the workspace
210+
Status status = Kernel::initialize_workspace(args, workspace, stream);
211+
if (status != Status::kSuccess) {
212+
return status;
213+
}
214+
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
215+
if (status != Status::kSuccess) {
216+
return status;
217+
}
218+
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
219+
220+
ReductionArguments reduction_args = to_reduction_args(args);
221+
if (reduction_args.split_kv > 1) {
222+
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
223+
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
224+
}
225+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
226+
// Initialize the Params structure
227+
params_ = Params {kernel_params, reduction_params};
228+
229+
if (is_initialized()) return Status::kSuccess;
230+
231+
// account for dynamic smem capacity if needed
232+
// no dynamic smem is needed for reduction kernel
233+
int smem_size = Kernel::SharedStorageSize;
234+
if (smem_size >= (48 << 10)) {
235+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
236+
cudaError_t result = cudaFuncSetAttribute(
237+
device_kernel<Kernel>,
238+
cudaFuncAttributeMaxDynamicSharedMemorySize,
239+
smem_size);
240+
if (cudaSuccess != result) {
241+
result = cudaGetLastError(); // to clear the error bit
242+
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
243+
return Status::kErrorInternal;
244+
}
245+
}
246+
247+
is_initialized(true);
248+
249+
return Status::kSuccess;
250+
}
251+
252+
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
253+
Status
254+
update(Arguments const& args, void* workspace = nullptr) {
255+
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
256+
257+
size_t workspace_bytes = get_workspace_size(args);
258+
if (workspace_bytes > 0 && nullptr == workspace) {
259+
return Status::kErrorWorkspaceNull;
260+
}
261+
262+
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
263+
264+
ReductionArguments reduction_args = to_reduction_args(args);
265+
if (reduction_args.split_kv > 1) {
266+
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
267+
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
268+
}
269+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
270+
// Initialize the Params structure
271+
params_ = Params {fmha_params, reduction_params};
272+
273+
return Status::kSuccess;
274+
}
275+
276+
/// Primary run() entry point API that is static allowing users to create and manage their own params.
277+
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
278+
static Status
279+
run(Params& params, cudaStream_t stream = nullptr) {
280+
CUTLASS_TRACE_HOST("MLA::run()");
281+
dim3 const block = Kernel::get_block_shape();
282+
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
283+
284+
// configure smem size and carveout
285+
int smem_size = Kernel::SharedStorageSize;
286+
287+
Status launch_result;
288+
// Use extended launch API only for mainloops that use it
289+
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
290+
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
291+
cute::size<1>(typename Kernel::ClusterShape{}),
292+
cute::size<2>(typename Kernel::ClusterShape{}));
293+
void const* kernel = (void const*) device_kernel<Kernel>;
294+
void* kernel_params[] = {&params.fmha_params};
295+
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
296+
}
297+
else {
298+
launch_result = Status::kSuccess;
299+
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
300+
}
301+
302+
cudaError_t result = cudaGetLastError();
303+
if (cudaSuccess != result or Status::kSuccess != launch_result) {
304+
//return Status::kSuccess;
305+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
306+
return Status::kErrorInternal;
307+
}
308+
if (params.reduction_params.split_kv > 1) {
309+
// launch reduction kernel
310+
dim3 const block = ReductionKernel::get_block_shape();
311+
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
312+
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
313+
cudaError_t result = cudaGetLastError();
314+
if (cudaSuccess == result) {
315+
return Status::kSuccess;
316+
}
317+
else {
318+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
319+
return Status::kErrorInternal;
320+
}
321+
}
322+
else {
323+
return Status::kSuccess;
324+
}
325+
}
326+
327+
//
328+
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
329+
//
330+
331+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
332+
Status
333+
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
334+
Status status = initialize(args, workspace, stream);
335+
if (Status::kSuccess == status) {
336+
status = run(params_, stream);
337+
}
338+
return status;
339+
}
340+
341+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
342+
Status
343+
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
344+
return run(args, workspace, stream);
345+
}
346+
347+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
348+
Status
349+
run(cudaStream_t stream = nullptr) {
350+
return run(params_, stream);
351+
}
352+
353+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
354+
Status
355+
operator()(cudaStream_t stream = nullptr) {
356+
return run(params_, stream);
357+
}
358+
};
359+
360+
////////////////////////////////////////////////////////////////////////////////
361+
362+
} // namespace cutlass::fmha::device
363+
364+
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)