diff --git a/torchtitan/experiments/deepseek_v3/generate.py b/torchtitan/experiments/deepseek_v3/generate.py index 67b551a2f..07d173fe7 100644 --- a/torchtitan/experiments/deepseek_v3/generate.py +++ b/torchtitan/experiments/deepseek_v3/generate.py @@ -224,7 +224,7 @@ def generate( tokenizer, dist_config, messages: list[dict], - n_tokens: int = 200, + n_tokens: int = 50, ): rank = dist.get_rank() device = dist_config.device diff --git a/torchtitan/experiments/deepseek_v3/group_gemms.py b/torchtitan/experiments/deepseek_v3/group_gemms.py index f4020dee2..ffb767d9e 100644 --- a/torchtitan/experiments/deepseek_v3/group_gemms.py +++ b/torchtitan/experiments/deepseek_v3/group_gemms.py @@ -47,6 +47,24 @@ except ImportError: TRITON_CONTIGUOUS_GROUP_GEMM_AVAILABLE = False +# Cutlass Cute DSL +try: + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + + # import cutlass.torch as cutlass_torch + # import cutlass.utils as utils + from cutlass.cute.runtime import from_dlpack + + from torchtitan.experiments.kernels.blackwell.cute_dense_gemm import DenseGemmKernel + + CUTLASS_AVAILABLE = True +except ImportError as e: + CUTLASS_AVAILABLE = False + print(f"Cutlass imports not available: {e}`") + print("Please run `pip install nvidia-cutlass-dsl`") + # Strategy base class for GroupGEMM implementations class GroupGEMMStrategy: @@ -97,9 +115,410 @@ def is_available() -> bool: "TorchBF16GroupGEMM", "TorchAOBF16GroupGEMM", "TritonCGBF16GroupGEMM", + "CuteDenseLoopingGroupGEMM", ] +# requires pip install nvidia-cutlass-dsl +class CuteDenseLoopingGroupGEMM(GroupGEMMStrategy): + """ + Implementation of grouped GEMM using Blackwell Dense GEMM kernel with manual looping. + + High level overview: + - Compiled kernels via Kernel caching: Compiled kernels are cached and reused + - Expert token tensor reuse: For MoE forward pass, expert_tokens are converted to CUTE + format once and reused for both gate and up projections + - Backup: Falls back to PyTorch implementation if CUTE kernels fail + """ + + def __init__(self, custom_activation): + + super().__init__(custom_activation) + + # Kernel configuration + self.alignment = 16 + self.dtype = torch.bfloat16 + self.cutlass_dtype = cutlass.BFloat16 + + # Initialize Cute Dense GEMM kernel + try: + self.gemm_kernel = DenseGemmKernel( + acc_dtype=cutlass.Float32, + use_2cta_instrs=False, + mma_tiler_mn=(128, 128), + cluster_shape_mn=(1, 1), + use_tma_store=False, + ) + except Exception as e: + raise RuntimeError(f"Failed to initialize GEMM kernel: {e}") from e + + # Setup CUDA stream + torch_stream = torch.cuda.Stream() + self.stream = cuda.CUstream(torch_stream.cuda_stream) + + # Cache for compiled kernels + self._compiled_kernels = {} + + # debug monitoring + self.debug_mode = True + + def arrange_expert_weights(self, all_weights, submod_name, module): + """Store weights in a simple list format""" + return torch.stack(all_weights) + + def _create_cute_tensor(self, tensor: torch.Tensor) -> cute.Tensor: + """ + Convert a PyTorch tensor to a CUTE tensor with proper formatting. + + Args: + tensor: PyTorch tensor to convert + + Returns: + CUTE tensor ready for kernel execution + """ + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + # Convert to MNKL format + tensor_mnkl = tensor.unsqueeze(-1).contiguous().detach() + + # Create CUTE tensor + cute_tensor = from_dlpack(tensor_mnkl, assumed_align=self.alignment) + cute_tensor.element_type = self.cutlass_dtype + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=1) + + return cute_tensor + + def _get_or_compile_kernel(self, a_cute, b_cute, c_cute, operation_name: str): + """ + Get a compiled kernel from cache or compile a new one. + + Args: + a_cute: Input tensor A in CUTE format + b_cute: Input tensor B in CUTE format + c_cute: Output tensor C in CUTE format + operation_name: Name of the operation for caching + + Returns: + Compiled CUTE kernel + """ + cache_key = f"{operation_name}_{a_cute.shape}_{b_cute.shape}_{c_cute.shape}" + + if cache_key not in self._compiled_kernels: + try: + self._compiled_kernels[cache_key] = cute.compile( + self.gemm_kernel, a_cute, b_cute, c_cute, self.stream + ) + if self.debug_mode: + print(f"✓ Compiled kernel for {operation_name}") + except Exception as e: + raise RuntimeError( + f"Failed to compile {operation_name} kernel: {e}" + ) from e + + return self._compiled_kernels[cache_key] + + def _execute_gemm_operation( + self, input_tensor: torch.Tensor, weight: torch.Tensor, operation_name: str + ) -> torch.Tensor: + """ + Execute a single GEMM operation using cute dense kernel. + + Args: + input_tensor: Input tensor [M, K] + weight: Weight tensor [N, K] + operation_name: Name of the operation for debugging + + Returns: + Output tensor [M, N] + """ + batch_size, input_dim = input_tensor.shape + output_dim = weight.shape[0] + + # Create output tensor + output = torch.zeros( + (batch_size, output_dim), + device=input_tensor.device, + dtype=self.dtype, + requires_grad=False, + ) + + # Convert tensors to cute format + try: + a_cute = self._create_cute_tensor(input_tensor) + b_cute = self._create_cute_tensor(weight) + c_cute = self._create_cute_tensor(output) + except Exception as e: + raise RuntimeError( + f"Failed to create CUTE tensors for {operation_name}: {e}" + ) from e + + # Get or compile kernel + compiled_kernel = self._get_or_compile_kernel( + a_cute, b_cute, c_cute, operation_name + ) + + # Execute kernel + try: + compiled_kernel(a_cute, b_cute, c_cute, self.stream) + if self.debug_mode: + print(f"✓ Executed {operation_name} kernel successfully") + except Exception as e: + raise RuntimeError(f"Failed to execute {operation_name} kernel: {e}") from e + + return output.squeeze(-1) if output.dim() > 2 else output + + def _execute_gemm_with_cute_input( + self, + input_cute: cute.Tensor, + weight: torch.Tensor, + operation_name: str, + output_shape: tuple, + ) -> torch.Tensor: + """ + Execute a GEMM operation with pre-converted cute input tensor. + + Args: + input_cute: Input cute tensor (already in cute format) + weight: Weight tensor [N, K] + operation_name: Name of the operation for debugging + output_shape: Shape of output tensor (batch_size, output_dim) + + Returns: + Output tensor [M, N] + """ + batch_size, output_dim = output_shape + + # Create output tensor + output = torch.zeros( + (batch_size, output_dim), + device=weight.device, + dtype=self.dtype, + requires_grad=False, + ) + + # Convert weight and output tensors to CUTE format + try: + b_cute = self._create_cute_tensor(weight) + c_cute = self._create_cute_tensor(output) + except Exception as e: + raise RuntimeError( + f"Failed to create CUTE tensors for {operation_name}: {e}" + ) from e + + # Get or compile kernel + compiled_kernel = self._get_or_compile_kernel( + input_cute, b_cute, c_cute, operation_name + ) + + # Execute kernel + try: + compiled_kernel(input_cute, b_cute, c_cute, self.stream) + if self.debug_mode: + print(f"✓ Executed {operation_name} kernel successfully") + except Exception as e: + raise RuntimeError(f"Failed to execute {operation_name} kernel: {e}") from e + + return output.squeeze(-1) if output.dim() > 2 else output + + def _process_expert( + self, + expert_tokens: torch.Tensor, + expert_idx: int, + w_gate: torch.Tensor, + w_up: torch.Tensor, + w_down: torch.Tensor, + ) -> torch.Tensor: + """ + Process tokens through a single expert using cute dense kernels. + + Args: + expert_tokens: Tokens for this expert [num_tokens, hidden_size] + expert_idx: Index of the expert + w_gate: Gate projection weights [intermediate_size, hidden_size] + w_up: Up projection weights [intermediate_size, hidden_size] + w_down: Down projection weights [hidden_size, intermediate_size] + + Returns: + Expert output [num_tokens, hidden_size] + """ + num_tokens = expert_tokens.shape[0] + intermediate_size = w_gate.shape[0] + hidden_size = w_down.shape[0] + + # Convert expert_tokens to CUTE format once for reuse + # OPTIMIZATION: Gate and up projections share the same input tensor, + # so we convert to CUTE format once and reuse to avoid redundant overhead + try: + expert_tokens_cute = self._create_cute_tensor(expert_tokens) + except BaseException as e: + raise RuntimeError( + f"Failed to create CUTE tensor for expert {expert_idx} input: {e}" + ) from e + + # Gate projection - reuse the CUTE input tensor + gate_out = self._execute_gemm_with_cute_input( + expert_tokens_cute, + w_gate, + f"gate_expert_{expert_idx}", + (num_tokens, intermediate_size), + ) + + # Up projection - reuse the same CUTE input tensor + up_out = self._execute_gemm_with_cute_input( + expert_tokens_cute, + w_up, + f"up_expert_{expert_idx}", + (num_tokens, intermediate_size), + ) + + # Apply activation and combine + hidden = self.activation_function(gate_out) * up_out + + # Down projection - create new CUTE tensor for hidden state + expert_output = self._execute_gemm_operation( + hidden, w_down, f"down_expert_{expert_idx}" + ) + + return expert_output + + def execute(self, contig_tokens, m_sizes, m_offsets, module): + """ + Execute the complete grouped GEMM operation via looping. + + Args: + contig_tokens: Input tokens arranged contiguously by expert + m_sizes: Sizes of each group + m_offsets: Offsets of each group + module: MoE module containing weights and parameters + + Returns: + Processed tokens + """ + try: + # Get weights + device = contig_tokens.device + w_gate = module.get_parameter("gate_proj_weight") + w_up = module.get_parameter("up_proj_weight") + w_down = module.get_parameter("down_proj_weight") + + # Validate inputs + if len(m_sizes) != w_gate.shape[0]: + raise ValueError( + f"Number of experts mismatch: {len(m_sizes)} vs {w_gate.shape[0]}" + ) + + # Prepare output tensor + hidden_size = w_gate.shape[2] if len(w_gate.shape) > 2 else w_gate.shape[1] + output = torch.zeros( + contig_tokens.shape[0], + hidden_size, + dtype=contig_tokens.dtype, + device=device, + ) + + # Process each expert + offset = 0 + active_experts = 0 + + for expert_idx, size in enumerate(m_sizes): + if size > 0: + # Get tokens and weights for this expert + expert_tokens = contig_tokens[offset : offset + size] + expert_gate_weight = w_gate[expert_idx] + expert_up_weight = w_up[expert_idx] + expert_down_weight = w_down[expert_idx] + + # Process through expert + expert_output = self._process_expert( + expert_tokens, + expert_idx, + expert_gate_weight, + expert_up_weight, + expert_down_weight, + ) + + # Store results + output[offset : offset + size] = expert_output + active_experts += 1 + + offset += size + + if self.debug_mode: + print( + f"Processed {active_experts} active experts out of {len(m_sizes)} total" + ) + + return output + + except Exception as e: + # Fallback to PyTorch implementation on error + if self.debug_mode: + print(f"CUTE kernel failed, falling back to PyTorch: {e}") + return self._fallback_pytorch(contig_tokens, m_sizes, module) + + def _fallback_pytorch(self, contig_tokens, m_sizes, module): + """ + Fallback implementation using standard PyTorch operations. + + Args: + contig_tokens: Input tokens + m_sizes: Group sizes + module: MoE module + + Returns: + Processed tokens using PyTorch mm + """ + print("\nWARNING: Cute GEMM issue -- Falling back to PyTorch implementation\n") + device = contig_tokens.device + w_gate = module.get_parameter("gate_proj_weight") + w_up = module.get_parameter("up_proj_weight") + w_down = module.get_parameter("down_proj_weight") + + hidden_size = w_gate.shape[2] if len(w_gate.shape) > 2 else w_gate.shape[1] + output = torch.zeros( + contig_tokens.shape[0], + hidden_size, + dtype=contig_tokens.dtype, + device=device, + ) + + offset = 0 + for expert_idx, size in enumerate(m_sizes): + if size > 0: + expert_tokens = contig_tokens[offset : offset + size] + + # Standard PyTorch forward pass + gate_out = torch.mm(expert_tokens, w_gate[expert_idx].t()) + up_out = torch.mm(expert_tokens, w_up[expert_idx].t()) + hidden = self.activation_function(gate_out) * up_out + expert_output = torch.mm(hidden, w_down[expert_idx].t()) + + output[offset : offset + size] = expert_output + + offset += size + + return output + + def clear_cache(self): + """Clear the compiled kernel cache to free memory.""" + self._compiled_kernels.clear() + if self.debug_mode: + print("Cleared compiled kernel cache") + + def set_debug_mode(self, enabled: bool = False): + """Enable or disable debug mode.""" + self.debug_mode = enabled + + @staticmethod + def is_available() -> bool: + """Check if this strategy is available on the current system.""" + try: + return CUTLASS_AVAILABLE and torch.cuda.is_available() + except Exception: + return False + + class TritonCGBF16GroupGEMM(GroupGEMMStrategy): """Implementation of Triton Contiguous group Gemm""" diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index 131b1ea2b..6e53d2924 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -45,6 +45,7 @@ from attn_mask_utils import _prepare_4d_causal_attention_mask from group_gemms import ( + CuteDenseLoopingGroupGEMM, DSGroupGEMM, TorchAOBF16GroupGEMM, TorchBF16GroupGEMM, @@ -474,7 +475,8 @@ class MoE(nn.Module): # Group GEMM strategies group_gemm_strategies = None # which group gemm to use? - group_mm = "torch" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg"] + group_mm = "cute" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg", "cute"] + print(f"Using group gemm strategy: {group_mm}") def __init__(self, config): super().__init__() @@ -550,6 +552,11 @@ def _initialize_group_gemm_strategies(cls): if TritonCGBF16GroupGEMM.is_available() else None ), + "cute": ( + CuteDenseLoopingGroupGEMM(MLP.act_fn) + if CuteDenseLoopingGroupGEMM.is_available() + else None + ), } def combine_experts(self, submod_name: str): diff --git a/torchtitan/experiments/deepseek_v3/requirements.txt b/torchtitan/experiments/deepseek_v3/requirements.txt index 2b66a52d8..738febc09 100644 --- a/torchtitan/experiments/deepseek_v3/requirements.txt +++ b/torchtitan/experiments/deepseek_v3/requirements.txt @@ -3,3 +3,4 @@ accelerate torchdata >= 0.8.0 datasets >= 2.21.0 tomli >= 1.1.0 ; python_version < "3.11" +nvidia-cutlass-dsl diff --git a/torchtitan/experiments/kernels/blackwell/cute_dense_gemm.py b/torchtitan/experiments/kernels/blackwell/cute_dense_gemm.py new file mode 100644 index 000000000..de7eca619 --- /dev/null +++ b/torchtitan/experiments/kernels/blackwell/cute_dense_gemm.py @@ -0,0 +1,1934 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# requires pip install nvidia-cutlass-dsl + +import argparse +from typing import Optional, Tuple, Type, Union + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils + +import torch +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import from_dlpack + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +4. Type convert C matrix to output type. +5. Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +6. Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/dense_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Blackwell tcgen05 MMA tile shape used 2 cta with 256x128 +MMA tile and the cluster shape is (2,1). The input, mma accumulator and output +data type are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +Constraints: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below DenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class DenseGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tiler (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = DenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + + Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + self.threads_per_cta = 128 + self.num_smem_capacity = sm100_utils.SMEM_CAPACITY["sm100"] + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.num_smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if cutlass.const_expr(self.use_tma_store) + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tma_tile_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tma_tile_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tma_tile_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) + if cutlass.const_expr(self.use_tma_store) + else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ # noqa: N815 + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ # noqa: N815 + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ # noqa: N815 + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma descriptor + # + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coords inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = utils.CooperativeGroup( + utils.Agent.Thread, num_tma_producer + ) + ab_pipeline = utils.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + ab_producer_state = utils.make_pipeline_state( + utils.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = utils.make_pipeline_state( + utils.PipelineUserType.Consumer, self.num_ab_stage + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) + acc_pipeline_consumer_group = utils.CooperativeGroup( + utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + acc_pipeline = utils.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + acc_producer_state = utils.make_pipeline_state( + utils.PipelineUserType.Producer, self.num_acc_stage + ) + acc_consumer_state = utils.make_pipeline_state( + utils.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == 0: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init_arrive_cnt( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if cutlass.const_expr(self.use_tma_store) + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if self.is_a_mcast or self.is_b_mcast or use_2cta_instrs: + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_block_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + + # + # Alloc tensor memory buffer + # + if warp_idx == 0: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + cute.arch.barrier() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( + tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), loopK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + + # + # Pipelining TMA load A/B and MMA mainloop + # + prefetch_k_block_cnt = cutlass.min(self.num_ab_stage - 2, k_block_cnt) + + if warp_idx == 0: + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Prefetch TMA load A/B + # + for prefetch_idx in cutlass.range_dynamic(prefetch_k_block_cnt, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # Peek (try_wait) AB buffer full for k_block = 0 + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_block_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # MMA mainloop + # + for k_block in cutlass.range_dynamic(0, k_block_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + if ab_producer_state.count < k_block_cnt: + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # tCtAcc += tCrA * tCrB + num_kphases = cute.size(tCrA, mode=[2]) + for kphase_idx in range(num_kphases): + kphase_coord = (None, None, kphase_idx, ab_consumer_state.index) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kphase_coord], + tCrB[kphase_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kphase + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # Peek (try_wait) AB buffer full for k_block = k_block + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_block_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # + # Epilogue + # + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Initialize tma store c_pipeline + c_producer_group = utils.CooperativeGroup( + utils.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = utils.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range_dynamic(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier() + + # + # TMA store C to global memory + # + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier() + else: + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Dealloc the tensor memory buffer + # + cute.arch.barrier() + if warp_idx == 0: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + # + # Wait A/B buffer empty + # + if warp_idx == 0: + # Reverse prefetch_k_block_cnt times to next available buffer + for i in cutlass.range_dynamic(prefetch_k_block_cnt): + ab_producer_state.reverse() + ab_pipeline.producer_tail(ab_producer_state) + return + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy( + copy_atom_r2s, + layout_tv=tiled_copy_t2r.layout_dst_tv_tiled, + tiler_mn=tiled_copy_t2r.tiler_mn, + ) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + num_smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tile. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + num_smem_capacity + - ab_bytes_per_stage * num_ab_stage + - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ((occupancy + 1) * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + ) -> Tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(c.layout.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c.layout.shape[1], cta_tile_shape_mnk[1]), + c.layout.shape[2], + ), + cluster_shape_mnl, + ) + + return grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[ + cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp + ]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment( # noqa: N802 + dtype, is_mode0_major, tensor_shape + ): # noqa: N802 + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not DenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not DenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + measure_launch_overhead=False, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print("Running B100 Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not DenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" # noqa: B950 + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a_ref, a_tensor, a_torch = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = DenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) + + # Launch GPU kernel + # Warm up + for i in range(warmup_iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + # Execution + for i in range(iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + + # Compute reference result + if not skip_ref_check: + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # k major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + # or: return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) from None + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tiler (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS")