Skip to content

Commit cdad4f4

Browse files
committed
[Attention] MLA - cutlass decode with unresticted num_heads
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent 9907fc4 commit cdad4f4

File tree

12 files changed

+3283
-2
lines changed

12 files changed

+3283
-2
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
563563
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
564564
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
565565
set(SRCS
566-
"csrc/attention/mla/cutlass_mla_kernels.cu")
566+
"csrc/attention/mla/cutlass_mla_kernels.cu"
567+
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
567568
set_gencode_flags_for_srcs(
568569
SRCS "${SRCS}"
569570
CUDA_ARCHS "${MLA_ARCHS}")
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
/*
33+
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
34+
* by Alcanderian JieXin Liang
35+
*/
36+
37+
/*!
38+
\file
39+
\brief An universal device layer for cutlass 3.x-style kernels.
40+
*/
41+
42+
// clang-format off
43+
#pragma once
44+
45+
// common
46+
#include "cutlass/cutlass.h"
47+
#include "cutlass/device_kernel.h"
48+
49+
#if !defined(__CUDACC_RTC__)
50+
#include "cutlass/cluster_launch.hpp"
51+
#include "cutlass/trace.h"
52+
#endif // !defined(__CUDACC_RTC__)
53+
54+
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
55+
#include "../kernel/sm100_fmha_mla_reduction.hpp"
56+
57+
////////////////////////////////////////////////////////////////////////////////
58+
59+
namespace cutlass::fmha::device {
60+
61+
using namespace cute;
62+
using namespace cutlass::fmha::kernel;
63+
64+
65+
////////////////////////////////////////////////////////////////////////////////
66+
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
67+
////////////////////////////////////////////////////////////////////////////////
68+
69+
template<
70+
class Kernel_
71+
>
72+
class MLA {
73+
public:
74+
75+
using Kernel = Kernel_;
76+
77+
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
78+
typename Kernel::ElementOut,
79+
typename Kernel::ElementAcc,
80+
typename Kernel::ElementAcc,
81+
Kernel::TileShapeH::value,
82+
Kernel::TileShapeL::value,
83+
256 /*Max split*/
84+
>;
85+
86+
/// Argument structure: User API
87+
using KernelArguments = typename Kernel::Arguments;
88+
using ReductionArguments = typename ReductionKernel::Arguments;
89+
90+
using Arguments = KernelArguments;
91+
92+
/// Argument structure: Kernel API
93+
using KernelParams = typename Kernel::Params;
94+
using ReductionParams = typename ReductionKernel::Params;
95+
struct Params {
96+
KernelParams fmha_params;
97+
ReductionParams reduction_params;
98+
};
99+
100+
private:
101+
102+
/// Kernel API parameters object
103+
Params params_;
104+
105+
bool is_initialized(bool set = false) {
106+
static bool initialized = false;
107+
if (set) initialized = true;
108+
return initialized;
109+
}
110+
111+
static ReductionArguments to_reduction_args(Arguments const& args) {
112+
auto [H, K, D, B] = args.problem_shape;
113+
return ReductionArguments{
114+
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
115+
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
116+
args.ptr_split_kv, Kernel::TileShapeS::value
117+
};
118+
}
119+
120+
public:
121+
122+
/// Access the Params structure
123+
Params const& params() const {
124+
return params_;
125+
}
126+
127+
static void set_split_kv (KernelArguments& args) {
128+
// printf("set_split_kv start");
129+
if (args.split_kv >= 1) return;
130+
auto [H, K, D, B] = args.problem_shape;
131+
// std::cout << H << " " << K << " " << D << " " << B << "\n";
132+
int sm_count = args.hw_info.sm_count;
133+
// printf(" sm_count = %d\n", sm_count);
134+
int max_splits = ceil_div(K, 128);
135+
max_splits = min(16, max_splits);
136+
// printf(" max_splits = %d\n", max_splits);
137+
int sms_per_batch = max(1, sm_count / B);
138+
// printf(" sms_per_batch = %d\n", sms_per_batch);
139+
int split_heur = min(max_splits, sms_per_batch);
140+
int waves = ceil_div(B * split_heur, sm_count);
141+
int k_waves = ceil_div(max_splits, split_heur);
142+
int split_wave_aware = ceil_div(max_splits, k_waves);
143+
args.split_kv = split_wave_aware;
144+
// printf(" args.split_kv = %d\n", args.split_kv);
145+
146+
}
147+
148+
/// Determines whether the GEMM can execute the given problem.
149+
static Status
150+
can_implement(Arguments const& args) {
151+
if (! Kernel::can_implement(args)) {
152+
return Status::kInvalid;
153+
}
154+
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
155+
return Status::kInvalid;
156+
}
157+
return Status::kSuccess;
158+
}
159+
160+
/// Gets the workspace size
161+
static size_t
162+
get_workspace_size(Arguments const& args) {
163+
size_t workspace_bytes = 0;
164+
workspace_bytes += Kernel::get_workspace_size(args);
165+
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
166+
return workspace_bytes;
167+
}
168+
169+
/// Computes the maximum number of active blocks per multiprocessor
170+
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
171+
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
172+
int max_active_blocks = -1;
173+
int smem_size = Kernel::SharedStorageSize;
174+
175+
// first, account for dynamic smem capacity if needed
176+
cudaError_t result;
177+
if (smem_size >= (48 << 10)) {
178+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
179+
result = cudaFuncSetAttribute(
180+
device_kernel<Kernel>,
181+
cudaFuncAttributeMaxDynamicSharedMemorySize,
182+
smem_size);
183+
if (cudaSuccess != result) {
184+
result = cudaGetLastError(); // to clear the error bit
185+
CUTLASS_TRACE_HOST(
186+
" cudaFuncSetAttribute() returned error: "
187+
<< cudaGetErrorString(result));
188+
return -1;
189+
}
190+
}
191+
192+
// query occupancy after setting smem size
193+
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
194+
&max_active_blocks,
195+
device_kernel<Kernel>,
196+
Kernel::MaxThreadsPerBlock,
197+
smem_size);
198+
199+
if (cudaSuccess != result) {
200+
result = cudaGetLastError(); // to clear the error bit
201+
CUTLASS_TRACE_HOST(
202+
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
203+
<< cudaGetErrorString(result));
204+
return -1;
205+
}
206+
207+
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
208+
return max_active_blocks;
209+
}
210+
211+
/// Initializes GEMM state from arguments.
212+
Status
213+
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
214+
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
215+
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
216+
217+
// Initialize the workspace
218+
Status status = Kernel::initialize_workspace(args, workspace, stream);
219+
if (status != Status::kSuccess) {
220+
return status;
221+
}
222+
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
223+
if (status != Status::kSuccess) {
224+
return status;
225+
}
226+
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
227+
228+
ReductionArguments reduction_args = to_reduction_args(args);
229+
if (reduction_args.split_kv > 1) {
230+
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
231+
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
232+
}
233+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
234+
// Initialize the Params structure
235+
params_ = Params {kernel_params, reduction_params};
236+
237+
if (is_initialized()) return Status::kSuccess;
238+
239+
// account for dynamic smem capacity if needed
240+
// no dynamic smem is needed for reduction kernel
241+
int smem_size = Kernel::SharedStorageSize;
242+
if (smem_size >= (48 << 10)) {
243+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
244+
cudaError_t result = cudaFuncSetAttribute(
245+
device_kernel<Kernel>,
246+
cudaFuncAttributeMaxDynamicSharedMemorySize,
247+
smem_size);
248+
if (cudaSuccess != result) {
249+
result = cudaGetLastError(); // to clear the error bit
250+
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
251+
return Status::kErrorInternal;
252+
}
253+
}
254+
255+
is_initialized(true);
256+
257+
return Status::kSuccess;
258+
}
259+
260+
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
261+
Status
262+
update(Arguments const& args, void* workspace = nullptr) {
263+
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
264+
265+
size_t workspace_bytes = get_workspace_size(args);
266+
if (workspace_bytes > 0 && nullptr == workspace) {
267+
return Status::kErrorWorkspaceNull;
268+
}
269+
270+
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
271+
272+
ReductionArguments reduction_args = to_reduction_args(args);
273+
if (reduction_args.split_kv > 1) {
274+
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
275+
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
276+
}
277+
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
278+
// Initialize the Params structure
279+
params_ = Params {fmha_params, reduction_params};
280+
281+
return Status::kSuccess;
282+
}
283+
284+
/// Primary run() entry point API that is static allowing users to create and manage their own params.
285+
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
286+
static Status
287+
run(Params& params, cudaStream_t stream = nullptr) {
288+
CUTLASS_TRACE_HOST("MLA::run()");
289+
dim3 const block = Kernel::get_block_shape();
290+
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
291+
292+
// configure smem size and carveout
293+
int smem_size = Kernel::SharedStorageSize;
294+
295+
Status launch_result;
296+
// Use extended launch API only for mainloops that use it
297+
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
298+
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
299+
cute::size<1>(typename Kernel::ClusterShape{}),
300+
cute::size<2>(typename Kernel::ClusterShape{}));
301+
void const* kernel = (void const*) device_kernel<Kernel>;
302+
void* kernel_params[] = {&params.fmha_params};
303+
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
304+
}
305+
else {
306+
launch_result = Status::kSuccess;
307+
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
308+
}
309+
310+
cudaError_t result = cudaGetLastError();
311+
if (cudaSuccess != result or Status::kSuccess != launch_result) {
312+
//return Status::kSuccess;
313+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
314+
return Status::kErrorInternal;
315+
}
316+
if (params.reduction_params.split_kv > 1) {
317+
// launch reduction kernel
318+
dim3 const block = ReductionKernel::get_block_shape();
319+
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
320+
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
321+
cudaError_t result = cudaGetLastError();
322+
if (cudaSuccess == result) {
323+
return Status::kSuccess;
324+
}
325+
else {
326+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
327+
return Status::kErrorInternal;
328+
}
329+
}
330+
else {
331+
return Status::kSuccess;
332+
}
333+
}
334+
335+
//
336+
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
337+
//
338+
339+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
340+
Status
341+
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
342+
Status status = initialize(args, workspace, stream);
343+
if (Status::kSuccess == status) {
344+
status = run(params_, stream);
345+
}
346+
return status;
347+
}
348+
349+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
350+
Status
351+
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
352+
return run(args, workspace, stream);
353+
}
354+
355+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
356+
Status
357+
run(cudaStream_t stream = nullptr) {
358+
return run(params_, stream);
359+
}
360+
361+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
362+
Status
363+
operator()(cudaStream_t stream = nullptr) {
364+
return run(params_, stream);
365+
}
366+
};
367+
368+
////////////////////////////////////////////////////////////////////////////////
369+
370+
} // namespace cutlass::fmha::device
371+
372+
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)