Skip to content

[Feature] Add DeepGEMM pre-compile tools. #2819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions tools/deep_gemm_pre-compile/generate_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import math
import os
from typing import Tuple

from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
get_smem_config

logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler()
logger.addHandler(console_handler)
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))


def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]:
hidden_size = model_cfg["hidden_size"]
intermediate_size = model_cfg["intermediate_size"]
moe_intermediate_size = model_cfg["moe_intermediate_size"]
num_attention_heads = model_cfg["num_attention_heads"]
num_key_value_heads = model_cfg["num_key_value_heads"]
head_dim = int(hidden_size / num_attention_heads)
gemm_kn_pairs = [
# Dense normal gemm
[hidden_size, intermediate_size * 2],
[intermediate_size, hidden_size],
[hidden_size, hidden_size],
[hidden_size, (num_attention_heads + num_key_value_heads * 2) * head_dim],
]
grouped_gemm_contiguous_kn_pairs = [
# Moe grouped gemm contiguous
[hidden_size, moe_intermediate_size * 2],
[moe_intermediate_size, hidden_size],
]
grouped_gemm_masked_kn_pairs = [
# Moe grouped gemm masked
[hidden_size, moe_intermediate_size * 2],
[moe_intermediate_size, hidden_size],
]

return gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs


def generate_json(
kn_pairs: list,
moe_num_experts: int,
output_path: str,
is_grouped_contiguous: bool = False,
is_grouped_masked: bool = False,
):
if not is_grouped_contiguous:
BLOCK_MS = [64, 128, 256]
else:
BLOCK_MS = [128]
BLOCK_NS = list(range(16, 129, 8)) + [144, 160]
TMA_MULTICAST_CONFIGS = [(1, True), (1, False), (2, True), (2, False)]
counter = 0
with open(output_path, "a+", encoding="utf-8") as f:
for block_m in BLOCK_MS:
for block_n in BLOCK_NS:
if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4:
NUM_STAGES = [4, 3]
else:
NUM_STAGES = [8, 7, 6, 5, 4, 3]
for num_stages in NUM_STAGES:
for kn_pair in kn_pairs:
smem_config = get_smem_config(
num_stages, kn_pair[0], block_m, block_n
)
for tma_multicast_config in TMA_MULTICAST_CONFIGS:
cfg = {
"N": kn_pair[1],
"K": kn_pair[0],
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"SWIZZLE_D_MODE": smem_config[1],
"BLOCK_N_PADDING": smem_config[2],
"NUM_STAGES": num_stages,
"NUM_TMA_MULTICAST": tma_multicast_config[0],
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
"IS_GROUPED_CONTIGUOUS": is_grouped_contiguous,
"IS_GROUPED_MASKED": is_grouped_masked,
"MOE_NUM_EXPERTS": moe_num_experts,
}
f.write(json.dumps(cfg) + "\n")
counter += 1

return counter


def main(args):
with open(os.path.join(args.model, "config.json"), "r") as f:
model_cfg = json.load(f)

gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs = (
generate_kn_pairs(model_cfg)
)
num_gemm = generate_json(
gemm_kn_pairs,
model_cfg["moe_num_experts"],
args.output,
)
num_grouped_contiguous = generate_json(
grouped_gemm_contiguous_kn_pairs,
model_cfg["moe_num_experts"],
args.output,
is_grouped_contiguous=True,
)
num_grouped_masked = generate_json(
grouped_gemm_masked_kn_pairs,
model_cfg["moe_num_experts"],
args.output,
is_grouped_masked=True,
)
logger.info(f"Configurations generated and saved to {args.output}")
logger.info(f"Generated {num_gemm} gemm configuration.")
logger.info(
f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration."
)
logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
)
parser.add_argument(
"--output",
type=str,
default="./deep_gemm_pre_compile_config.jsonl",
)
args = parser.parse_args()
main(args)
184 changes: 184 additions & 0 deletions tools/deep_gemm_pre-compile/pre_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import os
import threading
from queue import Queue
from time import time

import paddle
from tqdm import tqdm

from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.compiler import build
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.template import (
cpp_format, generate)
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
includes as gemm_includes
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \
template as gemm_template
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \
includes as m_grouped_includes
from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \
template as m_grouped_template

logger = logging.getLogger(__name__)
console_handler = logging.StreamHandler()
logger.addHandler(console_handler)
logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO"))


class CompileWorker(threading.Thread):
def __init__(self, queue, pbar):
super().__init__()
self.queue = queue
self.pbar = pbar

def run(self):
while True:
cfg = self.queue.get()
if cfg is None:
break

