Skip to content

Commit 528d3c8

Browse files
committed
cutlass mla decode from sglang
1 parent fcd8856 commit 528d3c8

File tree

14 files changed

+3269
-2
lines changed

14 files changed

+3269
-2
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
563563
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
564564
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
565565
set(SRCS
566-
"csrc/attention/mla/cutlass_mla_kernels.cu")
566+
"csrc/attention/mla/cutlass_mla_kernels.cu"
567+
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
567568
set_gencode_flags_for_srcs(
568569
SRCS "${SRCS}"
569570
CUDA_ARCHS "${MLA_ARCHS}")
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
/*!
33+
\file
34+
\brief An universal device layer for cutlass 3.x-style kernels.
35+
*/
36+
37+
// clang-format off
38+
#pragma once
39+
40+
// common
41+
#include "cutlass/cutlass.h"
42+
#include "cutlass/device_kernel.h"
43+
44+
#if !defined(__CUDACC_RTC__)
45+
#include "cutlass/cluster_launch.hpp"
46+
#include "cutlass/trace.h"
47+
#endif // !defined(__CUDACC_RTC__)
48+
49+
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
50+
#include "../kernel/sm100_fmha_mla_reduction.hpp"
51+
52+
////////////////////////////////////////////////////////////////////////////////
53+
54+
namespace cutlass::fmha::device {
55+
56+
using namespace cute;
57+
using namespace cutlass::fmha::kernel;
58+
59+
60+
////////////////////////////////////////////////////////////////////////////////
61+
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
62+
////////////////////////////////////////////////////////////////////////////////
63+
64+
template<
65+
class Kernel_
66+
>
67+
class MLA {
68+
public:
69+
70+
using Kernel = Kernel_;
71+
72+
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
73+
typename Kernel::ElementOut,
74+
typename Kernel::ElementAcc,
75+
typename Kernel::ElementAcc,
76+
Kernel::TileShapeH::value,
77+
Kernel::TileShapeL::value,
78+
256 /*Max split*/
79+
>;
80+
81+
/// Argument structure: User API
82+
using KernelArguments = typename Kernel::Arguments;
83+
using ReductionArguments = typename ReductionKernel::Arguments;
84+
85+
using Arguments = KernelArguments;
86+
87+
/// Argument structure: Kernel API
88+
using KernelParams = typename Kernel::Params;
89+
using ReductionParams = typename ReductionKernel::Params;
90+
struct Params {
91+
KernelParams fmha_params;
92+
ReductionParams reduction_params;
93+
};
94+
95+
private:
96+
97+
/// Kernel API parameters object
98+
Params params_;
99+
100+
bool is_initialized(bool set = false) {
101+
static bool initialized = false;
102+
if (set) initialized = true;
103+
return initialized;
104+
}
105+
106+
static ReductionArguments to_reduction_args(Arguments const& args) {
107+
auto [H, K, D, B] = args.problem_shape;
108+
return ReductionArguments{
109+
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
110+
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
111+
args.ptr_split_kv, Kernel::TileShapeS::value
112+
};
113+
}
114+
115+
public:
116+
117+
/// Access the Params structure
118+
Params const& params() const {
119+
return params_;
120+
}
121+
122+
static void set_split_kv (KernelArguments& args) {
123+
if (args.split_kv >= 1) return;
124+
auto [H, K, D, B] = args.problem_shape;
125+
int sm_count = args.hw_info.sm_count;
126+
int max_splits = ceil_div(K, 128);
127+
int sms_per_batch = max(1, sm_count / B);
128+
int split_heur = min(max_splits, sms_per_batch);
129+
int waves = ceil_div(B * split_heur, sm_count);
130+
int k_waves = ceil_div(max_splits, split_heur);
131+
int split_wave_aware = ceil_div(max_splits, k_waves);
132+
args.split_kv = split_wave_aware;
133+
}
134+
135+
/// Determines whether the GEMM can execute the given problem.
136+
static Status
137+
can_implement(Arguments const& args) {
138+
if (! Kernel::can_implement(args)) {
139+
return Status::kInvalid;
140+
}
141+
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
142+
return Status::kInvalid;
143+
}
144+
return Status::kSuccess;
145+
}
146+
147+
/// Gets the workspace size
148+
static size_t
149+
get_workspace_size(Arguments const& args) {
150+
size_t workspace_bytes = 0;
151+
workspace_bytes += Kernel::get_workspace_size(args);
152+
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
153+
return workspace_bytes;
154+
}
155+
156+
/// Computes the maximum number of active blocks per multiprocessor
157+
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
158+
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
159+
int max_active_blocks = -1;
160+
int smem_size = Kernel::SharedStorageSize;
161+
162+
// first, account for dynamic smem capacity if needed
163+
cudaError_t result;
164+
if (smem_size >= (48 << 10)) {
165+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
166+
result = cudaFuncSetAttribute(
167+
device_kernel<Kernel>,
168+
cudaFuncAttributeMaxDynamicSharedMemorySize,
169+
smem_size);
170+
if (cudaSuccess != result) {
171+
result = cudaGetLastError(); // to clear the error bit
172+
CUTLASS_TRACE_HOST(
173+
" cudaFuncSetAttribute() returned error: "
174+
<< cudaGetErrorString(result));
175+
return -1;
176+
}
177+
}
178+
179+
// query occupancy after setting smem size
180+
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
181+
&max_active_blocks,
182+
device_kernel<Kernel>,
183+
Kernel::MaxThreadsPerBlock,
184+
smem_size);
185+
186+
if (cudaSuccess != result) {
187+
result = cudaGetLastError(); // to clear the error bit
188+
CUTLASS_TRACE_HOST(
189+
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
190+
<< cudaGetErrorString(result));
191+
return -1;
192+
}
193+
194+
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
195+
return max_active_blocks;
196+
}
197+
198+
/// Initializes GEMM state from arguments.
199+
Status
200+
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
201+
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
202+
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
203+
204+
// Initialize the workspace
205+
Status status = Kernel::initialize_workspace(args, workspace, stream);
206+
if (status != Status::kSuccess) {
207+
return status;
208+
}
209+
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
210+
if (status != Status::kSuccess) {
211+
return status;
212+
}
213+
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
214+
215+
ReductionArguments reduction_args = to_reduction_args(args);
216+
if (reduction_args.split_kv > 1) {
217+
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
218+
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
219+
}
220+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
221+
// Initialize the Params structure
222+
params_ = Params {kernel_params, reduction_params};
223+
224+
if (is_initialized()) return Status::kSuccess;
225+
226+
// account for dynamic smem capacity if needed
227+
// no dynamic smem is needed for reduction kernel
228+
int smem_size = Kernel::SharedStorageSize;
229+
if (smem_size >= (48 << 10)) {
230+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
231+
cudaError_t result = cudaFuncSetAttribute(
232+
device_kernel<Kernel>,
233+
cudaFuncAttributeMaxDynamicSharedMemorySize,
234+
smem_size);
235+
if (cudaSuccess != result) {
236+
result = cudaGetLastError(); // to clear the error bit
237+
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
238+
return Status::kErrorInternal;
239+
}
240+
}
241+
242+
is_initialized(true);
243+
244+
return Status::kSuccess;
245+
}
246+
247+
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
248+
Status
249+
update(Arguments const& args, void* workspace = nullptr) {
250+
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
251+
252+
size_t workspace_bytes = get_workspace_size(args);
253+
if (workspace_bytes > 0 && nullptr == workspace) {
254+
return Status::kErrorWorkspaceNull;
255+
}
256+
257+
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
258+
259+
ReductionArguments reduction_args = to_reduction_args(args);
260+
if (reduction_args.split_kv > 1) {
261+
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
262+
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
263+
}
264+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
265+
// Initialize the Params structure
266+
params_ = Params {fmha_params, reduction_params};
267+
268+
return Status::kSuccess;
269+
}
270+
271+
/// Primary run() entry point API that is static allowing users to create and manage their own params.
272+
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
273+
static Status
274+
run(Params& params, cudaStream_t stream = nullptr) {
275+
CUTLASS_TRACE_HOST("MLA::run()");
276+
dim3 const block = Kernel::get_block_shape();
277+
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
278+
279+
// configure smem size and carveout
280+
int smem_size = Kernel::SharedStorageSize;
281+
282+
Status launch_result;
283+
// Use extended launch API only for mainloops that use it
284+
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
285+
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
286+
cute::size<1>(typename Kernel::ClusterShape{}),
287+
cute::size<2>(typename Kernel::ClusterShape{}));
288+
void const* kernel = (void const*) device_kernel<Kernel>;
289+
void* kernel_params[] = {&params.fmha_params};
290+
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
291+
}
292+
else {
293+
launch_result = Status::kSuccess;
294+
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
295+
}
296+
297+
cudaError_t result = cudaGetLastError();
298+
if (cudaSuccess != result or Status::kSuccess != launch_result) {
299+
//return Status::kSuccess;
300+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
301+
return Status::kErrorInternal;
302+
}
303+
if (params.reduction_params.split_kv > 1) {
304+
// launch reduction kernel
305+
dim3 const block = ReductionKernel::get_block_shape();
306+
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
307+
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
308+
cudaError_t result = cudaGetLastError();
309+
if (cudaSuccess == result) {
310+
return Status::kSuccess;
311+
}
312+
else {
313+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
314+
return Status::kErrorInternal;
315+
}
316+
}
317+
else {
318+
return Status::kSuccess;
319+
}
320+
}
321+
322+
//
323+
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
324+
//
325+
326+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
327+
Status
328+
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
329+
Status status = initialize(args, workspace, stream);
330+
if (Status::kSuccess == status) {
331+
status = run(params_, stream);
332+
}
333+
return status;
334+
}
335+
336+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
337+
Status
338+
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
339+
return run(args, workspace, stream);
340+
}
341+
342+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
343+
Status
344+
run(cudaStream_t stream = nullptr) {
345+
return run(params_, stream);
346+
}
347+
348+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
349+
Status
350+
operator()(cudaStream_t stream = nullptr) {
351+
return run(params_, stream);
352+
}
353+
};
354+
355+
////////////////////////////////////////////////////////////////////////////////
356+
357+
} // namespace cutlass::fmha::device
358+
359+
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)