Skip to content

[kernels][blackwell] add cutlass/cute group gemm forward for blackwell #1327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions torchtitan/experiments/deepseek_v3/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# use inference.sh "Your Question Here?" to run inference with a single prompt.

import os
import sys
from dataclasses import dataclass

Expand Down Expand Up @@ -127,7 +128,7 @@ def create_model(dist_config: DistConfig):
model_args.ep_size = dist_config.ep_size
model_args.num_stages = dist_config.pp_size
model_args.stage_idx = dist_config.pp_rank
model_args.max_seq_len = 4096 # 16384
model_args.max_seq_len = 1024 # 4096 # 16384

with dist_config.device, dist_config.mesh:
model = DeepseekForCausalLM(model_args)
Expand Down Expand Up @@ -224,7 +225,7 @@ def generate(
tokenizer,
dist_config,
messages: list[dict],
n_tokens: int = 200,
n_tokens: int = 80,
):
rank = dist.get_rank()
device = dist_config.device
Expand Down Expand Up @@ -353,6 +354,12 @@ def generate_with_cuda_graph(


if __name__ == "__main__":
# set device
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

run_with_cuda_graph = False
run_two_times = True

# Get user prompt from command line arguments
user_prompt = "What is 2+2?" # Default prompt
if len(sys.argv) > 1:
Expand All @@ -375,7 +382,13 @@ def generate_with_cuda_graph(
]

generate(model, pp_schedule, tokenizer, dist_config, messages)
generate_with_cuda_graph(model, tokenizer, dist_config, messages)

# we run a second time to compare the performance (i.e. compilation overhead)
if run_two_times:
generate(model, pp_schedule, tokenizer, dist_config, messages)

if run_with_cuda_graph:
generate_with_cuda_graph(model, tokenizer, dist_config, messages)

if rank == 0:
print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
Expand Down
13 changes: 11 additions & 2 deletions torchtitan/experiments/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@
TorchFP8GroupGEMM,
TritonCGBF16GroupGEMM,
)

from model_config import ModelArgs
from symm_mem_recipes import OnDeviceAllToAllV
from torch import nn
from torch.distributed._functional_collectives import all_to_all_single_autograd

# blackwell specific
from torchtitan.experiments.kernels.blackwell.cute_grouped_gemm_fwd import (
CUTLASSGroupedGemmStrategy,
)

from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import ALIGN_SIZE_M

Expand Down Expand Up @@ -474,7 +478,7 @@ class MoE(nn.Module):
# Group GEMM strategies
group_gemm_strategies = None
# which group gemm to use?
group_mm = "torch" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg"]
group_mm = "cute" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch","torchao", "tritoncg"], blackwell = ["cute"]

def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -550,6 +554,11 @@ def _initialize_group_gemm_strategies(cls):
if TritonCGBF16GroupGEMM.is_available()
else None
),
"cute": (
CUTLASSGroupedGemmStrategy(MLP.act_fn)
if CUTLASSGroupedGemmStrategy.is_available()
else None
),
}

def combine_experts(self, submod_name: str):
Expand Down
Loading
Loading