Skip to content

Commit 8a55a13

Browse files
committed
cutlass mla decode from sglang
1 parent fcd8856 commit 8a55a13

File tree

14 files changed

+3288
-2
lines changed

14 files changed

+3288
-2
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
563563
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
564564
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
565565
set(SRCS
566-
"csrc/attention/mla/cutlass_mla_kernels.cu")
566+
"csrc/attention/mla/cutlass_mla_kernels.cu"
567+
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
567568
set_gencode_flags_for_srcs(
568569
SRCS "${SRCS}"
569570
CUDA_ARCHS "${MLA_ARCHS}")
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
/*
33+
* Taken from SGLANG PR by Alcanderian JieXin Liang
34+
*/
35+
36+
/*!
37+
\file
38+
\brief An universal device layer for cutlass 3.x-style kernels.
39+
*/
40+
41+
// clang-format off
42+
#pragma once
43+
44+
// common
45+
#include "cutlass/cutlass.h"
46+
#include "cutlass/device_kernel.h"
47+
48+
#if !defined(__CUDACC_RTC__)
49+
#include "cutlass/cluster_launch.hpp"
50+
#include "cutlass/trace.h"
51+
#endif // !defined(__CUDACC_RTC__)
52+
53+
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
54+
#include "../kernel/sm100_fmha_mla_reduction.hpp"
55+
56+
////////////////////////////////////////////////////////////////////////////////
57+
58+
namespace cutlass::fmha::device {
59+
60+
using namespace cute;
61+
using namespace cutlass::fmha::kernel;
62+
63+
64+
////////////////////////////////////////////////////////////////////////////////
65+
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
66+
////////////////////////////////////////////////////////////////////////////////
67+
68+
template<
69+
class Kernel_
70+
>
71+
class MLA {
72+
public:
73+
74+
using Kernel = Kernel_;
75+
76+
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
77+
typename Kernel::ElementOut,
78+
typename Kernel::ElementAcc,
79+
typename Kernel::ElementAcc,
80+
Kernel::TileShapeH::value,
81+
Kernel::TileShapeL::value,
82+
256 /*Max split*/
83+
>;
84+
85+
/// Argument structure: User API
86+
using KernelArguments = typename Kernel::Arguments;
87+
using ReductionArguments = typename ReductionKernel::Arguments;
88+
89+
using Arguments = KernelArguments;
90+
91+
/// Argument structure: Kernel API
92+
using KernelParams = typename Kernel::Params;
93+
using ReductionParams = typename ReductionKernel::Params;
94+
struct Params {
95+
KernelParams fmha_params;
96+
ReductionParams reduction_params;
97+
};
98+
99+
private:
100+
101+
/// Kernel API parameters object
102+
Params params_;
103+
104+
bool is_initialized(bool set = false) {
105+
static bool initialized = false;
106+
if (set) initialized = true;
107+
return initialized;
108+
}
109+
110+
static ReductionArguments to_reduction_args(Arguments const& args) {
111+
auto [H, K, D, B] = args.problem_shape;
112+
return ReductionArguments{
113+
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
114+
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
115+
args.ptr_split_kv, Kernel::TileShapeS::value
116+
};
117+
}
118+
119+
public:
120+
121+
/// Access the Params structure
122+
Params const& params() const {
123+
return params_;
124+
}
125+
126+
static void set_split_kv (KernelArguments& args) {
127+
if (args.split_kv >= 1) return;
128+
auto [H, K, D, B] = args.problem_shape;
129+
int sm_count = args.hw_info.sm_count;
130+
int max_splits = ceil_div(K, 128);
131+
int sms_per_batch = max(1, sm_count / B);
132+
int split_heur = min(max_splits, sms_per_batch);
133+
int waves = ceil_div(B * split_heur, sm_count);
134+
int k_waves = ceil_div(max_splits, split_heur);
135+
int split_wave_aware = ceil_div(max_splits, k_waves);
136+
args.split_kv = split_wave_aware;
137+
}
138+
139+
/// Determines whether the GEMM can execute the given problem.
140+
static Status
141+
can_implement(Arguments const& args) {
142+
if (! Kernel::can_implement(args)) {
143+
return Status::kInvalid;
144+
}
145+
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
146+
return Status::kInvalid;
147+
}
148+
return Status::kSuccess;
149+
}
150+
151+
/// Gets the workspace size
152+
static size_t
153+
get_workspace_size(Arguments const& args) {
154+
size_t workspace_bytes = 0;
155+
workspace_bytes += Kernel::get_workspace_size(args);
156+
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
157+
return workspace_bytes;
158+
}
159+
160+
/// Computes the maximum number of active blocks per multiprocessor
161+
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
162+
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
163+
int max_active_blocks = -1;
164+
int smem_size = Kernel::SharedStorageSize;
165+
166+
// first, account for dynamic smem capacity if needed
167+
cudaError_t result;
168+
if (smem_size >= (48 << 10)) {
169+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
170+
result = cudaFuncSetAttribute(
171+
device_kernel<Kernel>,
172+
cudaFuncAttributeMaxDynamicSharedMemorySize,
173+
smem_size);
174+
if (cudaSuccess != result) {
175+
result = cudaGetLastError(); // to clear the error bit
176+
CUTLASS_TRACE_HOST(
177+
" cudaFuncSetAttribute() returned error: "
178+
<< cudaGetErrorString(result));
179+
return -1;
180+
}
181+
}
182+
183+
// query occupancy after setting smem size
184+
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
185+
&max_active_blocks,
186+
device_kernel<Kernel>,
187+
Kernel::MaxThreadsPerBlock,
188+
smem_size);
189+
190+
if (cudaSuccess != result) {
191+
result = cudaGetLastError(); // to clear the error bit
192+
CUTLASS_TRACE_HOST(
193+
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
194+
<< cudaGetErrorString(result));
195+
return -1;
196+
}
197+
198+
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
199+
return max_active_blocks;
200+
}
201+
202+
/// Initializes GEMM state from arguments.
203+
Status
204+
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
205+
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
206+
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
207+
208+
// Initialize the workspace
209+
Status status = Kernel::initialize_workspace(args, workspace, stream);
210+
if (status != Status::kSuccess) {
211+
return status;
212+
}
213+
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
214+
if (status != Status::kSuccess) {
215+
return status;
216+
}
217+
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
218+
219+
ReductionArguments reduction_args = to_reduction_args(args);
220+
if (reduction_args.split_kv > 1) {
221+
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
222+
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
223+
}
224+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
225+
// Initialize the Params structure
226+
params_ = Params {kernel_params, reduction_params};
227+
228+
if (is_initialized()) return Status::kSuccess;
229+
230+
// account for dynamic smem capacity if needed
231+
// no dynamic smem is needed for reduction kernel
232+
int smem_size = Kernel::SharedStorageSize;
233+
if (smem_size >= (48 << 10)) {
234+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
235+
cudaError_t result = cudaFuncSetAttribute(
236+
device_kernel<Kernel>,
237+
cudaFuncAttributeMaxDynamicSharedMemorySize,
238+
smem_size);
239+
if (cudaSuccess != result) {
240+
result = cudaGetLastError(); // to clear the error bit
241+
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
242+
return Status::kErrorInternal;
243+
}
244+
}
245+
246+
is_initialized(true);
247+
248+
return Status::kSuccess;
249+
}
250+
251+
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
252+
Status
253+
update(Arguments const& args, void* workspace = nullptr) {
254+
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
255+
256+
size_t workspace_bytes = get_workspace_size(args);
257+
if (workspace_bytes > 0 && nullptr == workspace) {
258+
return Status::kErrorWorkspaceNull;
259+
}
260+
261+
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
262+
263+
ReductionArguments reduction_args = to_reduction_args(args);
264+
if (reduction_args.split_kv > 1) {
265+
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
266+
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
267+
}
268+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
269+
// Initialize the Params structure
270+
params_ = Params {fmha_params, reduction_params};
271+
272+
return Status::kSuccess;
273+
}
274+
275+
/// Primary run() entry point API that is static allowing users to create and manage their own params.
276+
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
277+
static Status
278+
run(Params& params, cudaStream_t stream = nullptr) {
279+
CUTLASS_TRACE_HOST("MLA::run()");
280+
dim3 const block = Kernel::get_block_shape();
281+
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
282+
283+
// configure smem size and carveout
284+
int smem_size = Kernel::SharedStorageSize;
285+
286+
Status launch_result;
287+
// Use extended launch API only for mainloops that use it
288+
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
289+
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
290+
cute::size<1>(typename Kernel::ClusterShape{}),
291+
cute::size<2>(typename Kernel::ClusterShape{}));
292+
void const* kernel = (void const*) device_kernel<Kernel>;
293+
void* kernel_params[] = {&params.fmha_params};
294+
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
295+
}
296+
else {
297+
launch_result = Status::kSuccess;
298+
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
299+
}
300+
301+
cudaError_t result = cudaGetLastError();
302+
if (cudaSuccess != result or Status::kSuccess != launch_result) {
303+
//return Status::kSuccess;
304+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
305+
return Status::kErrorInternal;
306+
}
307+
if (params.reduction_params.split_kv > 1) {
308+
// launch reduction kernel
309+
dim3 const block = ReductionKernel::get_block_shape();
310+
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
311+
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
312+
cudaError_t result = cudaGetLastError();
313+
if (cudaSuccess == result) {
314+
return Status::kSuccess;
315+
}
316+
else {
317+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
318+
return Status::kErrorInternal;
319+
}
320+
}
321+
else {
322+
return Status::kSuccess;
323+
}
324+
}
325+
326+
//
327+
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
328+
//
329+
330+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
331+
Status
332+
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
333+
Status status = initialize(args, workspace, stream);
334+
if (Status::kSuccess == status) {
335+
status = run(params_, stream);
336+
}
337+
return status;
338+
}
339+
340+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
341+
Status
342+
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
343+
return run(args, workspace, stream);
344+
}
345+
346+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
347+
Status
348+
run(cudaStream_t stream = nullptr) {
349+
return run(params_, stream);
350+
}
351+
352+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
353+
Status
354+
operator()(cudaStream_t stream = nullptr) {
355+
return run(params_, stream);
356+
}
357+
};
358+
359+
////////////////////////////////////////////////////////////////////////////////
360+
361+
} // namespace cutlass::fmha::device
362+
363+
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)