Skip to content

Commit 00a6cf3

Browse files
lessw2020H-Huang
authored andcommitted
[deepseek][blackwell] add manual looping group gemm to enable base working inference on Blackwell (#1272)
This PR enables deepseek inference to run on Blackwell (B200). Currently, torch._grouped_mm is specific to Hopper...thus trying to run on B200 via TorchBF16GroupGEMM yields: ~~~ "Error using torch strategy: torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0" ~~~ thus this PR adds a manual looping group gemm to get ds inference working on Blackwell. *Note that you must use Symmetric Memory for the all2all, dist.all2all_single does not yet work on Blackwell. Wtih this PR: <img width="1103" alt="Screenshot 2025-06-07 at 4 20 31 PM" src="https://github.com/user-attachments/assets/0a1b77d7-6423-4c2a-91aa-2f8587cae78a" /> Token per second of 1.21 is not great, but we have moved now from 'not working' to a working inference on B200.
1 parent fb8b01e commit 00a6cf3

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

torchtitan/experiments/deepseek_v3/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def generate(
224224
tokenizer,
225225
dist_config,
226226
messages: list[dict],
227-
n_tokens: int = 200,
227+
n_tokens: int = 50,
228228
):
229229
rank = dist.get_rank()
230230
device = dist_config.device

torchtitan/experiments/deepseek_v3/group_gemms.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,70 @@ def is_available() -> bool:
9797
"TorchBF16GroupGEMM",
9898
"TorchAOBF16GroupGEMM",
9999
"TritonCGBF16GroupGEMM",
100+
"ManualLoopGroupGEMM",
100101
]
101102

102103

104+
class ManualLoopGroupGEMM(GroupGEMMStrategy):
105+
"""Manual looping baseline implementation for any arch (esp Blackwell) support"""
106+
107+
def arrange_expert_weights(self, all_weights, submod_name, module):
108+
"""Store weights in a stacked format"""
109+
return torch.stack(all_weights)
110+
111+
def execute(self, contig_tokens, m_sizes, m_offsets, module):
112+
"""Execute using manual loops over experts"""
113+
# Get weights
114+
115+
w_gate = module.get_parameter("gate_proj_weight")
116+
w_up = module.get_parameter("up_proj_weight")
117+
w_down = module.get_parameter("down_proj_weight")
118+
119+
# Prepare output tensor
120+
hidden_size = w_gate.shape[
121+
2
122+
] # stacked weights shape [num_experts, out_dim, in_dim]
123+
output = torch.zeros(
124+
contig_tokens.shape[0],
125+
hidden_size,
126+
dtype=contig_tokens.dtype,
127+
device=contig_tokens.device,
128+
)
129+
130+
# Process each expert sequentially
131+
offset = 0
132+
for expert_idx, size in enumerate(m_sizes):
133+
if size > 0:
134+
# Get tokens for this expert
135+
expert_tokens = contig_tokens[offset : offset + size]
136+
137+
# Get weights for this expert
138+
gate_weight = w_gate[expert_idx] # [out_dim, in_dim]
139+
up_weight = w_up[expert_idx]
140+
down_weight = w_down[expert_idx]
141+
142+
# Forward pass: gate and up projections
143+
gate_out = torch.mm(expert_tokens, gate_weight.t())
144+
up_out = torch.mm(expert_tokens, up_weight.t())
145+
146+
# Apply activation and combine
147+
hidden = self.activation_function(gate_out) * up_out
148+
149+
# Down projection
150+
expert_output = torch.mm(hidden, down_weight.t())
151+
152+
# Store results
153+
output[offset : offset + size] = expert_output
154+
155+
offset += size
156+
157+
return output
158+
159+
@staticmethod
160+
def is_available() -> bool:
161+
return True
162+
163+
103164
class TritonCGBF16GroupGEMM(GroupGEMMStrategy):
104165
"""Implementation of Triton Contiguous group Gemm"""
105166

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
from group_gemms import (
4848
DSGroupGEMM,
49+
ManualLoopGroupGEMM,
4950
TorchAOBF16GroupGEMM,
5051
TorchBF16GroupGEMM,
5152
TorchFP8GroupGEMM,
@@ -474,7 +475,7 @@ class MoE(nn.Module):
474475
# Group GEMM strategies
475476
group_gemm_strategies = None
476477
# which group gemm to use?
477-
group_mm = "torch" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg"]
478+
group_mm = "manual" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg", "manual"]
478479

479480
def __init__(self, config):
480481
super().__init__()
@@ -527,7 +528,10 @@ def __init__(self, config):
527528
def _initialize_group_gemm_strategies(cls):
528529
"""Initialize available group GEMM strategies"""
529530
cls.group_gemm_strategies = {
531+
# torch._group_MM
530532
"torch": TorchBF16GroupGEMM(MLP.act_fn),
533+
# torch.mm with looping
534+
"manual": ManualLoopGroupGEMM(MLP.act_fn),
531535
"torchao": (
532536
TorchAOBF16GroupGEMM(MLP.act_fn)
533537
if TorchAOBF16GroupGEMM.is_available()

0 commit comments

Comments
 (0)