|
| 1 | +/*************************************************************************************************** |
| 2 | + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | + * SPDX-License-Identifier: BSD-3-Clause |
| 4 | + * |
| 5 | + * Redistribution and use in source and binary forms, with or without |
| 6 | + * modification, are permitted provided that the following conditions are met: |
| 7 | + * |
| 8 | + * 1. Redistributions of source code must retain the above copyright notice, |
| 9 | + *this list of conditions and the following disclaimer. |
| 10 | + * |
| 11 | + * 2. Redistributions in binary form must reproduce the above copyright notice, |
| 12 | + * this list of conditions and the following disclaimer in the documentation |
| 13 | + * and/or other materials provided with the distribution. |
| 14 | + * |
| 15 | + * 3. Neither the name of the copyright holder nor the names of its |
| 16 | + * contributors may be used to endorse or promote products derived from |
| 17 | + * this software without specific prior written permission. |
| 18 | + * |
| 19 | + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 20 | + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 21 | + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 22 | + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| 23 | + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 24 | + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 25 | + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 26 | + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 27 | + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 28 | + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE |
| 29 | + *POSSIBILITY OF SUCH DAMAGE. |
| 30 | + * |
| 31 | + **************************************************************************************************/ |
| 32 | +/* |
| 33 | + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 |
| 34 | + * by Alcanderian JieXin Liang |
| 35 | + */ |
| 36 | + |
| 37 | +/*! |
| 38 | + \file |
| 39 | + \brief An universal device layer for cutlass 3.x-style kernels. |
| 40 | +*/ |
| 41 | + |
| 42 | +// clang-format off |
| 43 | +#pragma once |
| 44 | + |
| 45 | +// common |
| 46 | +#include "cutlass/cutlass.h" |
| 47 | +#include "cutlass/device_kernel.h" |
| 48 | + |
| 49 | +#if !defined(__CUDACC_RTC__) |
| 50 | +#include "cutlass/cluster_launch.hpp" |
| 51 | +#include "cutlass/trace.h" |
| 52 | +#endif // !defined(__CUDACC_RTC__) |
| 53 | + |
| 54 | +#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp" |
| 55 | +#include "../kernel/sm100_fmha_mla_reduction.hpp" |
| 56 | + |
| 57 | +//////////////////////////////////////////////////////////////////////////////// |
| 58 | + |
| 59 | +namespace cutlass::fmha::device { |
| 60 | + |
| 61 | +using namespace cute; |
| 62 | +using namespace cutlass::fmha::kernel; |
| 63 | + |
| 64 | + |
| 65 | +//////////////////////////////////////////////////////////////////////////////// |
| 66 | +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// |
| 67 | +//////////////////////////////////////////////////////////////////////////////// |
| 68 | + |
| 69 | +template< |
| 70 | + class Kernel_ |
| 71 | +> |
| 72 | +class MLA { |
| 73 | +public: |
| 74 | + |
| 75 | + using Kernel = Kernel_; |
| 76 | + |
| 77 | + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< |
| 78 | + typename Kernel::ElementOut, |
| 79 | + typename Kernel::ElementAcc, |
| 80 | + typename Kernel::ElementAcc, |
| 81 | + Kernel::TileShapeH::value, |
| 82 | + Kernel::TileShapeL::value, |
| 83 | + 256 /*Max split*/ |
| 84 | + >; |
| 85 | + |
| 86 | + /// Argument structure: User API |
| 87 | + using KernelArguments = typename Kernel::Arguments; |
| 88 | + using ReductionArguments = typename ReductionKernel::Arguments; |
| 89 | + |
| 90 | + using Arguments = KernelArguments; |
| 91 | + |
| 92 | + /// Argument structure: Kernel API |
| 93 | + using KernelParams = typename Kernel::Params; |
| 94 | + using ReductionParams = typename ReductionKernel::Params; |
| 95 | + struct Params { |
| 96 | + KernelParams fmha_params; |
| 97 | + ReductionParams reduction_params; |
| 98 | + }; |
| 99 | + |
| 100 | +private: |
| 101 | + |
| 102 | + /// Kernel API parameters object |
| 103 | + Params params_; |
| 104 | + |
| 105 | + bool is_initialized(bool set = false) { |
| 106 | + static bool initialized = false; |
| 107 | + if (set) initialized = true; |
| 108 | + return initialized; |
| 109 | + } |
| 110 | + |
| 111 | + static ReductionArguments to_reduction_args(Arguments const& args) { |
| 112 | + auto [H, K, D, B] = args.problem_shape; |
| 113 | + return ReductionArguments{ |
| 114 | + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, |
| 115 | + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, |
| 116 | + args.ptr_split_kv, Kernel::TileShapeS::value |
| 117 | + }; |
| 118 | + } |
| 119 | + |
| 120 | +public: |
| 121 | + |
| 122 | + /// Access the Params structure |
| 123 | + Params const& params() const { |
| 124 | + return params_; |
| 125 | + } |
| 126 | + |
| 127 | + static void set_split_kv (KernelArguments& args) { |
| 128 | + if (args.split_kv >= 1) return; |
| 129 | + auto [H, K, D, B] = args.problem_shape; |
| 130 | + int sm_count = args.hw_info.sm_count; |
| 131 | + int max_splits = ceil_div(K, 128); |
| 132 | + int sms_per_batch = max(1, sm_count / B); |
| 133 | + int split_heur = min(max_splits, sms_per_batch); |
| 134 | + int waves = ceil_div(B * split_heur, sm_count); |
| 135 | + int k_waves = ceil_div(max_splits, split_heur); |
| 136 | + int split_wave_aware = ceil_div(max_splits, k_waves); |
| 137 | + args.split_kv = split_wave_aware; |
| 138 | + } |
| 139 | + |
| 140 | + /// Determines whether the GEMM can execute the given problem. |
| 141 | + static Status |
| 142 | + can_implement(Arguments const& args) { |
| 143 | + if (! Kernel::can_implement(args)) { |
| 144 | + return Status::kInvalid; |
| 145 | + } |
| 146 | + if (! ReductionKernel::can_implement(to_reduction_args(args))) { |
| 147 | + return Status::kInvalid; |
| 148 | + } |
| 149 | + return Status::kSuccess; |
| 150 | + } |
| 151 | + |
| 152 | + /// Gets the workspace size |
| 153 | + static size_t |
| 154 | + get_workspace_size(Arguments const& args) { |
| 155 | + size_t workspace_bytes = 0; |
| 156 | + workspace_bytes += Kernel::get_workspace_size(args); |
| 157 | + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); |
| 158 | + return workspace_bytes; |
| 159 | + } |
| 160 | + |
| 161 | + /// Computes the maximum number of active blocks per multiprocessor |
| 162 | + static int maximum_active_blocks(int /* smem_capacity */ = -1) { |
| 163 | + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); |
| 164 | + int max_active_blocks = -1; |
| 165 | + int smem_size = Kernel::SharedStorageSize; |
| 166 | + |
| 167 | + // first, account for dynamic smem capacity if needed |
| 168 | + cudaError_t result; |
| 169 | + if (smem_size >= (48 << 10)) { |
| 170 | + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); |
| 171 | + result = cudaFuncSetAttribute( |
| 172 | + device_kernel<Kernel>, |
| 173 | + cudaFuncAttributeMaxDynamicSharedMemorySize, |
| 174 | + smem_size); |
| 175 | + if (cudaSuccess != result) { |
| 176 | + result = cudaGetLastError(); // to clear the error bit |
| 177 | + CUTLASS_TRACE_HOST( |
| 178 | + " cudaFuncSetAttribute() returned error: " |
| 179 | + << cudaGetErrorString(result)); |
| 180 | + return -1; |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + // query occupancy after setting smem size |
| 185 | + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( |
| 186 | + &max_active_blocks, |
| 187 | + device_kernel<Kernel>, |
| 188 | + Kernel::MaxThreadsPerBlock, |
| 189 | + smem_size); |
| 190 | + |
| 191 | + if (cudaSuccess != result) { |
| 192 | + result = cudaGetLastError(); // to clear the error bit |
| 193 | + CUTLASS_TRACE_HOST( |
| 194 | + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " |
| 195 | + << cudaGetErrorString(result)); |
| 196 | + return -1; |
| 197 | + } |
| 198 | + |
| 199 | + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); |
| 200 | + return max_active_blocks; |
| 201 | + } |
| 202 | + |
| 203 | + /// Initializes GEMM state from arguments. |
| 204 | + Status |
| 205 | + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { |
| 206 | + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " |
| 207 | + << workspace << ", stream: " << (stream ? "non-null" : "null")); |
| 208 | + |
| 209 | + // Initialize the workspace |
| 210 | + Status status = Kernel::initialize_workspace(args, workspace, stream); |
| 211 | + if (status != Status::kSuccess) { |
| 212 | + return status; |
| 213 | + } |
| 214 | + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); |
| 215 | + if (status != Status::kSuccess) { |
| 216 | + return status; |
| 217 | + } |
| 218 | + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); |
| 219 | + |
| 220 | + ReductionArguments reduction_args = to_reduction_args(args); |
| 221 | + if (reduction_args.split_kv > 1) { |
| 222 | + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; |
| 223 | + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; |
| 224 | + } |
| 225 | + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); |
| 226 | + // Initialize the Params structure |
| 227 | + params_ = Params {kernel_params, reduction_params}; |
| 228 | + |
| 229 | + if (is_initialized()) return Status::kSuccess; |
| 230 | + |
| 231 | + // account for dynamic smem capacity if needed |
| 232 | + // no dynamic smem is needed for reduction kernel |
| 233 | + int smem_size = Kernel::SharedStorageSize; |
| 234 | + if (smem_size >= (48 << 10)) { |
| 235 | + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); |
| 236 | + cudaError_t result = cudaFuncSetAttribute( |
| 237 | + device_kernel<Kernel>, |
| 238 | + cudaFuncAttributeMaxDynamicSharedMemorySize, |
| 239 | + smem_size); |
| 240 | + if (cudaSuccess != result) { |
| 241 | + result = cudaGetLastError(); // to clear the error bit |
| 242 | + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); |
| 243 | + return Status::kErrorInternal; |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + is_initialized(true); |
| 248 | + |
| 249 | + return Status::kSuccess; |
| 250 | + } |
| 251 | + |
| 252 | + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. |
| 253 | + Status |
| 254 | + update(Arguments const& args, void* workspace = nullptr) { |
| 255 | + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); |
| 256 | + |
| 257 | + size_t workspace_bytes = get_workspace_size(args); |
| 258 | + if (workspace_bytes > 0 && nullptr == workspace) { |
| 259 | + return Status::kErrorWorkspaceNull; |
| 260 | + } |
| 261 | + |
| 262 | + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); |
| 263 | + |
| 264 | + ReductionArguments reduction_args = to_reduction_args(args); |
| 265 | + if (reduction_args.split_kv > 1) { |
| 266 | + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; |
| 267 | + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; |
| 268 | + } |
| 269 | + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); |
| 270 | + // Initialize the Params structure |
| 271 | + params_ = Params {fmha_params, reduction_params}; |
| 272 | + |
| 273 | + return Status::kSuccess; |
| 274 | + } |
| 275 | + |
| 276 | + /// Primary run() entry point API that is static allowing users to create and manage their own params. |
| 277 | + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() |
| 278 | + static Status |
| 279 | + run(Params& params, cudaStream_t stream = nullptr) { |
| 280 | + CUTLASS_TRACE_HOST("MLA::run()"); |
| 281 | + dim3 const block = Kernel::get_block_shape(); |
| 282 | + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); |
| 283 | + |
| 284 | + // configure smem size and carveout |
| 285 | + int smem_size = Kernel::SharedStorageSize; |
| 286 | + |
| 287 | + Status launch_result; |
| 288 | + // Use extended launch API only for mainloops that use it |
| 289 | + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { |
| 290 | + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), |
| 291 | + cute::size<1>(typename Kernel::ClusterShape{}), |
| 292 | + cute::size<2>(typename Kernel::ClusterShape{})); |
| 293 | + void const* kernel = (void const*) device_kernel<Kernel>; |
| 294 | + void* kernel_params[] = {¶ms.fmha_params}; |
| 295 | + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); |
| 296 | + } |
| 297 | + else { |
| 298 | + launch_result = Status::kSuccess; |
| 299 | + device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params); |
| 300 | + } |
| 301 | + |
| 302 | + cudaError_t result = cudaGetLastError(); |
| 303 | + if (cudaSuccess != result or Status::kSuccess != launch_result) { |
| 304 | + //return Status::kSuccess; |
| 305 | + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); |
| 306 | + return Status::kErrorInternal; |
| 307 | + } |
| 308 | + if (params.reduction_params.split_kv > 1) { |
| 309 | + // launch reduction kernel |
| 310 | + dim3 const block = ReductionKernel::get_block_shape(); |
| 311 | + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); |
| 312 | + device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params); |
| 313 | + cudaError_t result = cudaGetLastError(); |
| 314 | + if (cudaSuccess == result) { |
| 315 | + return Status::kSuccess; |
| 316 | + } |
| 317 | + else { |
| 318 | + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); |
| 319 | + return Status::kErrorInternal; |
| 320 | + } |
| 321 | + } |
| 322 | + else { |
| 323 | + return Status::kSuccess; |
| 324 | + } |
| 325 | + } |
| 326 | + |
| 327 | + // |
| 328 | + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. |
| 329 | + // |
| 330 | + |
| 331 | + /// Launches the kernel after first constructing Params internal state from supplied arguments. |
| 332 | + Status |
| 333 | + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { |
| 334 | + Status status = initialize(args, workspace, stream); |
| 335 | + if (Status::kSuccess == status) { |
| 336 | + status = run(params_, stream); |
| 337 | + } |
| 338 | + return status; |
| 339 | + } |
| 340 | + |
| 341 | + /// Launches the kernel after first constructing Params internal state from supplied arguments. |
| 342 | + Status |
| 343 | + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { |
| 344 | + return run(args, workspace, stream); |
| 345 | + } |
| 346 | + |
| 347 | + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. |
| 348 | + Status |
| 349 | + run(cudaStream_t stream = nullptr) { |
| 350 | + return run(params_, stream); |
| 351 | + } |
| 352 | + |
| 353 | + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. |
| 354 | + Status |
| 355 | + operator()(cudaStream_t stream = nullptr) { |
| 356 | + return run(params_, stream); |
| 357 | + } |
| 358 | +}; |
| 359 | + |
| 360 | +//////////////////////////////////////////////////////////////////////////////// |
| 361 | + |
| 362 | +} // namespace cutlass::fmha::device |
| 363 | + |
| 364 | +//////////////////////////////////////////////////////////////////////////////// |
0 commit comments