Skip to content

Commit 09bbac6

Browse files
authored
Add DeepGEMM pre-compile tools (#2819)
This tool allows you to compile all possible kernels in advance through the model's config.json, and avoids the situation where uncompiled kernel is encountered and JIT is executed when certain requests arrive.
1 parent 7f64d40 commit 09bbac6

File tree

3 files changed

+366
-0
lines changed

3 files changed

+366
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
import logging
18+
import math
19+
import os
20+
from typing import Tuple
21+
22+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
23+
get_smem_config
24+
25+
logger = logging.getLogger(__name__)
26+
console_handler = logging.StreamHandler()
27+
logger.addHandler(console_handler)
28+
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
29+
30+
31+
def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
32+
hidden_size = model_cfg["hidden_size"]
33+
intermediate_size = model_cfg["intermediate_size"]
34+
moe_intermediate_size = model_cfg["moe_intermediate_size"]
35+
num_attention_heads = model_cfg["num_attention_heads"]
36+
num_key_value_heads = model_cfg["num_key_value_heads"]
37+
head_dim = int(hidden_size / num_attention_heads)
38+
gemm_kn_pairs = [
39+
# Dense normal gemm
40+
[hidden_size, intermediate_size * 2],
41+
[intermediate_size, hidden_size],
42+
[hidden_size, hidden_size],
43+
[hidden_size, (num_attention_heads + num_key_value_heads * 2) * head_dim],
44+
]
45+
grouped_gemm_contiguous_kn_pairs = [
46+
# Moe grouped gemm contiguous
47+
[hidden_size, moe_intermediate_size * 2],
48+
[moe_intermediate_size, hidden_size],
49+
]
50+
grouped_gemm_masked_kn_pairs = [
51+
# Moe grouped gemm masked
52+
[hidden_size, moe_intermediate_size * 2],
53+
[moe_intermediate_size, hidden_size],
54+
]
55+
56+
return gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs
57+
58+
59+
def generate_json(
60+
kn_pairs: list,
61+
moe_num_experts: int,
62+
output_path: str,
63+
is_grouped_contiguous: bool = False,
64+
is_grouped_masked: bool = False,
65+
):
66+
if not is_grouped_contiguous:
67+
BLOCK_MS = [64, 128, 256]
68+
else:
69+
BLOCK_MS = [128]
70+
BLOCK_NS = list(range(16, 129, 8)) + [144, 160]
71+
TMA_MULTICAST_CONFIGS = [(1, True), (1, False), (2, True), (2, False)]
72+
counter = 0
73+
with open(output_path, "a+", encoding="utf-8") as f:
74+
for block_m in BLOCK_MS:
75+
for block_n in BLOCK_NS:
76+
if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4:
77+
NUM_STAGES = [4, 3]
78+
else:
79+
NUM_STAGES = [8, 7, 6, 5, 4, 3]
80+
for num_stages in NUM_STAGES:
81+
for kn_pair in kn_pairs:
82+
smem_config = get_smem_config(
83+
num_stages, kn_pair[0], block_m, block_n
84+
)
85+
for tma_multicast_config in TMA_MULTICAST_CONFIGS:
86+
cfg = {
87+
"N": kn_pair[1],
88+
"K": kn_pair[0],
89+
"BLOCK_M": block_m,
90+
"BLOCK_N": block_n,
91+
"SWIZZLE_D_MODE": smem_config[1],
92+
"BLOCK_N_PADDING": smem_config[2],
93+
"NUM_STAGES": num_stages,
94+
"NUM_TMA_MULTICAST": tma_multicast_config[0],
95+
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
96+
"IS_GROUPED_CONTIGUOUS": is_grouped_contiguous,
97+
"IS_GROUPED_MASKED": is_grouped_masked,
98+
"MOE_NUM_EXPERTS": moe_num_experts,
99+
}
100+
f.write(json.dumps(cfg) + "\n")
101+
counter += 1
102+
103+
return counter
104+
105+
106+
def main(args):
107+
with open(os.path.join(args.model, "config.json"), "r") as f:
108+
model_cfg = json.load(f)
109+
110+
gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs = (
111+
generate_kn_pairs(model_cfg)
112+
)
113+
num_gemm = generate_json(
114+
gemm_kn_pairs,
115+
model_cfg["moe_num_experts"],
116+
args.output,
117+
)
118+
num_grouped_contiguous = generate_json(
119+
grouped_gemm_contiguous_kn_pairs,
120+
model_cfg["moe_num_experts"],
121+
args.output,
122+
is_grouped_contiguous=True,
123+
)
124+
num_grouped_masked = generate_json(
125+
grouped_gemm_masked_kn_pairs,
126+
model_cfg["moe_num_experts"],
127+
args.output,
128+
is_grouped_masked=True,
129+
)
130+
logger.info(f"Configurations generated and saved to {args.output}")
131+
logger.info(f"Generated {num_gemm} gemm configuration.")
132+
logger.info(
133+
f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration."
134+
)
135+
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")
136+
137+
138+
if __name__ == "__main__":
139+
parser = argparse.ArgumentParser()
140+
parser.add_argument(
141+
"--model",
142+
type=str,
143+
required=True,
144+
)
145+
parser.add_argument(
146+
"--output",
147+
type=str,
148+
default="./deep_gemm_pre_compile_config.jsonl",
149+
)
150+
args = parser.parse_args()
151+
main(args)
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import json
17+
import logging
18+
import os
19+
import threading
20+
from queue import Queue
21+
from time import time
22+
23+
import paddle
24+
from tqdm import tqdm
25+
26+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.compiler import build
27+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.template import (
28+
cpp_format, generate)
29+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
30+
includes as gemm_includes
31+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
32+
template as gemm_template
33+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \
34+
includes as m_grouped_includes
35+
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \
36+
template as m_grouped_template
37+
38+
logger = logging.getLogger(__name__)
39+
console_handler = logging.StreamHandler()
40+
logger.addHandler(console_handler)
41+
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))
42+
43+
44+
class CompileWorker(threading.Thread):
45+
def __init__(self, queue, pbar):
46+
super().__init__()
47+
self.queue = queue
48+
self.pbar = pbar
49+
50+
def run(self):
51+
while True:
52+
cfg = self.queue.get()
53+
if cfg is None:
54+
break
55+
56+
try:
57+
logger.debug(f"Compiling for config: {cfg}")
58+
keys = {
59+
"N": cfg["N"],
60+
"K": cfg["K"],
61+
"BLOCK_M": cfg["BLOCK_M"],
62+
"BLOCK_N": cfg["BLOCK_N"],
63+
"SWIZZLE_D_MODE": cfg["SWIZZLE_D_MODE"],
64+
"BLOCK_N_PADDING": cfg["BLOCK_N_PADDING"],
65+
"NUM_STAGES": cfg["NUM_STAGES"],
66+
"NUM_TMA_MULTICAST": cfg["NUM_TMA_MULTICAST"],
67+
"IS_TMA_MULTICAST_ON_A": cfg["IS_TMA_MULTICAST_ON_A"],
68+
}
69+
arg_defs = (
70+
("lhs", paddle.float8_e4m3fn),
71+
("lhs_scales", paddle.float32),
72+
("rhs", paddle.float8_e4m3fn),
73+
("rhs_scales", paddle.float32),
74+
("out", paddle.bfloat16),
75+
("m", int),
76+
("stream", paddle.device.cuda.Stream),
77+
("num_sms", int),
78+
("smem_size", int),
79+
)
80+
name = "gemm_fp8_fp8_bf16_nt"
81+
includes = gemm_includes
82+
template = gemm_template
83+
if cfg["IS_GROUPED_CONTIGUOUS"]:
84+
keys["GEMM_TYPE"] = "GroupedContiguous"
85+
arg_defs = (
86+
("lhs", paddle.float8_e4m3fn),
87+
("lhs_scales", paddle.float32),
88+
("rhs", paddle.float8_e4m3fn),
89+
("rhs_scales", paddle.float32),
90+
("out", paddle.bfloat16),
91+
("grouped_layout", paddle.int32),
92+
("m", int),
93+
("num_groups", int),
94+
("stream", paddle.device.cuda.Stream),
95+
("num_sms", int),
96+
("smem_size", int),
97+
)
98+
if cfg["IS_GROUPED_MASKED"]:
99+
keys["GEMM_TYPE"] = "GroupedMasked"
100+
arg_defs = (
101+
("lhs", paddle.float8_e4m3fn),
102+
("lhs_scales", paddle.float32),
103+
("rhs", paddle.float8_e4m3fn),
104+
("rhs_scales", paddle.float32),
105+
("out", paddle.bfloat16),
106+
("grouped_layout", paddle.int32),
107+
("m", int),
108+
("stream", paddle.device.cuda.Stream),
109+
("num_sms", int),
110+
("smem_size", int),
111+
)
112+
if cfg["IS_GROUPED_CONTIGUOUS"] or cfg["IS_GROUPED_MASKED"]:
113+
keys["NUM_GROUPS"] = int(
114+
cfg["MOE_NUM_EXPERTS"] / cfg["EXPERT_PARALLEL"]
115+
)
116+
includes = m_grouped_includes
117+
template = m_grouped_template
118+
name = "m_grouped_gemm_fp8_fp8_bf16_nt"
119+
120+
code = generate(includes, arg_defs, cpp_format(template, keys))
121+
build(name, arg_defs, code)
122+
except Exception as e:
123+
logger.error(f"Failed to compile config {cfg}: {str(e)}")
124+
raise RuntimeError(e)
125+
finally:
126+
self.pbar.update(1)
127+
self.queue.task_done()
128+
129+
130+
def pre_compile_from_config(config_file: str, num_threads: int, expert_parallel: int):
131+
with open(config_file, "r") as f:
132+
start_time = time()
133+
lines = f.readlines()
134+
135+
queue = Queue()
136+
pbar = tqdm(total=len(lines), desc="Compiling")
137+
workers = []
138+
for _ in range(num_threads):
139+
worker = CompileWorker(queue, pbar)
140+
worker.start()
141+
workers.append(worker)
142+
143+
for line in lines:
144+
cfg = json.loads(line)
145+
cfg["EXPERT_PARALLEL"] = expert_parallel
146+
queue.put(cfg)
147+
148+
queue.join()
149+
150+
for _ in range(num_threads):
151+
queue.put(None)
152+
for worker in workers:
153+
worker.join()
154+
155+
pbar.close()
156+
157+
logger.info(f"Total compliation time: {time() - start_time:.2f} seconds")
158+
159+
160+
def main(args):
161+
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel)
162+
163+
164+
if __name__ == "__main__":
165+
166+
parser = argparse.ArgumentParser()
167+
parser.add_argument(
168+
"--config_file",
169+
type=str,
170+
default="./deep_gemm_pre_compile_config.jsonl",
171+
)
172+
parser.add_argument(
173+
"--expert_parallel",
174+
"--ep",
175+
type=int,
176+
default=8,
177+
)
178+
parser.add_argument(
179+
"--num_threads",
180+
type=int,
181+
default=16,
182+
)
183+
args = parser.parse_args()
184+
main(args)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
export PRE_COMPILE_LOG_LEVEL="INFO"
16+
export DG_CACHE_DIR=$(pwd)/deep_gemm_cache
17+
18+
echo DeepGEMM Cache Dir: $DG_CACHE_DIR
19+
20+
MODEL_PATH=${1:-"/path/to/model"}
21+
EXPERT_PARALLEL=${2:-"8"}
22+
nproc=$(nproc)
23+
24+
python generate_config.py \
25+
--model $MODEL_PATH \
26+
--output=./deep_gemm_pre_compile_config.jsonl
27+
28+
python pre_compile.py \
29+
--config_file=./deep_gemm_pre_compile_config.jsonl \
30+
--expert_parallel=$EXPERT_PARALLEL \
31+
--num_threads=$nproc

0 commit comments

Comments
 (0)