try:
logger.debug(f"Compiling for config: {cfg}")
keys = {
"N": cfg["N"],
"K": cfg["K"],
"BLOCK_M": cfg["BLOCK_M"],
"BLOCK_N": cfg["BLOCK_N"],
"SWIZZLE_D_MODE": cfg["SWIZZLE_D_MODE"],
"BLOCK_N_PADDING": cfg["BLOCK_N_PADDING"],
"NUM_STAGES": cfg["NUM_STAGES"],
"NUM_TMA_MULTICAST": cfg["NUM_TMA_MULTICAST"],
"IS_TMA_MULTICAST_ON_A": cfg["IS_TMA_MULTICAST_ON_A"],
}
arg_defs = (
("lhs", paddle.float8_e4m3fn),
("lhs_scales", paddle.float32),
("rhs", paddle.float8_e4m3fn),
("rhs_scales", paddle.float32),
("out", paddle.bfloat16),
("m", int),
("stream", paddle.device.cuda.Stream),
("num_sms", int),
("smem_size", int),
)
name = "gemm_fp8_fp8_bf16_nt"
includes = gemm_includes
template = gemm_template
if cfg["IS_GROUPED_CONTIGUOUS"]:
keys["GEMM_TYPE"] = "GroupedContiguous"
arg_defs = (
("lhs", paddle.float8_e4m3fn),
("lhs_scales", paddle.float32),
("rhs", paddle.float8_e4m3fn),
("rhs_scales", paddle.float32),
("out", paddle.bfloat16),
("grouped_layout", paddle.int32),
("m", int),
("num_groups", int),
("stream", paddle.device.cuda.Stream),
("num_sms", int),
("smem_size", int),
)
if cfg["IS_GROUPED_MASKED"]:
keys["GEMM_TYPE"] = "GroupedMasked"
arg_defs = (
("lhs", paddle.float8_e4m3fn),
("lhs_scales", paddle.float32),
("rhs", paddle.float8_e4m3fn),
("rhs_scales", paddle.float32),
("out", paddle.bfloat16),
("grouped_layout", paddle.int32),
("m", int),
("stream", paddle.device.cuda.Stream),
("num_sms", int),
("smem_size", int),
)
if cfg["IS_GROUPED_CONTIGUOUS"] or cfg["IS_GROUPED_MASKED"]:
keys["NUM_GROUPS"] = int(
cfg["MOE_NUM_EXPERTS"] / cfg["EXPERT_PARALLEL"]
)
includes = m_grouped_includes
template = m_grouped_template
name = "m_grouped_gemm_fp8_fp8_bf16_nt"

code = generate(includes, arg_defs, cpp_format(template, keys))
build(name, arg_defs, code)
except Exception as e:
logger.error(f"Failed to compile config {cfg}: {str(e)}")
raise RuntimeError(e)
finally:
self.pbar.update(1)
self.queue.task_done()


def pre_compile_from_config(config_file: str, num_threads: int, expert_parallel: int):
with open(config_file, "r") as f:
start_time = time()
lines = f.readlines()

queue = Queue()
pbar = tqdm(total=len(lines), desc="Compiling")
workers = []
for _ in range(num_threads):
worker = CompileWorker(queue, pbar)
worker.start()
workers.append(worker)

for line in lines:
cfg = json.loads(line)
cfg["EXPERT_PARALLEL"] = expert_parallel
queue.put(cfg)

queue.join()

for _ in range(num_threads):
queue.put(None)
for worker in workers:
worker.join()

pbar.close()

logger.info(f"Total compliation time: {time() - start_time:.2f} seconds")


def main(args):
pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel)


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
type=str,
default="./deep_gemm_pre_compile_config.jsonl",
)
parser.add_argument(
"--expert_parallel",
"--ep",
type=int,
default=8,
)
parser.add_argument(
"--num_threads",
type=int,
default=16,
)
args = parser.parse_args()
main(args)
31 changes: 31 additions & 0 deletions tools/deep_gemm_pre-compile/pre_compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

export PRE_COMPILE_LOG_LEVEL="INFO"
export DG_CACHE_DIR=$(pwd)/deep_gemm_cache

echo DeepGEMM Cache Dir: $DG_CACHE_DIR

MODEL_PATH=${1:-"/path/to/model"}
EXPERT_PARALLEL=${2:-"8"}
nproc=$(nproc)

python generate_config.py \
--model $MODEL_PATH \
--output=./deep_gemm_pre_compile_config.jsonl

python pre_compile.py \
--config_file=./deep_gemm_pre_compile_config.jsonl \
--expert_parallel=$EXPERT_PARALLEL \
--num_threads=$nproc