From a0fd3f6aa7d9eeb5030614c8e09b3ccba4e66a39 Mon Sep 17 00:00:00 2001 From: Deleter-D <867909454@qq.com> Date: Fri, 11 Jul 2025 19:56:40 +0800 Subject: [PATCH] Add DeepGEMM pre-compile tools --- .../deep_gemm_pre-compile/generate_config.py | 151 ++++++++++++++ tools/deep_gemm_pre-compile/pre_compile.py | 184 ++++++++++++++++++ tools/deep_gemm_pre-compile/pre_compile.sh | 31 +++ 3 files changed, 366 insertions(+) create mode 100644 tools/deep_gemm_pre-compile/generate_config.py create mode 100644 tools/deep_gemm_pre-compile/pre_compile.py create mode 100644 tools/deep_gemm_pre-compile/pre_compile.sh diff --git a/tools/deep_gemm_pre-compile/generate_config.py b/tools/deep_gemm_pre-compile/generate_config.py new file mode 100644 index 0000000000..ef746c4252 --- /dev/null +++ b/tools/deep_gemm_pre-compile/generate_config.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import math +import os +from typing import Tuple + +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \ + get_smem_config + +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler() +logger.addHandler(console_handler) +logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO")) + + +def generate_kn_pairs(model_cfg: dict) -> Tuple[list, list, list]: + hidden_size = model_cfg["hidden_size"] + intermediate_size = model_cfg["intermediate_size"] + moe_intermediate_size = model_cfg["moe_intermediate_size"] + num_attention_heads = model_cfg["num_attention_heads"] + num_key_value_heads = model_cfg["num_key_value_heads"] + head_dim = int(hidden_size / num_attention_heads) + gemm_kn_pairs = [ + # Dense normal gemm + [hidden_size, intermediate_size * 2], + [intermediate_size, hidden_size], + [hidden_size, hidden_size], + [hidden_size, (num_attention_heads + num_key_value_heads * 2) * head_dim], + ] + grouped_gemm_contiguous_kn_pairs = [ + # Moe grouped gemm contiguous + [hidden_size, moe_intermediate_size * 2], + [moe_intermediate_size, hidden_size], + ] + grouped_gemm_masked_kn_pairs = [ + # Moe grouped gemm masked + [hidden_size, moe_intermediate_size * 2], + [moe_intermediate_size, hidden_size], + ] + + return gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs + + +def generate_json( + kn_pairs: list, + moe_num_experts: int, + output_path: str, + is_grouped_contiguous: bool = False, + is_grouped_masked: bool = False, +): + if not is_grouped_contiguous: + BLOCK_MS = [64, 128, 256] + else: + BLOCK_MS = [128] + BLOCK_NS = list(range(16, 129, 8)) + [144, 160] + TMA_MULTICAST_CONFIGS = [(1, True), (1, False), (2, True), (2, False)] + counter = 0 + with open(output_path, "a+", encoding="utf-8") as f: + for block_m in BLOCK_MS: + for block_n in BLOCK_NS: + if 128 % block_n != 0 and 128 // math.gcd(128, block_n) <= 4: + NUM_STAGES = [4, 3] + else: + NUM_STAGES = [8, 7, 6, 5, 4, 3] + for num_stages in NUM_STAGES: + for kn_pair in kn_pairs: + smem_config = get_smem_config( + num_stages, kn_pair[0], block_m, block_n + ) + for tma_multicast_config in TMA_MULTICAST_CONFIGS: + cfg = { + "N": kn_pair[1], + "K": kn_pair[0], + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "SWIZZLE_D_MODE": smem_config[1], + "BLOCK_N_PADDING": smem_config[2], + "NUM_STAGES": num_stages, + "NUM_TMA_MULTICAST": tma_multicast_config[0], + "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1], + "IS_GROUPED_CONTIGUOUS": is_grouped_contiguous, + "IS_GROUPED_MASKED": is_grouped_masked, + "MOE_NUM_EXPERTS": moe_num_experts, + } + f.write(json.dumps(cfg) + "\n") + counter += 1 + + return counter + + +def main(args): + with open(os.path.join(args.model, "config.json"), "r") as f: + model_cfg = json.load(f) + + gemm_kn_pairs, grouped_gemm_contiguous_kn_pairs, grouped_gemm_masked_kn_pairs = ( + generate_kn_pairs(model_cfg) + ) + num_gemm = generate_json( + gemm_kn_pairs, + model_cfg["moe_num_experts"], + args.output, + ) + num_grouped_contiguous = generate_json( + grouped_gemm_contiguous_kn_pairs, + model_cfg["moe_num_experts"], + args.output, + is_grouped_contiguous=True, + ) + num_grouped_masked = generate_json( + grouped_gemm_masked_kn_pairs, + model_cfg["moe_num_experts"], + args.output, + is_grouped_masked=True, + ) + logger.info(f"Configurations generated and saved to {args.output}") + logger.info(f"Generated {num_gemm} gemm configuration.") + logger.info( + f"Generated {num_grouped_contiguous} grouped_gemm_contiguous configuration." + ) + logger.info(f"Generated {num_grouped_masked} grouped_gemm_masked configuration.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + required=True, + ) + parser.add_argument( + "--output", + type=str, + default="./deep_gemm_pre_compile_config.jsonl", + ) + args = parser.parse_args() + main(args) diff --git a/tools/deep_gemm_pre-compile/pre_compile.py b/tools/deep_gemm_pre-compile/pre_compile.py new file mode 100644 index 0000000000..38571f5cde --- /dev/null +++ b/tools/deep_gemm_pre-compile/pre_compile.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import threading +from queue import Queue +from time import time + +import paddle +from tqdm import tqdm + +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.compiler import build +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit.template import ( + cpp_format, generate) +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \ + includes as gemm_includes +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.gemm import \ + template as gemm_template +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \ + includes as m_grouped_includes +from fastdeploy.model_executor.ops.gpu.deep_gemm.jit_kernels.m_grouped_gemm import \ + template as m_grouped_template + +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler() +logger.addHandler(console_handler) +logger.setLevel(os.getenv("PRE_COMPILE_LOG_LEVEL", "INFO")) + + +class CompileWorker(threading.Thread): + def __init__(self, queue, pbar): + super().__init__() + self.queue = queue + self.pbar = pbar + + def run(self): + while True: + cfg = self.queue.get() + if cfg is None: + break + + try: + logger.debug(f"Compiling for config: {cfg}") + keys = { + "N": cfg["N"], + "K": cfg["K"], + "BLOCK_M": cfg["BLOCK_M"], + "BLOCK_N": cfg["BLOCK_N"], + "SWIZZLE_D_MODE": cfg["SWIZZLE_D_MODE"], + "BLOCK_N_PADDING": cfg["BLOCK_N_PADDING"], + "NUM_STAGES": cfg["NUM_STAGES"], + "NUM_TMA_MULTICAST": cfg["NUM_TMA_MULTICAST"], + "IS_TMA_MULTICAST_ON_A": cfg["IS_TMA_MULTICAST_ON_A"], + } + arg_defs = ( + ("lhs", paddle.float8_e4m3fn), + ("lhs_scales", paddle.float32), + ("rhs", paddle.float8_e4m3fn), + ("rhs_scales", paddle.float32), + ("out", paddle.bfloat16), + ("m", int), + ("stream", paddle.device.cuda.Stream), + ("num_sms", int), + ("smem_size", int), + ) + name = "gemm_fp8_fp8_bf16_nt" + includes = gemm_includes + template = gemm_template + if cfg["IS_GROUPED_CONTIGUOUS"]: + keys["GEMM_TYPE"] = "GroupedContiguous" + arg_defs = ( + ("lhs", paddle.float8_e4m3fn), + ("lhs_scales", paddle.float32), + ("rhs", paddle.float8_e4m3fn), + ("rhs_scales", paddle.float32), + ("out", paddle.bfloat16), + ("grouped_layout", paddle.int32), + ("m", int), + ("num_groups", int), + ("stream", paddle.device.cuda.Stream), + ("num_sms", int), + ("smem_size", int), + ) + if cfg["IS_GROUPED_MASKED"]: + keys["GEMM_TYPE"] = "GroupedMasked" + arg_defs = ( + ("lhs", paddle.float8_e4m3fn), + ("lhs_scales", paddle.float32), + ("rhs", paddle.float8_e4m3fn), + ("rhs_scales", paddle.float32), + ("out", paddle.bfloat16), + ("grouped_layout", paddle.int32), + ("m", int), + ("stream", paddle.device.cuda.Stream), + ("num_sms", int), + ("smem_size", int), + ) + if cfg["IS_GROUPED_CONTIGUOUS"] or cfg["IS_GROUPED_MASKED"]: + keys["NUM_GROUPS"] = int( + cfg["MOE_NUM_EXPERTS"] / cfg["EXPERT_PARALLEL"] + ) + includes = m_grouped_includes + template = m_grouped_template + name = "m_grouped_gemm_fp8_fp8_bf16_nt" + + code = generate(includes, arg_defs, cpp_format(template, keys)) + build(name, arg_defs, code) + except Exception as e: + logger.error(f"Failed to compile config {cfg}: {str(e)}") + raise RuntimeError(e) + finally: + self.pbar.update(1) + self.queue.task_done() + + +def pre_compile_from_config(config_file: str, num_threads: int, expert_parallel: int): + with open(config_file, "r") as f: + start_time = time() + lines = f.readlines() + + queue = Queue() + pbar = tqdm(total=len(lines), desc="Compiling") + workers = [] + for _ in range(num_threads): + worker = CompileWorker(queue, pbar) + worker.start() + workers.append(worker) + + for line in lines: + cfg = json.loads(line) + cfg["EXPERT_PARALLEL"] = expert_parallel + queue.put(cfg) + + queue.join() + + for _ in range(num_threads): + queue.put(None) + for worker in workers: + worker.join() + + pbar.close() + + logger.info(f"Total compliation time: {time() - start_time:.2f} seconds") + + +def main(args): + pre_compile_from_config(args.config_file, args.num_threads, args.expert_parallel) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--config_file", + type=str, + default="./deep_gemm_pre_compile_config.jsonl", + ) + parser.add_argument( + "--expert_parallel", + "--ep", + type=int, + default=8, + ) + parser.add_argument( + "--num_threads", + type=int, + default=16, + ) + args = parser.parse_args() + main(args) diff --git a/tools/deep_gemm_pre-compile/pre_compile.sh b/tools/deep_gemm_pre-compile/pre_compile.sh new file mode 100644 index 0000000000..8b609dfeb1 --- /dev/null +++ b/tools/deep_gemm_pre-compile/pre_compile.sh @@ -0,0 +1,31 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export PRE_COMPILE_LOG_LEVEL="INFO" +export DG_CACHE_DIR=$(pwd)/deep_gemm_cache + +echo DeepGEMM Cache Dir: $DG_CACHE_DIR + +MODEL_PATH=${1:-"/path/to/model"} +EXPERT_PARALLEL=${2:-"8"} +nproc=$(nproc) + +python generate_config.py \ + --model $MODEL_PATH \ + --output=./deep_gemm_pre_compile_config.jsonl + +python pre_compile.py \ + --config_file=./deep_gemm_pre_compile_config.jsonl \ + --expert_parallel=$EXPERT_PARALLEL \ + --num_threads=$nproc \ No newline at end of file