diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index a378bc6baa5..68aff793ae6 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import sys @@ -8,12 +9,12 @@ # Note that we have 400 MiB quota, please use it wisely. # See https://github.com/pypi/support/issues/3792 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) def print_top_10_largest_files(zip_file): """Print the top 10 largest files in the given zip file.""" - with zipfile.ZipFile(zip_file, 'r') as z: + with zipfile.ZipFile(zip_file, "r") as z: file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes.sort(key=lambda x: x[1], reverse=True) for f, size in file_sizes[:10]: @@ -28,14 +29,18 @@ def check_wheel_size(directory): wheel_path = os.path.join(root, file_name) wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) if wheel_size_mb > VLLM_MAX_SIZE_MB: - print(f"Not allowed: Wheel {wheel_path} is larger " - f"({wheel_size_mb:.2f} MB) than the limit " - f"({VLLM_MAX_SIZE_MB} MB).") + print( + f"Not allowed: Wheel {wheel_path} is larger " + f"({wheel_size_mb:.2f} MB) than the limit " + f"({VLLM_MAX_SIZE_MB} MB)." + ) print_top_10_largest_files(wheel_path) return 1 else: - print(f"Wheel {wheel_path} is within the allowed size " - f"({wheel_size_mb:.2f} MB).") + print( + f"Wheel {wheel_path} is within the allowed size " + f"({wheel_size_mb:.2f} MB)." + ) return 0 @@ -45,4 +50,4 @@ def check_wheel_size(directory): sys.exit(1) directory = sys.argv[1] - sys.exit(check_wheel_size(directory)) \ No newline at end of file + sys.exit(check_wheel_size(directory)) diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 36e1b6c0132..7045d881049 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import os @@ -22,5 +23,5 @@ print(f"Generated index.html for {args.wheel}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, - wheel_html_escaped=filename.replace("+", "%2B"))) + template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml new file mode 100644 index 00000000000..cca58097e8a --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.335 + - name: "exact_match,flexible-extract" + value: 0.323 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml new file mode 100644 index 00000000000..54579a63a9b --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1 +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.54 + - name: "exact_match,flexible-extract" + value: 0.59 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 00000000000..a2f235f4858 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.47 + - name: "exact_match,flexible-extract" + value: 0.64 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 37eeac85c93..27a1a9a82bd 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 254d01edf84..36e0543879b 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,10 +1,6 @@ -Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +Qwen2.5-1.5B-Instruct.yaml Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml -Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-compressed-tensors.yaml -Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml -Qwen2-1.5B-Instruct-FP8W8.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/conftest.py b/.buildkite/lm-eval-harness/conftest.py new file mode 100644 index 00000000000..c0d60dd5328 --- /dev/null +++ b/.buildkite/lm-eval-harness/conftest.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from pathlib import Path + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--config-list-file", + action="store", + help="Path to the file listing model config YAMLs (one per line)", + ) + parser.addoption( + "--tp-size", + action="store", + default="1", + help="Tensor parallel size to use for evaluation", + ) + + +@pytest.fixture(scope="session") +def config_list_file(pytestconfig, config_dir): + rel_path = pytestconfig.getoption("--config-list-file") + return config_dir / rel_path + + +@pytest.fixture(scope="session") +def tp_size(pytestconfig): + return pytestconfig.getoption("--tp-size") + + +def pytest_generate_tests(metafunc): + if "config_filename" in metafunc.fixturenames: + rel_path = metafunc.config.getoption("--config-list-file") + config_list_file = Path(rel_path).resolve() + config_dir = config_list_file.parent + with open(config_list_file, encoding="utf-8") as f: + configs = [ + config_dir / line.strip() + for line in f + if line.strip() and not line.startswith("#") + ] + metafunc.parametrize("config_filename", configs) diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh deleted file mode 100644 index 26f33b74428..00000000000 --- a/.buildkite/lm-eval-harness/run-tests.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -usage() { - echo`` - echo "Runs lm eval harness on GSM8k using vllm and compares to " - echo "precomputed baseline (measured by HF transformers.)" - echo - echo "usage: ${0} " - echo - echo " -c - path to the test data config (e.g. configs/small-models.txt)" - echo " -t - tensor parallel size" - echo -} - -SUCCESS=0 - -while getopts "c:t:" OPT; do - case ${OPT} in - c ) - CONFIG="$OPTARG" - ;; - t ) - TP_SIZE="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -# Parse list of configs. -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG" - -for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" -do - LOCAL_SUCCESS=0 - - echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE===" - - export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG} - export LM_EVAL_TP_SIZE=$TP_SIZE - pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$? - - if [[ $LOCAL_SUCCESS == 0 ]]; then - echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" - else - echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" - fi - - SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) - -done - -if [ "${SUCCESS}" -eq "0" ]; then - exit 0 -else - exit 1 -fi diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 6015a83e829..930adfaf3e1 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -1,69 +1,55 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ LM eval harness on model to compare vs HF baseline computed offline. Configs are found in configs/$MODEL.yaml -* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml -* export LM_EVAL_TP_SIZE=4 -* pytest -s test_lm_eval_correctness.py +pytest -s -v test_lm_eval_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 """ -import os -from pathlib import Path - import lm_eval -import numpy -import pytest +import numpy as np import yaml RTOL = 0.08 -TEST_DATA_FILE = os.environ.get( - "LM_EVAL_TEST_DATA_FILE", - ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") - -TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) - -def launch_lm_eval(eval_config): - trust_remote_code = eval_config.get('trust_remote_code', False) - - model_args = f"pretrained={eval_config['model_name']}," \ - f"tensor_parallel_size={TP_SIZE}," \ - f"add_bos_token=true," \ - f"trust_remote_code={trust_remote_code}" +def launch_lm_eval(eval_config, tp_size): + trust_remote_code = eval_config.get("trust_remote_code", False) + model_args = ( + f"pretrained={eval_config['model_name']}," + f"tensor_parallel_size={tp_size}," + f"enforce_eager=true," + f"add_bos_token=true," + f"trust_remote_code={trust_remote_code}" + ) results = lm_eval.simple_evaluate( model="vllm", model_args=model_args, tasks=[task["name"] for task in eval_config["tasks"]], num_fewshot=eval_config["num_fewshot"], limit=eval_config["limit"], - batch_size="auto") - + batch_size="auto", + ) return results -def test_lm_eval_correctness(): - eval_config = yaml.safe_load( - Path(TEST_DATA_FILE).read_text(encoding="utf-8")) - - if eval_config[ - "model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501 - pytest.skip("FBGEMM is currently failing on main.") +def test_lm_eval_correctness_param(config_filename, tp_size): + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) - # Launch eval requests. - results = launch_lm_eval(eval_config) + results = launch_lm_eval(eval_config, tp_size) - # Confirm scores match ground truth. success = True for task in eval_config["tasks"]: for metric in task["metrics"]: ground_truth = metric["value"] measured_value = results["results"][task["name"]][metric["name"]] - print(f'{task["name"]} | {metric["name"]}: ' - f'ground_truth={ground_truth} | measured={measured_value}') - success = success and numpy.isclose( - ground_truth, measured_value, rtol=RTOL) + print( + f"{task['name']} | {metric['name']}: " + f"ground_truth={ground_truth} | measured={measured_value}" + ) + success = success and np.isclose(ground_truth, measured_value, rtol=RTOL) - # Assert at the end, print all scores even on failure for debugging. assert success diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index d3f5fc5cd4c..72c52d5bb5e 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -113,7 +113,7 @@ WARNING: The benchmarking script will save json results by itself, so please do ### Visualizing the results -The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results. +The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](performance-benchmarks-descriptions.md) with real benchmarking results. You can find the result presented as a table inside the `buildkite/performance-benchmark` job page. If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 1030ec24e8d..a4f1638c1ad 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import os @@ -65,18 +66,18 @@ def read_markdown(file): def results_to_json(latency, throughput, serving): - return json.dumps({ - 'latency': latency.to_dict(), - 'throughput': throughput.to_dict(), - 'serving': serving.to_dict() - }) + return json.dumps( + { + "latency": latency.to_dict(), + "throughput": throughput.to_dict(), + "serving": serving.to_dict(), + } + ) if __name__ == "__main__": - # collect results for test_file in results_folder.glob("*.json"): - with open(test_file) as f: raw_result = json.loads(f.read()) @@ -120,7 +121,8 @@ def results_to_json(latency, throughput, serving): for perc in [10, 25, 50, 75, 90, 99]: # Multiply 1000 to convert the time unit from s to ms raw_result.update( - {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) + {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]} + ) raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 # add the result to raw_result @@ -153,26 +155,27 @@ def results_to_json(latency, throughput, serving): serving_results = pd.DataFrame.from_dict(serving_results) throughput_results = pd.DataFrame.from_dict(throughput_results) - raw_results_json = results_to_json(latency_results, throughput_results, - serving_results) + raw_results_json = results_to_json( + latency_results, throughput_results, serving_results + ) # remapping the key, for visualization purpose if not latency_results.empty: - latency_results = latency_results[list( - latency_column_mapping.keys())].rename( - columns=latency_column_mapping) + latency_results = latency_results[list(latency_column_mapping.keys())].rename( + columns=latency_column_mapping + ) if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) + serving_results = serving_results[list(serving_column_mapping.keys())].rename( + columns=serving_column_mapping + ) if not throughput_results.empty: - throughput_results = throughput_results[list( - throughput_results_column_mapping.keys())].rename( - columns=throughput_results_column_mapping) + throughput_results = throughput_results[ + list(throughput_results_column_mapping.keys()) + ].rename(columns=throughput_results_column_mapping) - processed_results_json = results_to_json(latency_results, - throughput_results, - serving_results) + processed_results_json = results_to_json( + latency_results, throughput_results, serving_results + ) for df in [latency_results, serving_results, throughput_results]: if df.empty: @@ -184,38 +187,39 @@ def results_to_json(latency, throughput, serving): # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # we want to turn it into "8xGPUTYPE" df["GPU"] = df["GPU"].apply( - lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") + lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}" + ) # get markdown tables - latency_md_table = tabulate(latency_results, - headers='keys', - tablefmt='pipe', - showindex=False) - serving_md_table = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) - throughput_md_table = tabulate(throughput_results, - headers='keys', - tablefmt='pipe', - showindex=False) + latency_md_table = tabulate( + latency_results, headers="keys", tablefmt="pipe", showindex=False + ) + serving_md_table = tabulate( + serving_results, headers="keys", tablefmt="pipe", showindex=False + ) + throughput_md_table = tabulate( + throughput_results, headers="keys", tablefmt="pipe", showindex=False + ) # document the result with open(results_folder / "benchmark_results.md", "w") as f: - - results = read_markdown("../.buildkite/nightly-benchmarks/" + - "performance-benchmarks-descriptions.md") + results = read_markdown( + "../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md" + ) results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, serving_tests_markdown_table=serving_md_table, - benchmarking_results_in_json_string=processed_results_json) + benchmarking_results_in_json_string=processed_results_json, + ) f.write(results) # document benchmarking results in json with open(results_folder / "benchmark_results.json", "w") as f: - - results = latency_results.to_dict( - orient='records') + throughput_results.to_dict( - orient='records') + serving_results.to_dict(orient='records') + results = ( + latency_results.to_dict(orient="records") + + throughput_results.to_dict(orient="records") + + serving_results.to_dict(orient="records") + ) f.write(json.dumps(results)) diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py index 5e17b79d26a..8532ff7ef79 100644 --- a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse @@ -14,15 +15,12 @@ def main(model, cachedir): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Download and save Hugging Face tokenizer") - parser.add_argument("--model", - type=str, - required=True, - help="Name of the model") - parser.add_argument("--cachedir", - type=str, - required=True, - help="Directory to save the tokenizer") + description="Download and save Hugging Face tokenizer" + ) + parser.add_argument("--model", type=str, required=True, help="Name of the model") + parser.add_argument( + "--cachedir", type=str, required=True, help="Directory to save the tokenizer" + ) args = parser.parse_args() main(args.model, args.cachedir) diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py index 0ff95a0911b..053fd52c35a 100644 --- a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json @@ -11,33 +12,33 @@ def parse_arguments(): parser = argparse.ArgumentParser( - description= - 'Parse command line arguments for summary-nightly-results script.') - parser.add_argument('--results-folder', - type=str, - required=True, - help='The folder where the results are stored.') - parser.add_argument('--description', - type=str, - required=True, - help='Description of the results.') + description="Parse command line arguments for summary-nightly-results script." + ) + parser.add_argument( + "--results-folder", + type=str, + required=True, + help="The folder where the results are stored.", + ) + parser.add_argument( + "--description", type=str, required=True, help="Description of the results." + ) args = parser.parse_args() return args def get_perf(df, method, model, metric): - means = [] for qps in [2, 4, 8, 16, "inf"]: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - target = target & df['Test name'].str.contains("qps_" + str(qps)) + target = df["Test name"].str.contains(model) + target = target & df["Engine"].str.contains(method) + target = target & df["Test name"].str.contains("qps_" + str(qps)) filtered_df = df[target] if filtered_df.empty: - means.append(0.) + means.append(0.0) else: means.append(filtered_df[metric].values[0]) @@ -45,7 +46,6 @@ def get_perf(df, method, model, metric): def get_perf_w_std(df, method, model, metric): - if metric in ["TTFT", "ITL"]: mean = get_perf(df, method, model, "Mean " + metric + " (ms)") mean = mean.tolist() @@ -60,7 +60,8 @@ def get_perf_w_std(df, method, model, metric): else: assert metric == "Tput" mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( - df, method, model, "Output Tput (tok/s)") + df, method, model, "Output Tput (tok/s)" + ) mean = mean.tolist() std = None @@ -80,18 +81,17 @@ def main(args): # generate markdown table df = pd.DataFrame.from_dict(results) - md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) + md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False) with open(args.description) as f: description = f.read() - description = description.format( - nightly_results_benchmarking_table=md_table) + description = description.format(nightly_results_benchmarking_table=md_table) with open("nightly_results.md", "w") as f: f.write(description) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py index e5f179a0f5b..ddea1d2b1b1 100644 --- a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py +++ b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from lmdeploy.serve.openai.api_client import APIClient diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py index 62ee5e10b50..fb3b9d5e34e 100644 --- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import datetime import json @@ -34,10 +35,8 @@ } if __name__ == "__main__": - # collect results for test_file in results_folder.glob("*.json"): - with open(test_file) as f: raw_result = json.loads(f.read()) @@ -56,17 +55,16 @@ serving_results = pd.DataFrame.from_dict(serving_results) if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) + serving_results = serving_results[list(serving_column_mapping.keys())].rename( + columns=serving_column_mapping + ) - serving_md_table_with_headers = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) + serving_md_table_with_headers = tabulate( + serving_results, headers="keys", tablefmt="pipe", showindex=False + ) # remove the first line of header - serving_md_table_lines = serving_md_table_with_headers.split('\n') - serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:]) + serving_md_table_lines = serving_md_table_with_headers.split("\n") + serving_md_table_without_header = "\n".join(serving_md_table_lines[2:]) prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE") @@ -76,10 +74,9 @@ # document results with header. # for those who wants to reproduce our benchmark. f.write(serving_md_table_with_headers) - f.write('\n') + f.write("\n") # document benchmarking results in json with open(results_folder / f"{prefix}_nightly_results.json", "w") as f: - - results = serving_results.to_dict(orient='records') + results = serving_results.to_dict(orient="records") f.write(json.dumps(results)) diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml new file mode 100644 index 00000000000..d5cad1c73c6 --- /dev/null +++ b/.buildkite/pyproject.toml @@ -0,0 +1,46 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.format] +docstring-code-format = true diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 4cc9c70a6ad..16b5ad0297f 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,5 +1,6 @@ steps: - label: "Build wheel - CUDA 12.8" + id: build-wheel-cuda-12-8 agents: queue: cpu_queue_postmerge commands: @@ -11,10 +12,11 @@ steps: DOCKER_BUILDKIT: "1" - label: "Build wheel - CUDA 12.6" + id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -28,10 +30,11 @@ steps: - label: "Build wheel - CUDA 11.8" # depends_on: block-build-cu118-wheel + id: build-wheel-cuda-11-8 agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -44,6 +47,7 @@ steps: - label: "Build release image" depends_on: block-release-image-build + id: build-release-image agents: queue: cpu_queue_postmerge commands: @@ -51,6 +55,18 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Annotate release workflow" + depends_on: + - build-release-image + - build-wheel-cuda-12-8 + - build-wheel-cuda-12-6 + - build-wheel-cuda-11-8 + id: annotate-release-workflow + agents: + queue: cpu_queue_postmerge + commands: + - "bash .buildkite/scripts/annotate-release.sh" + - label: "Build and publish TPU release image" depends_on: ~ if: build.env("NIGHTLY") == "1" @@ -64,15 +80,16 @@ steps: - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" plugins: - docker-login#v3.0.0: - username: vllm + username: vllmbot password-env: DOCKERHUB_TOKEN env: DOCKER_BUILDKIT: "1" - input: "Provide Release version here" + id: input-release-version fields: - text: "What is the release version?" - key: "release-version" + key: release-version - block: "Build CPU release image" key: block-cpu-release-image-build diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh new file mode 100755 index 00000000000..94e0ac2398f --- /dev/null +++ b/.buildkite/scripts/annotate-release.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -ex + +# Get release version and strip leading 'v' if present +RELEASE_VERSION=$(buildkite-agent meta-data get release-version | sed 's/^v//') + +if [ -z "$RELEASE_VERSION" ]; then + echo "Error: RELEASE_VERSION is empty. 'release-version' metadata might not be set or is invalid." + exit 1 +fi + +buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF +To download the wheel: +\`\`\` +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu118/vllm-${RELEASE_VERSION}+cu118-cp38-abi3-manylinux1_x86_64.whl . +\`\`\` + +To download and upload the image: + +\`\`\` +docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} +docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} vllm/vllm-openai +docker tag vllm/vllm-openai vllm/vllm-openai:latest +docker tag vllm/vllm-openai vllm/vllm-openai:v${RELEASE_VERSION} +docker push vllm/vllm-openai:latest +docker push vllm/vllm-openai:v${RELEASE_VERSION} +\`\`\` +EOF \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index d29903bf497..6e9af1e721b 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -3,6 +3,9 @@ # This script runs test inside the corresponding ROCm docker container. set -o pipefail +# Export Python path +export PYTHONPATH=".." + # Print ROCm version echo "--- Confirming Clean Initial State" while true; do @@ -74,6 +77,27 @@ HF_MOUNT="/root/.cache/huggingface" commands=$@ echo "Commands:$commands" + +if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} +fi + +if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then + commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} +fi + +if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then + commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"} +fi + +if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} +fi + +if [[ $commands == *"pytest -v -s lora"* ]]; then + commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"} +fi + #ignore certain kernels tests if [[ $commands == *" kernels/core"* ]]; then commands="${commands} \ @@ -161,6 +185,8 @@ fi PARALLEL_JOB_COUNT=8 +MYPYTHONPATH=".." + # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then # assign job count as the number of shards used @@ -181,6 +207,7 @@ if [[ $commands == *"--shard-id="* ]]; then -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}_${GPU}" \ "${image_name}" \ /bin/bash -c "${commands_gpu}" \ @@ -211,6 +238,7 @@ else -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}" \ "${image_name}" \ /bin/bash -c "${commands}" diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 5d863dd82e9..077bd991490 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -32,9 +32,12 @@ function cpu_tests() { set -e pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install sentence-transformers datamodel_code_generator - pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] - pytest -v -s tests/models/encoder_decoder/language -m cpu_model" + pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] + pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] + pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]" } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 40f3df96065..61aa7df13b4 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -6,72 +6,67 @@ set -ex # allow to bind to different cores CORE_RANGE=${CORE_RANGE:-48-95} +OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95} NUMA_NODE=${NUMA_NODE:-1} +export CMAKE_BUILD_PARALLEL_LEVEL=32 + # Setup cleanup remove_docker_container() { set -e; - docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; - docker image rm cpu-test-"$BUILDKITE_BUILD_NUMBER" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 || true; + docker rm -f cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"-avx2 || true; } trap remove_docker_container EXIT remove_docker_container # Try building the docker image -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$BUILDKITE_BUILD_NUMBER" --target vllm-test -f docker/Dockerfile.cpu . -numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 --target vllm-test -f docker/Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE" --target vllm-test -f docker/Dockerfile.cpu . +numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER" -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e export NUMA_NODE=$2 - export BUILDKITE_BUILD_NUMBER=$3 # offline inference - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " set -e python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" # Run basic model test - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -v -s tests/kernels/test_cache.py -m cpu_model - pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model - pytest -v -s tests/models/decoder_only/language -m cpu_model - pytest -v -s tests/models/embedding/language -m cpu_model - pytest -v -s tests/models/encoder_decoder/language -m cpu_model - pytest -v -s tests/models/decoder_only/audio_language -m cpu_model - pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" + pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model + pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + pytest -v -s tests/models/language/generation -m cpu_model + pytest -v -s tests/models/language/pooling -m cpu_model + pytest -v -s tests/models/multimodal/generation --ignore=tests/models/multimodal/generation/test_mllama.py -m cpu_model" # Run compressed-tensor test - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" # Run AWQ test - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + VLLM_USE_V1=0 pytest -s -v \ tests/quantization/test_ipex_quant.py" # Run chunked-prefill and prefix-cache test - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v -k cpu_model \ tests/basic_correctness/test_chunked_prefill.py" # online serving - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - export VLLM_CPU_KVCACHE_SPACE=10 - export VLLM_CPU_OMP_THREADS_BIND=$1 python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 python3 benchmarks/benchmark_serving.py \ @@ -83,7 +78,7 @@ function cpu_tests() { --tokenizer facebook/opt-125m" # Run multi-lora tests - docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v \ tests/lora/test_qwen2vl.py" @@ -91,4 +86,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE $BUILDKITE_BUILD_NUMBER" +timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh index 95b6ac37f18..5efac3ddf46 100644 --- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh @@ -10,15 +10,17 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu . # Setup cleanup # certain versions of HPU software stack have a bug that can # override the exit code of the script, so we need to use -# separate remove_docker_container and remove_docker_container_and_exit +# separate remove_docker_containers and remove_docker_containers_and_exit # functions, while other platforms only need one remove_docker_container # function. EXITCODE=1 -remove_docker_container() { docker rm -f hpu-test || true; } -remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; } -trap remove_docker_container_and_exit EXIT -remove_docker_container +remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; } +remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; } +trap remove_docker_containers_and_exit EXIT +remove_docker_containers # Run the image and launch offline inference docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m +docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2 + EXITCODE=$? diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh index ec6a080eb49..3d294ea5f8a 100644 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ b/.buildkite/scripts/hardware_ci/run-neuron-test.sh @@ -11,13 +11,14 @@ container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" +HF_TOKEN=$(aws secretsmanager get-secret-value --secret-id "ci/vllm-neuron/hf-token" --region us-west-2 --query 'SecretString' --output text | jq -r .VLLM_NEURON_CI_HF_TOKEN) NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache" mkdir -p "${NEURON_COMPILE_CACHE_URL}" NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache" # Try building the docker image -aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com +aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws # prune old image and containers to save disk space, and only once a day # by using a timestamp file in tmp. @@ -47,8 +48,16 @@ trap remove_docker_container EXIT docker run --rm -it --device=/dev/neuron0 --network bridge \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "HF_TOKEN=${HF_TOKEN}" \ -v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ --name "${container_name}" \ ${image_name} \ - /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys" + /bin/bash -c " + python3 /workspace/vllm/examples/offline_inference/neuron.py; + python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; + for f in /workspace/vllm/tests/neuron/2_core/*.py; do + echo 'Running test file: '$f; + python3 -m pytest \$f -v --capture=tee-sys; + done + " \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 21982b01b9c..a2a5c2a02cb 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -1,54 +1,185 @@ #!/bin/bash -set -xue +set -xu -# Build the docker image. -docker build -f docker/Dockerfile.tpu -t vllm-tpu . -# Set up cleanup. -remove_docker_container() { docker rm -f tpu-test || true; } +remove_docker_container() { + docker rm -f tpu-test || true; + docker rm -f vllm-tpu || true; +} + trap remove_docker_container EXIT + # Remove the container that might not be cleaned up in the previous run. remove_docker_container +# Build the docker image. +docker build -f docker/Dockerfile.tpu -t vllm-tpu . + +# Set up cleanup. +cleanup_docker() { + # Get Docker's root directory + docker_root=$(docker info -f '{{.DockerRootDir}}') + if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 + fi + echo "Docker root directory: $docker_root" + # Check disk usage of the filesystem where Docker's root directory is located + disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') + # Define the threshold + threshold=70 + if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune --force --filter "until=72h" --all + echo "Docker images and volumes cleanup completed." + else + echo "Disk usage is below $threshold%. No cleanup needed." + fi +} +cleanup_docker + # For HF_TOKEN. source /etc/environment -# Run a simple end-to-end example. + docker run --privileged --net host --shm-size=16G -it \ -e "HF_TOKEN=$HF_TOKEN" --name tpu-test \ - vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \ - && python3 -m pip install pytest pytest-asyncio tpu-info \ - && python3 -m pip install lm_eval[api]==0.4.4 \ - && export VLLM_XLA_CACHE_PATH= \ - && export VLLM_USE_V1=1 \ - && export VLLM_XLA_CHECK_RECOMPILATION=1 \ - && echo HARDWARE \ - && tpu-info \ - && echo TEST_0 \ - && pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ - && echo TEST_1 \ - && pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \ - && echo TEST_2 \ - && pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ - && echo TEST_3 \ - && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ - && echo TEST_4 \ - && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ - && echo TEST_5 \ - && python3 /workspace/vllm/examples/offline_inference/tpu.py \ - && echo TEST_6 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ - && echo TEST_7 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ - && echo TEST_8 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ - && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ - && echo TEST_10 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ - && echo TEST_11 \ - && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ + vllm-tpu /bin/bash -c ' +set -e # Exit immediately if a command exits with a non-zero status. +set -u # Treat unset variables as an error. + +echo "--- Starting script inside Docker container ---" + +# Create results directory +RESULTS_DIR=$(mktemp -d) +# If mktemp fails, set -e will cause the script to exit. +echo "Results will be stored in: $RESULTS_DIR" + +# Install dependencies +echo "--- Installing Python dependencies ---" +python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ + && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ + && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 +echo "--- Python dependencies installed ---" +export VLLM_USE_V1=1 +export VLLM_XLA_CHECK_RECOMPILATION=1 +export VLLM_XLA_CACHE_PATH= +echo "Using VLLM V1" + +echo "--- Hardware Information ---" +tpu-info +echo "--- Starting Tests ---" +set +e +overall_script_exit_code=0 + +# --- Test Definitions --- +# If a test fails, this function will print logs and will not cause the main script to exit. +run_test() { + local test_num=$1 + local test_name=$2 + local test_command=$3 + local log_file="$RESULTS_DIR/test_${test_num}.log" + local actual_exit_code + + echo "--- TEST_$test_num: Running $test_name ---" + + # Execute the test command. + eval "$test_command" > >(tee -a "$log_file") 2> >(tee -a "$log_file" >&2) + actual_exit_code=$? + + echo "TEST_${test_num}_COMMAND_EXIT_CODE: $actual_exit_code" # This goes to main log + echo "TEST_${test_num}_COMMAND_EXIT_CODE: $actual_exit_code" >> "$log_file" # Also to per-test log + + if [ "$actual_exit_code" -ne 0 ]; then + echo "TEST_$test_num ($test_name) FAILED with exit code $actual_exit_code." >&2 + echo "--- Log for failed TEST_$test_num ($test_name) ---" >&2 + if [ -f "$log_file" ]; then + cat "$log_file" >&2 + else + echo "Log file $log_file not found for TEST_$test_num ($test_name)." >&2 + fi + echo "--- End of log for TEST_$test_num ($test_name) ---" >&2 + return "$actual_exit_code" # Return the failure code + else + echo "TEST_$test_num ($test_name) PASSED." + return 0 # Return success + fi +} + +# Helper function to call run_test and update the overall script exit code +run_and_track_test() { + local test_num_arg="$1" + local test_name_arg="$2" + local test_command_arg="$3" + + # Run the test + run_test "$test_num_arg" "$test_name_arg" "$test_command_arg" + local test_specific_exit_code=$? + + # If the test failed, set the overall script exit code to 1 + if [ "$test_specific_exit_code" -ne 0 ]; then + # No need for extra echo here, run_test already logged the failure. + overall_script_exit_code=1 + fi +} + +# --- Actual Test Execution --- +run_and_track_test 0 "test_perf.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_perf.py" +run_and_track_test 1 "test_compilation.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py" +run_and_track_test 2 "test_basic.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py" +run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \ + "python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine" +run_and_track_test 4 "test_quantization_accuracy.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py" +run_and_track_test 5 "examples/offline_inference/tpu.py" \ + "python3 /workspace/vllm/examples/offline_inference/tpu.py" +run_and_track_test 6 "test_tpu_model_runner.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py" +run_and_track_test 7 "test_sampler.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" +run_and_track_test 8 "test_topk_topp_sampler.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" +run_and_track_test 9 "test_multimodal.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py" +run_and_track_test 10 "test_pallas.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" +run_and_track_test 11 "test_struct_output_generate.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\"" +run_and_track_test 12 "test_moe_pallas.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" +run_and_track_test 13 "test_lora.py" \ + "VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py" +run_and_track_test 14 "test_tpu_qkv_linear.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" +run_and_track_test 15 "test_spmd_model_weight_loading.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" + +# After all tests have been attempted, exit with the overall status. +if [ "$overall_script_exit_code" -ne 0 ]; then + echo "--- One or more tests FAILED. Overall script exiting with failure code 1. ---" +else + echo "--- All tests have completed and PASSED. Overall script exiting with success code 0. ---" +fi +exit "$overall_script_exit_code" +' # IMPORTANT: This is the closing single quote for the bash -c "..." command. Ensure it is present and correct. +# Capture the exit code of the docker run command +DOCKER_RUN_EXIT_CODE=$? +# The trap will run for cleanup. +# Exit the main script with the Docker run command's exit code. +if [ "$DOCKER_RUN_EXIT_CODE" -ne 0 ]; then + echo "Docker run command failed with exit code $DOCKER_RUN_EXIT_CODE." + exit "$DOCKER_RUN_EXIT_CODE" +else + echo "Docker run command completed successfully." + exit 0 +fi # TODO: This test fails because it uses RANDOM_SEED sampling -# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ +# pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 75e3ef26409..037897e53db 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -75,3 +75,4 @@ else fi aws s3 cp "$wheel" "s3://vllm-wheels/$version/" +aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b3005b1b4b0..b739851cb90 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -32,16 +32,17 @@ steps: ##### fast check tests ##### - label: Documentation Build # 2min - working_dir: "/vllm-workspace/test_docs/docs" + mirror_hardwares: [amdexperimental] + working_dir: "/vllm-workspace/test_docs" fast_check: true no_gpu: True commands: - - pip install -r ../../requirements/docs.txt - - SPHINXOPTS=\"-W\" make html - # Check API reference (if it fails, you may have missing mock imports) - - grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html + - pip install -r ../requirements/docs.txt + # TODO: add `--strict` once warnings in docstrings are fixed + - mkdocs build - label: Async Engine, Inputs, Utils, Worker Test # 24min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/mq_llm_engine @@ -57,11 +58,13 @@ steps: - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py + - pytest -v -s test_outputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker - label: Python-only Installation Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh - setup.py @@ -69,7 +72,7 @@ steps: - bash standalone_tests/python_only_compile.sh - label: Basic Correctness Test # 30min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true torch_nightly: true source_file_dependencies: @@ -86,6 +89,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Chunked Prefill Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/basic_correctness/test_chunked_prefill @@ -94,7 +98,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true source_file_dependencies: - vllm/core @@ -104,10 +108,10 @@ steps: - pytest -v -s core - label: Entrypoints Test # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true torch_nightly: true - #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/entrypoints/llm @@ -121,11 +125,12 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ - pytest -v -s entrypoints/test_chat_utils.py - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Distributed Tests (4 GPUs) # 10min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -133,32 +138,38 @@ steps: - vllm/core/ - tests/distributed/test_utils - tests/distributed/test_pynccl + - tests/distributed/test_events - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile/test_basic_correctness - examples/offline_inference/rlhf.py - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - tests/v1/test_async_llm_dp.py + - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py + - pytest -v -s distributed/test_events.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference - - python3 rlhf.py - - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd - label: Metrics, Tracing Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 source_file_dependencies: - vllm/ @@ -172,7 +183,7 @@ steps: ##### 1 GPU test ##### - label: Regression Test # 5min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/test_regression @@ -182,7 +193,7 @@ steps: working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/engine @@ -190,13 +201,14 @@ steps: - tests/test_sequence - tests/test_config - tests/test_logger + - tests/test_vllm_port commands: - - pytest -v -s engine test_sequence.py test_config.py test_logger.py + - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py # OOM in the CI unless we run this separately - pytest -v -s tokenization - label: V1 Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 @@ -209,10 +221,11 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py + - pytest -v -s v1/test_metrics_reader.py # TODO: accuracy does not match, whether setting # VLLM_USE_FLASHINFER_SAMPLER or not on H100. - pytest -v -s v1/e2e @@ -221,8 +234,8 @@ steps: - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" - #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ @@ -237,7 +250,7 @@ steps: - python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language_embedding.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - - VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder.py - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py @@ -246,7 +259,7 @@ steps: - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/prefix_caching @@ -254,6 +267,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test # 36min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers - vllm/sampling_metadata.py @@ -263,18 +277,8 @@ steps: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers -- label: LogitsProcessor Test # 5min - mirror_hardwares: [amd] - source_file_dependencies: - - vllm/model_executor/layers - - vllm/model_executor/guided_decoding - - tests/test_logits_processor - - tests/model_executor/test_guided_processors - commands: - - pytest -v -s test_logits_processor.py - - pytest -v -s model_executor/test_guided_processors.py - - label: Speculative decoding tests # 40min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/spec_decode - tests/spec_decode @@ -285,7 +289,7 @@ steps: - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/lora - tests/lora @@ -293,6 +297,7 @@ steps: parallelism: 4 - label: PyTorch Compilation Unit Tests + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -300,9 +305,12 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py + - pytest -v -s compile/test_async_tp.py - label: PyTorch Fullgraph Smoke Test # 9min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -312,8 +320,10 @@ steps: # these tests need to be separated, cannot combine - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py + - pytest -v -s compile/piecewise/test_full_cudagraph.py - label: PyTorch Fullgraph Test # 18min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -322,7 +332,7 @@ steps: - pytest -v -s compile/test_full_graph.py - label: Kernels Core Operation Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/ - tests/kernels/core @@ -330,7 +340,7 @@ steps: - pytest -v -s kernels/core - label: Kernels Attention Test %N - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/attention/ - vllm/attention @@ -341,7 +351,7 @@ steps: parallelism: 2 - label: Kernels Quantization Test %N - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/quantization/ - vllm/model_executor/layers/quantization @@ -351,7 +361,7 @@ steps: parallelism: 2 - label: Kernels MoE Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/moe/ - tests/kernels/moe @@ -360,7 +370,7 @@ steps: - pytest -v -s kernels/moe - label: Kernels Mamba Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba @@ -368,25 +378,39 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - # mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader - tests/tensorizer_loader + - tests/entrypoints/openai/test_tensorizer_entrypoint.py commands: - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s tensorizer_loader + - pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py + +- label: Model Executor Test + mirror_hardwares: [amdexperimental, amdproduction] + soft_fail: true + source_file_dependencies: + - vllm/model_executor + - tests/model_executor + commands: + - apt-get update && apt-get install -y curl libsodium23 + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s model_executor - label: Benchmarks # 9min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/.buildkite" - mirror_hardwares: [amd] source_file_dependencies: - benchmarks/ commands: - bash scripts/run-benchmarks.sh - label: Benchmarks CLI Test # 10min + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/benchmarks/ @@ -394,23 +418,29 @@ steps: - pytest -v -s benchmarks/ - label: Quantization Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization - tests/quantization commands: + # temporary install here since we need nightly, will move to requirements/test.in + # after torchao 0.12 release + - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-small.txt -t 1 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/entrypoints/openai/ @@ -419,6 +449,7 @@ steps: - pytest -s entrypoints/openai/correctness/ - label: Encoder Decoder tests # 5min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/encoder_decoder @@ -426,8 +457,8 @@ steps: - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min + mirror_hardwares: [amdexperimental] fast_check: false - #mirror_hardwares: [ amd ] source_file_dependencies: - vllm/ - tests/tool_use @@ -439,6 +470,7 @@ steps: ##### models test ##### - label: Basic Models Test # 24min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -448,43 +480,55 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_utils.py - pytest -v -s models/test_vision.py - # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' + - pytest -v -s models/test_initialization.py - label: Language Models Test (Standard) - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models/language commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pip freeze | grep -E 'torch' - pytest -v -s models/language -m core_model -- label: Language Models Test (Extended) +- label: Language Models Test (Extended Generation) # 1hr20min + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ - - tests/models/language + - tests/models/language/generation commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - - pytest -v -s models/language -m 'not core_model' + - pytest -v -s models/language/generation -m 'not core_model' + +- label: Language Models Test (Extended Pooling) # 36min + mirror_hardwares: [amdexperimental] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/language/pooling + commands: + - pytest -v -s models/language/pooling -m 'not core_model' - label: Multi-Modal Models Test (Standard) - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pip freeze | grep -E 'torch' - pytest -v -s models/multimodal/processing - pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -494,6 +538,7 @@ steps: - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' - label: Multi-Modal Models Test (Extended) 2 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -503,6 +548,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' - label: Multi-Modal Models Test (Extended) 3 + mirror_hardwares: [amdexperimental, amdproduction] optional: true source_file_dependencies: - vllm/ @@ -512,7 +558,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' - label: Quantized Models Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/model_executor/layers/quantization - tests/models/quantization @@ -521,7 +567,7 @@ steps: # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] optional: true commands: - echo 'Testing custom models...' @@ -533,7 +579,7 @@ steps: ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -544,6 +590,7 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - label: 2 Node Tests (4 GPUs in total) # 16min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 num_nodes: 2 @@ -562,7 +609,7 @@ steps: - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -577,9 +624,11 @@ steps: - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py - tests/v1/test_async_llm_dp.py + - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py @@ -599,13 +648,14 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: - vllm/plugins/ - tests/plugins/ commands: - # begin platform plugin tests, all the code in-between runs on dummy platform + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform - pip install -e ./plugins/vllm_add_dummy_platform - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y @@ -616,8 +666,10 @@ steps: - pytest -v -s distributed/test_distributed_oot.py - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - label: Multi-step Tests (4 GPUs) # 36min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -638,6 +690,7 @@ steps: - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 45min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -651,6 +704,7 @@ steps: - pytest -v -s distributed/test_pipeline_parallel.py - label: LoRA TP Test (Distributed) + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 4 source_file_dependencies: - vllm/lora @@ -666,6 +720,7 @@ steps: - label: Weight Loading Multiple GPU Test # 33min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -675,6 +730,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - label: Weight Loading Multiple GPU Test - Large Models # optional + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 gpu: a100 @@ -713,4 +769,4 @@ steps: - vllm/model_executor/layers/quantization commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index 637d2dd1145..f05be2ba870 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -75,20 +75,20 @@ body: ``` ``` - The error message you got, with the full traceback. + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. ``` validations: required: true - type: markdown attributes: - value: > - ⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output: + value: | + ⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the model's output: - Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc). - If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect. - Thanks for contributing 🎉! + Thanks for reporting 🙏! - type: checkboxes id: askllm attributes: diff --git a/.github/ISSUE_TEMPLATE/450-ci-failure.yml b/.github/ISSUE_TEMPLATE/450-ci-failure.yml new file mode 100644 index 00000000000..7af0e0673a2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/450-ci-failure.yml @@ -0,0 +1,69 @@ +name: 🧪 CI failure report +description: Report a failing test. +title: "[CI Failure]: " +labels: ["ci-failure"] + +body: +- type: markdown + attributes: + value: > + #### Include the name of the failing Buildkite step and test file in the title. +- type: input + attributes: + label: Name of failing test + description: | + Paste in the fully-qualified name of the failing test from the logs. + placeholder: | + `path/to/test_file.py::test_name[params]` + validations: + required: true +- type: checkboxes + attributes: + label: Basic information + description: Select all items that apply to the failing test. + options: + - label: Flaky test + - label: Can reproduce locally + - label: Caused by external libraries (e.g. bug in `transformers`) +- type: textarea + attributes: + label: 🧪 Describe the failing test + description: | + Please provide a clear and concise description of the failing test. + placeholder: | + A clear and concise description of the failing test. + + ``` + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. + ``` + validations: + required: true +- type: textarea + attributes: + label: 📝 History of failing test + description: | + Since when did the test start to fail? + You can look up its history via [Buildkite Test Suites](https://buildkite.com/organizations/vllm/analytics/suites/ci-1/tests?branch=main). + + If you have time, identify the PR that caused the test to fail on main. You can do so via the following methods: + + - Use Buildkite Test Suites to find the PR where the test failure first occurred, and reproduce the failure locally. + + - Run [`git bisect`](https://git-scm.com/docs/git-bisect) locally. + + - Manually unblock Buildkite steps for suspected PRs on main and check the results. (authorized users only) + placeholder: | + Approximate timeline and/or problematic PRs + + A link to the Buildkite analytics of the failing test (if available) + validations: + required: true +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. Usually, this includes those who worked on the PR that failed the test. +- type: markdown + attributes: + value: > + Thanks for reporting 🙏! diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 7042e81a84d..c1d1e07bf62 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,6 +1,15 @@ -FILL IN THE PR DESCRIPTION HERE +## Essential Elements of an Effective PR Description Checklist +- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". +- [ ] The test plan, such as providing test command. +- [ ] The test results, such as pasting the results comparison before and after, or e2e results -FIX #xxxx (*link existing issues this PR will resolve*) +PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED. + +## Purpose + +## Test Plan + +## Test Result -**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) +**BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.github/mergify.yml b/.github/mergify.yml index 15fa3660a87..e595060c325 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -58,7 +58,7 @@ pull_request_rules: - files~=^benchmarks/structured_schemas/ - files=benchmarks/benchmark_serving_structured_output.py - files=benchmarks/run_structured_output_benchmark.sh - - files=docs/source/features/structured_outputs.md + - files=docs/features/structured_outputs.md - files=examples/offline_inference/structured_outputs.py - files=examples/online_serving/openai_chat_completion_structured_outputs.py - files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -135,9 +135,7 @@ pull_request_rules: - files~=^tests/entrypoints/openai/tool_parsers/ - files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py - files~=^vllm/entrypoints/openai/tool_parsers/ - - files=docs/source/features/tool_calling.md - - files=docs/source/getting_started/examples/openai_chat_completion_client_with_tools.md - - files=docs/source/getting_started/examples/chat_with_tools.md + - files=docs/features/tool_calling.md - files~=^examples/tool_chat_* - files=examples/offline_inference/chat_with_tools.py - files=examples/online_serving/openai_chat_completion_client_with_tools_required.py @@ -163,6 +161,17 @@ pull_request_rules: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork +- name: assign reviewer for tensorizer changes + conditions: + - files~=^vllm/model_executor/model_loader/tensorizer.py + - files~=^vllm/model_executor/model_loader/tensorizer_loader.py + - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py + - files~=^tests/tensorizer_loader/ + actions: + assign: + users: + - "sangstar" + - name: remove 'needs-rebase' label when conflict is resolved conditions: - -conflict diff --git a/.github/scripts/cleanup_pr_body.sh b/.github/scripts/cleanup_pr_body.sh index 3246c6f9bc4..8d65936fba1 100755 --- a/.github/scripts/cleanup_pr_body.sh +++ b/.github/scripts/cleanup_pr_body.sh @@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" # Remove HTML
section that includes text of "PR Checklist (Click to Expand)" python3 - < - - vLLM + + vLLM

@@ -18,18 +18,20 @@ Easy, fast, and cheap LLM serving for everyone > For Intel Gaudi specific setup instructions and examples, please refer [Intel® Gaudi® README](https://github.com/HabanaAI/vllm-fork/blob/habana_main/README_GAUDI.md). For jupyter notebook based quickstart tutorials refer [Getting Started with vLLM](https://github.com/HabanaAI/Gaudi-tutorials/blob/main/PyTorch/vLLM_Tutorials/Getting_Started_with_vLLM/Getting_Started_with_vLLM.ipynb) and [Understanding vLLM on Gaudi](https://github.com/HabanaAI/Gaudi-tutorials/blob/main/PyTorch/vLLM_Tutorials/Understanding_vLLM_on_Gaudi/Understanding_vLLM_on_Gaudi.ipynb). *Latest News* 🔥 +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). +- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). + +
+Previous News + - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). - [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0). - [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted. -- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - -
-Previous News - - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! - [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! @@ -59,8 +61,8 @@ vLLM is fast with: - Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. -- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8 +- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer - Speculative decoding - Chunked prefill @@ -73,14 +75,14 @@ vLLM is flexible and easy to use with: - Tensor parallelism and pipeline parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. +- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron - Prefix caching support -- Multi-lora support +- Multi-LoRA support vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) - Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3) -- Embedding Models (e.g. E5-Mistral) +- Embedding Models (e.g., E5-Mistral) - Multi-modal LLMs (e.g., LLaVA) Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). @@ -101,14 +103,14 @@ Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more. ## Contributing We welcome and value any contributions and collaborations. -Please check out [Contributing to vLLM](https://docs.vllm.ai/en/stable/contributing/overview.html) for how to get involved. +Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved. ## Sponsors vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! - + Cash Donations: - a16z - Dropbox @@ -163,4 +165,4 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs ## Media Kit -- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). +- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit) diff --git a/SECURITY.md b/SECURITY.md index 47196a1f122..6053cfb41f3 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -8,4 +8,6 @@ Please report security issues privately using [the vulnerability submission form --- +Please see the [Security Guide in the vLLM documentation](https://docs.vllm.ai/en/latest/usage/security.html) for more information on vLLM's security assumptions and recommendations. + Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. diff --git a/benchmarks/README.md b/benchmarks/README.md index 4a8ab895e18..6f9fbb91cbd 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -64,6 +64,12 @@ become available. ✅ lmms-lab/LLaVA-OneVision-Data, Aeala/ShareGPT_Vicuna_unfiltered + + Custom + ✅ + ✅ + Local file: data.jsonl + @@ -124,6 +130,38 @@ P99 ITL (ms): 8.39 ================================================== ``` +### Custom Dataset +If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl + +``` +{"prompt": "What is the capital of India?"} +{"prompt": "What is the capital of Iran?"} +{"prompt": "What is the capital of China?"} +``` + +```bash +# start server +VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests +``` + +```bash +# run benchmarking script +python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detailed \ + --backend vllm \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --endpoint /v1/completions \ + --dataset-name custom \ + --dataset-path \ + --custom-skip-chat-template \ + --num-prompts 80 \ + --max-concurrency 1 \ + --temperature=0.3 \ + --top-p=0.75 \ + --result-dir "./log/" +``` + +You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. + ### VisionArena Benchmark for Vision Language Models ```bash @@ -146,10 +184,9 @@ python3 vllm/benchmarks/benchmark_serving.py \ ``` bash VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ - --speculative-model "[ngram]" \ - --ngram_prompt_lookup_min 2 \ - --ngram-prompt-lookup-max 5 \ - --num_speculative_tokens 5 + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' ``` ``` bash @@ -204,6 +241,16 @@ python3 vllm/benchmarks/benchmark_serving.py \ --seed 42 ``` +**`philschmid/mt-bench`** + +``` bash +python3 vllm/benchmarks/benchmark_serving.py \ + --model Qwen/QwQ-32B \ + --dataset-name hf \ + --dataset-path philschmid/mt-bench \ + --num-prompts 80 +``` + ### Running With Sampling Parameters When using OpenAI-compatible backends such as `vllm`, optional sampling @@ -274,10 +321,9 @@ python3 vllm/benchmarks/benchmark_throughput.py \ --output-len=100 \ --num-prompts=2048 \ --async-engine \ - --speculative-model="[ngram]" \ - --ngram_prompt_lookup_min=2 \ - --ngram-prompt-lookup-max=5 \ - --num_speculative_tokens=5 + --speculative-config $'{"method": "ngram", + "num_speculative_tokens": 5, "prompt_lookup_max": 5, + "prompt_lookup_min": 2}' ``` ``` diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index e6a67fda682..ddb38e304cd 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io import json @@ -12,8 +13,7 @@ import aiohttp import huggingface_hub.constants from tqdm.asyncio import tqdm -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast # NOTE(simon): do not import vLLM here so the benchmark script # can run without vLLM installed. @@ -43,8 +43,7 @@ class RequestFuncOutput: latency: float = 0.0 output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: list[float] = field( - default_factory=list) # list of inter-token latencies + itl: list[float] = field(default_factory=list) # list of inter-token latencies tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -57,8 +56,9 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: params = { "max_new_tokens": request_func_input.output_len, "do_sample": True, @@ -105,8 +105,7 @@ async def async_request_tgi( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp @@ -133,8 +132,9 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, @@ -159,8 +159,7 @@ async def async_request_trt_llm( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data:") + chunk = chunk_bytes.decode("utf-8").removeprefix("data:") data = json.loads(chunk) output.generated_text += data["text_output"] @@ -172,8 +171,7 @@ async def async_request_trt_llm( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp @@ -197,9 +195,14 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + api_url = request_func_input.api_url + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, @@ -207,6 +210,8 @@ async def async_request_deepspeed_mii( "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. "top_p": 1.0, } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -217,19 +222,21 @@ async def async_request_deepspeed_mii( st = time.perf_counter() try: - async with session.post(url=request_func_input.api_url, - json=payload) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: parsed_resp = await response.json() output.latency = time.perf_counter() - st if "choices" in parsed_resp: - output.generated_text = parsed_resp["choices"][0][ - "text"] + output.generated_text = parsed_resp["choices"][0]["text"] elif "text" in parsed_resp: output.generated_text = parsed_resp["text"][0] else: - output.error = ("Unexpected response format: " - "neither 'choices' nor 'text' found") + output.error = ( + "Unexpected response format: " + "neither 'choices' nor 'text' found" + ) output.success = False output.success = True else: @@ -250,15 +257,17 @@ async def async_request_openai_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("completions", "profile") - ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -273,9 +282,7 @@ async def async_request_openai_completions( payload["ignore_eos"] = request_func_input.ignore_eos if request_func_input.extra_body: payload.update(request_func_input.extra_body) - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -284,8 +291,9 @@ async def async_request_openai_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: first_chunk_received = False async for chunk_bytes in response.content: @@ -293,8 +301,7 @@ async def async_request_openai_completions( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": data = json.loads(chunk) @@ -314,21 +321,20 @@ async def async_request_openai_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" - elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + if usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") if first_chunk_received: output.success = True else: output.success = False output.error = ( "Never received a valid chunk to calculate TTFT." - "This response will be marked as failed!") + "This response will be marked as failed!" + ) output.generated_text = generated_text output.latency = most_recent_timestamp - st else: @@ -349,23 +355,22 @@ async def async_request_openai_chat_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("chat/completions", "profile") - ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'." + ) - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "messages": [ - { - "role": "user", - "content": content - }, + {"role": "user", "content": content}, ], "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, @@ -391,16 +396,16 @@ async def async_request_openai_chat_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) @@ -414,13 +419,11 @@ async def async_request_openai_chat_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") most_recent_timestamp = timestamp @@ -446,25 +449,28 @@ async def async_request_openai_audio( ) -> RequestFuncOutput: # Lazy import without PlaceholderModule to avoid vllm dep. import soundfile + api_url = request_func_input.api_url - assert api_url.endswith( - ("transcriptions", "translations" - )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' " + ) "or `translations`." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: content = [{"type": "text", "text": request_func_input.prompt}] payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, "stream": True, "language": "en", # Flattened due to multipart/form-data "stream_include_usage": True, - "stream_continuous_usage_stats": True + "stream_continuous_usage_stats": True, } if request_func_input.extra_body: payload.update(request_func_input.extra_body) @@ -479,9 +485,9 @@ def to_bytes(y, sr): buffer.seek(0) return buffer - with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: form = aiohttp.FormData() - form.add_field('file', f, content_type='audio/wav') + form.add_field("file", f, content_type="audio/wav") for key, value in payload.items(): form.add_field(key, str(value)) @@ -493,24 +499,22 @@ def to_bytes(y, sr): st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, - data=form, - headers=headers) as response: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: ttft = timestamp - st @@ -519,12 +523,14 @@ def to_bytes(y, sr): # Decoding phase else: output.itl.append( - timestamp - most_recent_timestamp) + timestamp - most_recent_timestamp + ) generated_text += content or "" elif usage := data.get("usage"): output.output_tokens = usage.get( - "completion_tokens") + "completion_tokens" + ) most_recent_timestamp = timestamp @@ -545,7 +551,7 @@ def to_bytes(y, sr): def get_model(pretrained_model_name_or_path: str) -> str: - if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': + if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true": from modelscope import snapshot_download from vllm.model_executor.model_loader.weight_utils import get_lock @@ -556,7 +562,8 @@ def get_model(pretrained_model_name_or_path: str) -> str: model_path = snapshot_download( model_id=pretrained_model_name_or_path, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) return model_path return pretrained_model_name_or_path @@ -569,23 +576,23 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( - pretrained_model_name_or_path): - pretrained_model_name_or_path = get_model( - pretrained_model_name_or_path) + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: from vllm.transformers_utils.tokenizer import MistralTokenizer except ImportError as e: - raise ImportError("MistralTokenizer requires vllm package.\n" - "Please install it with `pip install vllm` " - "to use mistral tokenizer mode.") from e - return MistralTokenizer.from_pretrained( - str(pretrained_model_name_or_path)) + raise ImportError( + "MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode." + ) from e + return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path)) else: return AutoTokenizer.from_pretrained( pretrained_model_name_or_path, @@ -605,10 +612,11 @@ def get_tokenizer( "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, "sglang": async_request_openai_completions, + "llama.cpp": async_request_openai_completions, } OPENAI_COMPATIBLE_BACKENDS = [ - k for k, v in ASYNC_REQUEST_FUNCS.items() - if v in (async_request_openai_completions, - async_request_openai_chat_completions) + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, async_request_openai_chat_completions) ] diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index b81c2f8192d..5d2a26cd443 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This module defines a framework for sampling benchmark requests from various datasets. Each dataset subclass of BenchmarkDataset must implement sample @@ -9,9 +10,6 @@ - BurstGPT - HuggingFace - VisionArena - -TODO: Implement CustomDataset to parse a JSON file and convert its contents into -SampleRequest instances, similar to the approach used in ShareGPT. """ import base64 @@ -35,6 +33,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer logger = logging.getLogger(__name__) @@ -82,14 +81,12 @@ def __init__( self.dataset_path = dataset_path # Set the random seed, ensuring that a None value is replaced with the # default seed. - self.random_seed = (random_seed - if random_seed is not None else self.DEFAULT_SEED) + self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED self.data = None def apply_multimodal_chat_transformation( - self, - prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + self, prompt: str, mm_content: Optional[MultiModalDataDict] = None + ) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -111,8 +108,7 @@ def load_data(self) -> None: NotImplementedError: If a subclass does not implement this method. """ # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError( - "load_data must be implemented in subclasses.") + raise NotImplementedError("load_data must be implemented in subclasses.") def get_random_lora_request( self, @@ -158,8 +154,9 @@ def get_random_lora_request( return lora_request, lora_tokenizer_cache[lora_id] or tokenizer @abstractmethod - def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int) -> list[SampleRequest]: + def sample( + self, tokenizer: PreTrainedTokenizerBase, num_requests: int + ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -177,8 +174,9 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, """ raise NotImplementedError("sample must be implemented in subclasses.") - def maybe_oversample_requests(self, requests: list[SampleRequest], - num_requests: int) -> None: + def maybe_oversample_requests( + self, requests: list[SampleRequest], num_requests: int + ) -> None: """ Oversamples the list of requests if its size is less than the desired number. @@ -189,11 +187,9 @@ def maybe_oversample_requests(self, requests: list[SampleRequest], """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, - k=num_requests - len(requests)) + additional = random.choices(requests, k=num_requests - len(requests)) requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", - num_requests) + logger.info("Oversampled requests to reach %d total samples.", num_requests) # ----------------------------------------------------------------------------- @@ -218,14 +214,14 @@ def is_valid_sequence( """ # Check for invalid conditions prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len - < min_len) + output_too_short = (not skip_min_output_len_check) and (output_len < min_len) prompt_too_long = prompt_len > max_prompt_len combined_too_long = (prompt_len + output_len) > max_total_len # Return True if none of the invalid conditions are met - return not (prompt_too_short or output_too_short or prompt_too_long - or combined_too_long) + return not ( + prompt_too_short or output_too_short or prompt_too_long or combined_too_long + ) @cache @@ -257,28 +253,28 @@ def process_image(image: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(image, dict) and 'bytes' in image: - image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, dict) and "bytes" in image: + image = Image.open(BytesIO(image["bytes"])) if isinstance(image, Image.Image): - image = image.convert("RGB") + image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") - image_base64 = base64.b64encode( - image_data.getvalue()).decode("utf-8") + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") return { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, } if isinstance(image, str): - image_url = (image if image.startswith( - ("http://", "file://")) else f"file://{image}") + image_url = ( + image if image.startswith(("http://", "file://")) else f"file://{image}" + ) return {"type": "image_url", "image_url": {"url": image_url}} - raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes.") + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes." + ) # ----------------------------------------------------------------------------- @@ -318,8 +314,11 @@ def sample( num_special_tokens = tokenizer.num_special_tokens_to_add() real_input_len = input_len - num_special_tokens - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + prefix_token_ids = ( + np.random.randint(0, vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) # New sampling logic: [X * (1 - b), X * (1 + b)] input_low = int(real_input_len * (1 - range_ratio)) @@ -329,21 +328,17 @@ def sample( # Add logging for debugging logger.info("Sampling input_len from [%s, %s]", input_low, input_high) - logger.info("Sampling output_len from [%s, %s]", output_low, - output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) + logger.info("Sampling output_len from [%s, %s]", output_low, output_high) + + input_lens = np.random.randint(input_low, input_high + 1, size=num_requests) + output_lens = np.random.randint(output_low, output_high + 1, size=num_requests) offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() + inner_seq = ( + (offsets[i] + i + np.arange(input_lens[i])) % vocab_size + ).tolist() token_sequence = prefix_token_ids + inner_seq prompt = tokenizer.decode(token_sequence) # After decoding the prompt we have to encode and decode it again. @@ -354,8 +349,9 @@ def sample( # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] # To avoid uncontrolled change of the prompt length, # the encoded sequence is truncated before being decode again. - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:input_lens[i]] + re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ + : input_lens[i] + ] prompt = tokenizer.decode(re_encoded_sequence) total_input_len = prefix_len + int(input_lens[i]) requests.append( @@ -363,7 +359,8 @@ def sample( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), - )) + ) + ) return requests @@ -390,7 +387,8 @@ def load_data(self) -> None: self.data = json.load(f) # Filter entries with at least two conversation turns. self.data = [ - entry for entry in self.data + entry + for entry in self.data if "conversations" in entry and len(entry["conversations"]) >= 2 ] random.seed(self.random_seed) @@ -416,31 +414,123 @@ def sample( ) lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path + ) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) - if not is_valid_sequence(prompt_len, - new_output_len, - skip_min_output_len_check=output_len - is not None): + new_output_len = len(completion_ids) if output_len is None else output_len + if not is_valid_sequence( + prompt_len, + new_output_len, + skip_min_output_len_check=output_len is not None, + ): continue if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation( - prompt, None) + prompt = self.apply_multimodal_chat_transformation(prompt, None) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, - )) + ) + ) self.maybe_oversample_requests(samples, num_requests) return samples +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset." + ) + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + # ----------------------------------------------------------------------------- # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- @@ -482,20 +572,20 @@ def sample( ) -> list: # Calculate average token length for a poem line. tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template(base_msg, - add_generation_prompt=True, - tokenize=False) + base_fmt = tokenizer.apply_chat_template( + base_msg, add_generation_prompt=True, tokenize=False + ) base_offset = len(tokenizer(base_fmt).input_ids) if input_len <= base_offset: raise ValueError( f"'input_len' must be higher than the base prompt length " - f"({base_offset}).") + f"({base_offset})." + ) # Determine how many poem lines to use. num_input_lines = round((input_len - base_offset) / avg_len) @@ -504,21 +594,23 @@ def sample( samples = [] while len(samples) < num_requests: - extra_lines = random.choices(self.data, - k=num_input_lines - num_prefix_lines) + extra_lines = random.choices( + self.data, k=num_input_lines - num_prefix_lines + ) prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" msg = [{"role": "user", "content": prompt}] prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False) + msg, add_generation_prompt=True, tokenize=False + ) prompt_len = len(tokenizer(prompt_formatted).input_ids) if prompt_len <= input_len: samples.append( SampleRequest( - prompt=prompt_formatted - if return_prompt_formatted else prompt, + prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) return samples @@ -538,7 +630,9 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.load_data() - def load_data(self, ): + def load_data( + self, + ): if self.dataset_path is None: raise ValueError("dataset_path must be provided for loading data.") @@ -552,8 +646,7 @@ def load_data(self, ): def _sample_loaded_data(self, num_requests: int) -> list: if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, - random_state=self.random_seed) + data = self.data.sample(n=num_requests, random_state=self.random_seed) else: data = self.data.sample( n=num_requests, @@ -577,7 +670,8 @@ def sample( input_len = int(data[i][2]) output_len = int(data[i][3]) lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path + ) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + # j) modulo vocab_size. @@ -589,7 +683,8 @@ def sample( prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, - )) + ) + ) return samples @@ -632,20 +727,23 @@ def load_data(self) -> None: class ConversationDataset(HuggingFaceDataset): """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { - 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + "lmms-lab/LLaVA-OneVision-Data", + "Aeala/ShareGPT_Vicuna_unfiltered", } IS_MULTIMODAL = True - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: # Filter examples with at least 2 conversations - filtered_data = self.data.filter( - lambda x: len(x["conversations"]) >= 2) + filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None @@ -661,24 +759,22 @@ def sample(self, completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len): + if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - mm_content = process_image( - item["image"]) if "image" in item else None + mm_content = process_image(item["image"]) if "image" in item else None if enable_multimodal_chat: # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -695,10 +791,8 @@ class VisionArenaDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": - lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": - lambda x: x["turns"][0][0]["content"] + "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], } IS_MULTIMODAL = True @@ -710,16 +804,14 @@ def sample( enable_multimodal_chat: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") prompt = parser_fn(item) mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) @@ -727,15 +819,15 @@ def sample( # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -760,26 +852,36 @@ class InstructCoderDataset(HuggingFaceDataset): "likaixin/InstructCoder", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: break - prompt = f"{item['instruction']}:\n{item['input']}" + prompt = f"{item['input']}\n\n{item['instruction']} Just output \ + the code, do not include any explanation." + + # apply template + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -794,38 +896,38 @@ class MTBenchDataset(HuggingFaceDataset): MT-Bench Dataset. https://huggingface.co/datasets/philschmid/mt-bench - We create a single turn dataset for MT-Bench. + We create a single turn dataset for MT-Bench. This is similar to Spec decoding benchmark setup in vLLM https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 - """ # noqa: E501 + """ # noqa: E501 DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM SUPPORTED_DATASET_PATHS = { "philschmid/mt-bench", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: break - prompt = item['turns'][0] + prompt = item["turns"][0] # apply template - prompt = tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - add_generation_prompt=True, - tokenize=False) + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( @@ -833,7 +935,8 @@ def sample(self, prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -847,23 +950,27 @@ class AIMODataset(HuggingFaceDataset): """ Dataset class for processing a AIMO dataset with reasoning questions. """ + SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT" + "AI-MO/aimo-validation-aime", + "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: sampled_requests = [] dynamic_output = output_len is None for item in self.data: if len(sampled_requests) >= num_requests: break - prompt, completion = item['problem'], item["solution"] + prompt, completion = item["problem"], item["solution"] prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids @@ -871,10 +978,9 @@ def sample(self, completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, - completion_len, - max_prompt_len=2048, - max_total_len=32000): + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 + ): continue sampled_requests.append( SampleRequest( @@ -882,11 +988,100 @@ def sample(self, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, original_start_marker: str = "<|editable_region_start|>" +) -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids + ), + ) + ) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + # ----------------------------------------------------------------------------- # ASR Dataset Implementation # ----------------------------------------------------------------------------- @@ -909,18 +1104,22 @@ class ASRDataset(HuggingFaceDataset): | AMI | Meetings | Spontaneous | ihm, sdm | +----------------+----------------------------------------+--------------------------+-----------------------------+ - """ # noqa: E501 + """ # noqa: E501 + SUPPORTED_DATASET_PATHS = { - "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", - "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", } DEFAULT_OUTPUT_LEN = 128 IS_MULTIMODAL = True # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ - "<|notimestamps|>" + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" skip_long_audios: bool = True def sample( @@ -931,8 +1130,8 @@ def sample( **kwargs, ) -> list: import librosa - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] @@ -955,10 +1154,14 @@ def sample( prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) if skipped: - logger.warning("%d samples discarded from dataset due to" \ - " their length being greater than" \ - " what Whisper supports.", skipped) + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index dfd9bb1e6a4..c06857247ee 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark the latency of processing a single batch of requests.""" import argparse @@ -6,14 +7,13 @@ import json import os import time -from pathlib import Path from typing import Any, Optional import numpy as np -import torch -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm +import vllm.envs as envs +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType @@ -21,13 +21,14 @@ from vllm.utils import FlexibleArgumentParser -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={"latency": results["latencies"]}, - extra_info={k: results[k] - for k in ["avg_latency", "percentiles"]}) + extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, + ) if pt_records: pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" write_to_json(pt_file, pt_records) @@ -42,9 +43,11 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + - args.output_len), ("Please ensure that max_model_len is greater than" - " the sum of input_len and output_len.") + args.input_len + args.output_len + ), ( + "Please ensure that max_model_len is greater than" + " the sum of input_len and output_len." + ) sampling_params = SamplingParams( n=args.n, @@ -55,18 +58,16 @@ def main(args: argparse.Namespace): detokenize=not args.disable_detokenize, ) print(sampling_params) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = np.random.randint( + 10000, size=(args.batch_size, args.input_len) + ) + dummy_prompts: list[PromptType] = [ + {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() + ] def llm_generate(): if not args.use_beam_search: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) else: llm.beam_search( dummy_prompts, @@ -79,16 +80,9 @@ def llm_generate(): def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir)), - ) as p: - llm_generate() - print(p.key_averages().table(sort_by="self_cuda_time_total")) + llm.start_profile() + llm_generate() + llm.stop_profile() else: start_time = time.perf_counter() llm_generate() @@ -101,10 +95,7 @@ def run_to_completion(profile_dir: Optional[str] = None): run_to_completion(profile_dir=None) if args.profile: - profile_dir = args.profile_result_dir - if not profile_dir: - profile_dir = (Path(".") / "vllm_benchmark_result" / - f"latency_result_{time.time()}") + profile_dir = envs.VLLM_TORCH_PROFILER_DIR print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -135,7 +126,8 @@ def run_to_completion(profile_dir: Optional[str] = None): if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark the latency of processing a single batch of " - "requests till completion.") + "requests till completion." + ) parser.add_argument("--input-len", type=int, default=32) parser.add_argument("--output-len", type=int, default=128) parser.add_argument("--batch-size", type=int, default=8) @@ -152,22 +144,14 @@ def run_to_completion(profile_dir: Optional[str] = None): default=10, help="Number of iterations to run for warmup.", ) - parser.add_argument("--num-iters", - type=int, - default=30, - help="Number of iterations to run.") + parser.add_argument( + "--num-iters", type=int, default=30, help="Number of iterations to run." + ) parser.add_argument( "--profile", action="store_true", help="profile the generation process of a single batch", ) - parser.add_argument( - "--profile-result-dir", - type=str, - default=None, - help=("path to save the pytorch profiler output. Can be visualized " - "with ui.perfetto.dev or Tensorboard."), - ) parser.add_argument( "--output-json", type=str, @@ -177,10 +161,20 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) + # V1 enables prefix caching by default which skews the latency + # numbers. We need to disable prefix caching by default. + parser.set_defaults(enable_prefix_caching=False) args = parser.parse_args() + if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: + raise OSError( + "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " + "Please set it to a valid path to use torch profiler." + ) main(args) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py index 21480578edb..00869fa94e7 100644 --- a/benchmarks/benchmark_long_document_qa_throughput.py +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Offline benchmark to test the long document QA throughput. @@ -76,7 +77,7 @@ def repeat_prompts(prompts, repeat_count, mode: str): - 'random': Shuffle the prompts randomly after repetition. - 'tile': Repeat the entire prompt list in sequence. Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. - - 'interleave': Repeat each prompt consecutively before moving to + - 'interleave': Repeat each prompt consecutively before moving to the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. Returns: @@ -86,20 +87,21 @@ def repeat_prompts(prompts, repeat_count, mode: str): ValueError: If an invalid mode is provided. """ print("Repeat mode: ", mode) - if mode == 'random': + if mode == "random": repeated_prompts = prompts * repeat_count random.shuffle(repeated_prompts) return repeated_prompts - elif mode == 'tile': + elif mode == "tile": return prompts * repeat_count - elif mode == 'interleave': + elif mode == "interleave": repeated_prompts = [] for prompt in prompts: repeated_prompts.extend([prompt] * repeat_count) return repeated_prompts else: - raise ValueError(f"Invalid mode: {mode}, only support " - "'random', 'tile', 'interleave'") + raise ValueError( + f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'" + ) def main(args): @@ -109,16 +111,16 @@ def main(args): # we append the document id at the beginning to avoid any of the document # being the prefix of other documents prompts = [ - str(i) + ' '.join(['hi'] * args.document_length) + str(i) + " ".join(["hi"] * args.document_length) for i in range(args.num_documents) ] prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) warmup_prompts = [ - "This is warm up request " + str(i) + \ - ' '.join(['hi'] * args.document_length) - for i in range(args.num_documents)] + "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length) + for i in range(args.num_documents) + ] # Create the LLM engine engine_args = EngineArgs.from_cli_args(args) @@ -142,42 +144,52 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') + description="Benchmark the performance with or " + "without automatic prefix caching." + ) parser.add_argument( - '--document-length', + "--document-length", type=int, # Roughly the number of tokens for a system paper, # excluding images default=20000, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - - parser.add_argument('--num-documents', - type=int, - default=8, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - - parser.add_argument('--output-len', type=int, default=10) - - parser.add_argument('--repeat-count', - type=int, - default=2, - help='Number of times to repeat each prompt') - - parser.add_argument("--repeat-mode", - type=str, - default='random', - help='The mode to repeat prompts. The supported ' - 'modes are "random", "tile", and "interleave". ' - 'See repeat_prompts() in the source code for details.') - - parser.add_argument("--shuffle-seed", - type=int, - default=0, - help='Random seed when the repeat mode is "random"') + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument( + "--num-documents", + type=int, + default=8, + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument("--output-len", type=int, default=10) + + parser.add_argument( + "--repeat-count", + type=int, + default=2, + help="Number of times to repeat each prompt", + ) + + parser.add_argument( + "--repeat-mode", + type=str, + default="random", + help="The mode to repeat prompts. The supported " + 'modes are "random", "tile", and "interleave". ' + "See repeat_prompts() in the source code for details.", + ) + + parser.add_argument( + "--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"', + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index f44da95d321..3e4704f0b82 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Benchmark the efficiency of prefix caching. @@ -63,8 +64,7 @@ class Request: output_len: int -def sample_tokens(tokenizer: PreTrainedTokenizerBase, - length: int) -> list[int]: +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]: vocab = tokenizer.get_vocab() all_special_ids = set(tokenizer.all_special_ids) @@ -91,8 +91,10 @@ def sample_requests_from_dataset( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] # Shuffle the dataset. random.shuffle(dataset) @@ -113,8 +115,9 @@ def sample_requests_from_dataset( completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = (len(completion_token_ids) - if fixed_output_len is None else fixed_output_len) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) if min_len <= prompt_len <= max_len: filtered_requests.append(Request(prompt, prompt_len, output_len)) @@ -128,27 +131,27 @@ def sample_requests_from_random( fixed_output_len: Optional[int], prefix_len: int, ) -> list[Request]: - requests = [] prefix_token_ids = sample_tokens(tokenizer, prefix_len) min_len, max_len = input_length_range for i in range(num_requests): unique_part_token_ids = sample_tokens( - tokenizer, - random.randint(min_len - prefix_len, max_len - prefix_len)) + tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len) + ) prompt_token_ids = prefix_token_ids + unique_part_token_ids prompt = tokenizer.decode(prompt_token_ids) prompt_len = len(prompt_token_ids) - assert (min_len <= prompt_len <= max_len - ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + assert min_len <= prompt_len <= max_len, ( + f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + ) requests.append(Request(prompt, prompt_len, fixed_output_len)) return requests -def repeat_and_sort_requests(requests: list[Request], - repeat_count: int, - sort: bool = False) -> list[str]: +def repeat_and_sort_requests( + requests: list[Request], repeat_count: int, sort: bool = False +) -> list[str]: repeated_requests = requests * repeat_count if sort: repeated_requests.sort(key=lambda x: x[1]) @@ -159,14 +162,14 @@ def repeat_and_sort_requests(requests: list[Request], def main(args): tokenizer = get_tokenizer(args.model, trust_remote_code=True) - input_length_range = tuple(map(int, args.input_length_range.split(':'))) + input_length_range = tuple(map(int, args.input_length_range.split(":"))) random.seed(args.seed) if args.dataset_path is not None: if args.prefix_len > 0: - raise ValueError("prefix-len is not supported when " - "dataset-path is provided.") - print(f"Start to sample {args.num_prompts} prompts " - f"from {args.dataset_path}") + raise ValueError( + "prefix-len is not supported when dataset-path is provided." + ) + print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, @@ -196,14 +199,16 @@ def main(args): llm = LLM(**dataclasses.asdict(engine_args)) - sampling_params = SamplingParams(temperature=0, - max_tokens=args.output_len, - detokenize=not args.disable_detokenize) + sampling_params = SamplingParams( + temperature=0, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) print("Testing filtered requests") - prompts = repeat_and_sort_requests(filtered_requests, - repeat_count=args.repeat_count, - sort=args.sort) + prompts = repeat_and_sort_requests( + filtered_requests, repeat_count=args.repeat_count, sort=args.sort + ) print("------start generating------") test_prefix( @@ -215,29 +220,35 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--num-prompts', - type=int, - required=True, - help="Number of the prompts sampled from dataset") - parser.add_argument('--repeat-count', - type=int, - default=1, - help='Number of times to repeat each prompt') - parser.add_argument('--sort', - action='store_true', - help='Sort prompts by input length') - parser.add_argument('--input-length-range', - type=str, - required=True, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') + description="Benchmark the performance with or without " + "automatic prefix caching." + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--num-prompts", + type=int, + required=True, + help="Number of the prompts sampled from dataset", + ) + parser.add_argument( + "--repeat-count", + type=int, + default=1, + help="Number of times to repeat each prompt", + ) + parser.add_argument( + "--sort", action="store_true", help="Sort prompts by input length" + ) + parser.add_argument( + "--input-length-range", + type=str, + required=True, + help="Range of input lengths for sampling prompts," + 'specified as "min:max" (e.g., "128:256").', + ) parser.add_argument( "--prefix-len", type=int, @@ -248,10 +259,12 @@ def main(args): "when dataset-path is not provided.", ) parser.add_argument( - '--disable-detokenize', - action='store_true', - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 76fe00ede24..5496703f23c 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline prioritization.""" + import argparse import dataclasses import json @@ -13,7 +15,7 @@ from vllm.utils import FlexibleArgumentParser -#Select a equi-probable random priority +# Select a equi-probable random priority def get_random_flag(): return 0 if random.random() < 0.5 else 1 @@ -33,8 +35,10 @@ def sample_requests( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] # Shuffle the dataset. random.shuffle(dataset) @@ -51,8 +55,9 @@ def sample_requests( completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -74,13 +79,16 @@ def run_vllm( disable_detokenize: bool = False, ) -> float: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " input_len and output_len for all requests.") + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " input_len and output_len for all requests." + ) # Add the requests to the engine. prompts = [] @@ -97,7 +105,8 @@ def run_vllm( ignore_eos=True, max_tokens=output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) @@ -111,26 +120,33 @@ def main(args: argparse.Namespace): # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) if args.dataset is None: # Synthesize a prompt with the given input length. prompt = "hi" * (args.input_len - 1) - requests = [(prompt, args.input_len, args.output_len, - get_random_flag()) for _ in range(args.num_prompts)] + requests = [ + (prompt, args.input_len, args.output_len, get_random_flag()) + for _ in range(args.num_prompts) + ] else: - requests = sample_requests(args.dataset, args.num_prompts, tokenizer, - args.output_len) + requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.output_len + ) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize) + elapsed_time = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) else: raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len, priority in requests) - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") + total_num_tokens = sum( + prompt_len + output_len for _, prompt_len, output_len, priority in requests + ) + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s" + ) # Output JSON results if specified if args.output_json: @@ -147,41 +163,44 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii"], - default="vllm") - parser.add_argument("--dataset", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=200, - help="Number of prompts to process.") parser.add_argument( - '--output-json', + "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" + ) + parser.add_argument( + "--dataset", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=200, help="Number of prompts to process." + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') + help="Path to save the throughput results in JSON format.", + ) parser.add_argument( - '--disable-detokenize', - action='store_true', - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index c236d64261d..81428fb7dae 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project r"""Benchmark online serving throughput. On the server side, run one of the following commands: @@ -20,6 +21,7 @@ --endpoint /generate_stream to the end of the command above. """ + import argparse import asyncio import gc @@ -34,12 +36,16 @@ from typing import Any, Optional import numpy as np -from backend_request_func import (ASYNC_REQUEST_FUNCS, - OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, - RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput, +) + try: from vllm.transformers_utils.tokenizer import get_tokenizer except ImportError: @@ -50,11 +56,22 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, - ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, MTBenchDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) +from benchmark_dataset import ( + AIMODataset, + ASRDataset, + BurstGPTDataset, + ConversationDataset, + CustomDataset, + HuggingFaceDataset, + InstructCoderDataset, + MTBenchDataset, + NextEditPredictionDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -117,7 +134,8 @@ async def get_request( # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) theta = 1.0 / (request_rate * burstiness) for request in input_requests: @@ -163,8 +181,10 @@ def calculate_metrics( # bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer( + outputs[i].generated_text, add_special_tokens=False + ).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -187,16 +207,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -207,7 +230,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -216,27 +240,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by backend + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics, actual_output_lens @@ -269,10 +297,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = \ - input_requests[0].prompt, input_requests[0].prompt_len, \ - input_requests[0].expected_output_len, \ - input_requests[0].multi_modal_data + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( @@ -292,36 +322,36 @@ async def benchmark( if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + f"are correctly specified. Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( - [random.choice(lora_modules) \ - for _ in range(len(input_requests))]) + [random.choice(lora_modules) for _ in range(len(input_requests))] + ) if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" - else: - distribution = "Gamma distribution" + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") @@ -333,42 +363,45 @@ async def benchmark( # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len, mm_content = request.prompt, \ - request.prompt_len, request.expected_output_len, \ - request.multi_modal_data + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -400,22 +433,32 @@ async def limited_request_func(request_func_input, pbar): goodput_config_dict=goodput_config_dict, ) - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) result = { "duration": benchmark_duration, @@ -423,8 +466,7 @@ async def limited_request_func(request_func_input, pbar): "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": - metrics.request_goodput if goodput_config_dict else None, + "request_goodput:": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -447,29 +489,35 @@ def process_one_metric( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -489,12 +537,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -507,31 +557,42 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any], - file_name: str) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any], file_name: str +) -> None: metrics = [ - "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", - "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", - "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + "median_ttft_ms", + "mean_ttft_ms", + "std_ttft_ms", + "p99_ttft_ms", + "mean_tpot_ms", + "median_tpot_ms", + "std_tpot_ms", + "p99_tpot_ms", + "median_itl_ms", + "mean_itl_ms", + "std_itl_ms", + "p99_itl_ms", ] # These raw data might be useful, but they are rather big. They can be added # later if needed ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] pt_records = convert_to_pytorch_benchmark_format( args=args, - metrics={k: [results[k]] - for k in metrics}, + metrics={k: [results[k]] for k in metrics}, extra_info={ k: results[k] - for k in results if k not in metrics and k not in ignored_metrics - }) + for k in results + if k not in metrics and k not in ignored_metrics + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" @@ -556,34 +617,51 @@ def main(args: argparse.Namespace): api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" - tokenizer = get_tokenizer(tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer( + tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) if args.dataset_name is None: raise ValueError( "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required.") + "'--dataset-path' if required." + ) + + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + ) - if args.dataset_name == "sonnet": + elif args.dataset_name == "sonnet": dataset = SonnetDataset(dataset_path=args.dataset_path) # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": - input_requests = dataset.sample(num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False) + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") - input_requests = dataset.sample(num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True) + "Tokenizer/model must have chat template for sonnet dataset." + ) + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) elif args.dataset_name == "hf": # all following datasets are implemented from the @@ -603,27 +681,37 @@ def main(args: argparse.Namespace): elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: dataset_class = ASRDataset args.hf_split = "train" else: - supported_datasets = set([ - dataset_name for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ]) + supported_datasets = set( + [ + dataset_name + for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ] + ) raise ValueError( f"Unsupported dataset path: {args.dataset_path}. " "Huggingface dataset only supports dataset_path" f" from one of following: {supported_datasets}. " "Please consider contributing if you would " - "like to add support for additional dataset formats.") + "like to add support for additional dataset formats." + ) - if (dataset_class.IS_MULTIMODAL and backend not in \ - ["openai-chat", "openai-audio"]): + if dataset_class.IS_MULTIMODAL and backend not in [ + "openai-chat", + "openai-audio", + ]: # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " \ - "'openai-audio' backend.") + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend." + ) input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -638,26 +726,24 @@ def main(args: argparse.Namespace): else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(dataset_path=args.dataset_path).sample( + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, - ) + ), } try: @@ -673,19 +759,24 @@ def main(args: argparse.Namespace): "top_p": args.top_p, "top_k": args.top_k, "min_p": args.min_p, - "temperature": args.temperature - }.items() if v is not None + "temperature": args.temperature, + }.items() + if v is not None } # Sampling parameters are only supported by openai-compatible backend. if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: raise ValueError( - "Sampling parameters are only supported by openai-compatible " - "backends.") + "Sampling parameters are only supported by openai-compatible backends." + ) if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. + if args.backend == "llama.cpp": + # Disable prompt caching in llama.cpp backend + sampling_params["cache_prompt"] = False + # Avoid GC processing "static" data - reduce pause times. gc.collect() gc.freeze() @@ -705,15 +796,14 @@ def main(args: argparse.Namespace): disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, - )) + ) + ) # Save config and results to json if args.save_result or args.append_result: @@ -738,8 +828,9 @@ def main(args: argparse.Namespace): "Invalid metadata format. Please use KEY=VALUE format." ) # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf" + ) result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -749,24 +840,34 @@ def main(args: argparse.Namespace): if not args.save_detailed: # Remove fields with too many data points for field in [ - "input_lens", "output_lens", "ttfts", "itls", - "generated_texts", "errors" + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", ]: if field in result_json: del result_json[field] + if field in benchmark_result: + del benchmark_result[field] # Save to file base_model_id = model_id.split("/")[-1] - max_concurrency_str = (f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None else "") - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None + else "" + ) + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: file_name = args.result_filename if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) - with open(file_name, - mode="a+" if args.append_result else "w", - encoding='utf-8') as outfile: + with open( + file_name, mode="a+" if args.append_result else "w", encoding="utf-8" + ) as outfile: # Append a newline. if args.append_result and outfile.tell() != 0: outfile.write("\n") @@ -776,7 +877,8 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput.") + description="Benchmark the online serving throughput." + ) parser.add_argument( "--backend", type=str, @@ -802,14 +904,16 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], help="Name of the dataset to benchmark on.", ) - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) parser.add_argument( "--max-concurrency", type=int, @@ -821,7 +925,8 @@ def main(args: argparse.Namespace): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -832,8 +937,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -846,11 +950,13 @@ def main(args: argparse.Namespace): "--logprobs", type=int, default=None, - help=("Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed"), + help=( + "Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed" + ), ) parser.add_argument( "--request-rate", @@ -934,58 +1040,71 @@ def main(args: argparse.Namespace): "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " - "Default value is \"ttft,tpot,itl\".") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\". " - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help="Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help="Skip applying chat template to prompt, used only for custom dataset.", + ) + sonnet_group = parser.add_argument_group("sonnet dataset options") sonnet_group.add_argument( "--sonnet-input-len", type=int, default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", + help="Number of input tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-output-len", type=int, default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", + help="Number of output tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-prefix-len", type=int, default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", + help="Number of prefix tokens per request, used only for sonnet dataset.", ) sharegpt_group = parser.add_argument_group("sharegpt dataset options") @@ -994,22 +1113,21 @@ def main(args: argparse.Namespace): type=int, default=None, help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") + "from the ShareGPT dataset.", + ) random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", type=int, default=1024, - help= - "Number of input tokens per request, used only for random sampling.", + help="Number of input tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-output-len", type=int, default=128, - help= - "Number of output tokens per request, used only for random sampling.", + help="Number of output tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-range-ratio", @@ -1024,23 +1142,23 @@ def main(args: argparse.Namespace): "--random-prefix-len", type=int, default=0, - help=("Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]."), + help=( + "Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]." + ), ) hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + hf_group.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) hf_group.add_argument( "--hf-output-len", type=int, @@ -1054,52 +1172,58 @@ def main(args: argparse.Namespace): "--top-p", type=float, default=None, - help="Top-p sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Top-p sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--top-k", type=int, default=None, - help="Top-k sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Top-k sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--min-p", type=float, default=None, - help="Min-p sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Min-p sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--temperature", type=float, default=None, help="Temperature sampling parameter. Only has effect on " "openai-compatible backends. If not specified, default to greedy " - "decoding (i.e. temperature==0.0).") + "decoding (i.e. temperature==0.0).", + ) parser.add_argument( - '--tokenizer-mode', + "--tokenizer-mode", type=str, default="auto", - choices=['auto', 'slow', 'mistral', 'custom'], + choices=["auto", "slow", "mistral", "custom"], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer. \n* ' + "always use the slow tokenizer. \n* " '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.') - - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ") - - parser.add_argument("--lora-modules", - nargs='+', - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.") + '"custom" will use --tokenizer to select the preregistered tokenizer.', + ) + + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + + parser.add_argument( + "--lora-modules", + nargs="+", + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.", + ) args = parser.parse_args() diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 9084255d244..3848ebda959 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project r"""Benchmark online serving throughput with structured outputs. On the server side, run one of the following commands: @@ -19,6 +20,7 @@ --endpoint /generate_stream to the end of the command above. """ + import argparse import asyncio import copy @@ -36,11 +38,15 @@ import datasets import numpy as np import pandas as pd -from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, - RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + RequestFuncInput, + RequestFuncOutput, +) + try: from vllm.transformers_utils.tokenizer import get_tokenizer except ImportError: @@ -52,7 +58,8 @@ from argparse import ArgumentParser as FlexibleArgumentParser from vllm.v1.structured_output.backend_xgrammar import ( - has_xgrammar_unsupported_json_features) + has_xgrammar_unsupported_json_features, +) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -98,6 +105,7 @@ class SampleRequest: prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. """ + prompt: str prompt_len: int expected_output_len: int @@ -106,32 +114,28 @@ class SampleRequest: completion: str = None -def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> list[SampleRequest]: - if args.dataset == 'json' or args.dataset == 'json-unique': +def sample_requests( + tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace +) -> list[SampleRequest]: + if args.dataset == "json" or args.dataset == "json-unique": if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) - args.json_schema_path = os.path.join(dir_path, - "structured_schemas", - "structured_schema_1.json") + args.json_schema_path = os.path.join( + dir_path, "structured_schemas", "structured_schema_1.json" + ) json_schemas = [] with open(args.json_schema_path) as f: schema = json.load(f) - if args.dataset == 'json-unique': - json_schemas = [ - copy.deepcopy(schema) for _ in range(args.num_prompts) - ] + if args.dataset == "json-unique": + json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)] for i in range(len(json_schemas)): if "properties" not in json_schemas[i]: json_schemas[i]["properties"] = {} - json_schemas[i]["properties"][ - f"__optional_field_{uuid.uuid4()}"] = { - "type": - "string", - "description": - "An unique optional field to avoid cached schemas" - } + json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = { + "type": "string", + "description": "An unique optional field to avoid cached schemas", + } else: json_schemas = [schema] * args.num_prompts @@ -142,11 +146,13 @@ def get_schema(index: int): return json_schemas[index % len(json_schemas)] requests = [ - SampleRequest(prompt=gen_prompt(i), - prompt_len=len(tokenizer(gen_prompt(i)).input_ids), - expected_output_len=args.output_len, - schema=get_schema(i), - structure_type=args.structure_type) + SampleRequest( + prompt=gen_prompt(i), + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), + expected_output_len=args.output_len, + schema=get_schema(i), + structure_type=args.structure_type, + ) for i in range(args.num_prompts) ] @@ -170,11 +176,13 @@ def get_schema(index: int): input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] @@ -188,11 +196,13 @@ def get_schema(index: int): input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=regex, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] @@ -203,48 +213,55 @@ def get_schema(index: int): input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=choice, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] elif args.dataset == "xgrammar_bench": requests: list[SampleRequest] = [] - dataset = datasets.load_dataset("NousResearch/json-mode-eval", - split="train") + dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") full_dataset_len = len(dataset) def _filter_func(item): import json + schema = json.loads(item["schema"]) return not has_xgrammar_unsupported_json_features(schema) dataset = dataset.filter(_filter_func) num_filtered_out = full_dataset_len - len(dataset) - print(f"dataset has {len(dataset)} entries after filtering " - f"out {num_filtered_out} entries with unsupported features") + print( + f"dataset has {len(dataset)} entries after filtering " + f"out {num_filtered_out} entries with unsupported features" + ) len_dataset = len(dataset) for data_point_idx in range(args.num_prompts): idx = data_point_idx while idx >= len_dataset: idx -= len_dataset schema = dataset["schema"][idx] - prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], - tokenize=False, - add_generation_prompt=True) + prompt = tokenizer.apply_chat_template( + dataset["prompt"][idx], tokenize=False, add_generation_prompt=True + ) input_len = len(tokenizer(prompt).input_ids) completion = dataset["completion"][idx] requests.append( - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type, - completion=completion)) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + completion=completion, + ) + ) return requests @@ -276,7 +293,8 @@ async def get_request( # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) theta = 1.0 / (request_rate * burstiness) for i, request in enumerate(input_requests): @@ -318,8 +336,8 @@ def calculate_metrics( # multiple output tokens may be bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -343,16 +361,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -363,7 +384,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -372,27 +394,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by backend + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics, actual_output_lens @@ -429,12 +455,13 @@ def prepare_extra_body(request) -> dict: print("Starting initial single prompt test run...") structured_output_req_idx = random.sample( - range(len(input_requests)), - int(len(input_requests) * structured_output_ratio)) + range(len(input_requests)), int(len(input_requests) * structured_output_ratio) + ) test_request = input_requests[0] - test_req_extra_body = (prepare_extra_body(test_request) - if 0 in structured_output_req_idx else None) + test_req_extra_body = ( + prepare_extra_body(test_request) if 0 in structured_output_req_idx else None + ) test_input = RequestFuncInput( model=model_id, prompt=test_request.prompt, @@ -448,7 +475,8 @@ def prepare_extra_body(request) -> dict: if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + f"are correctly specified. Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") @@ -467,10 +495,7 @@ def prepare_extra_body(request) -> dict: if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" - else: - distribution = "Gamma distribution" + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") @@ -482,24 +507,21 @@ def prepare_extra_body(request) -> dict: # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] expected: list[str] = [] - async for i, request in get_request(input_requests, request_rate, - burstiness): - extra_body = prepare_extra_body( - request) if i in structured_output_req_idx else None + async for i, request in get_request(input_requests, request_rate, burstiness): + extra_body = ( + prepare_extra_body(request) if i in structured_output_req_idx else None + ) request_func_input = RequestFuncInput( model=model_id, prompt=request.prompt, @@ -512,8 +534,9 @@ async def limited_request_func(request_func_input, pbar): expected.append(request.completion) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -545,54 +568,58 @@ async def limited_request_func(request_func_input, pbar): goodput_config_dict=goodput_config_dict, ) - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) result = { - "duration": - benchmark_duration, - "completed": - metrics.completed, - "total_input_tokens": - metrics.total_input, - "total_output_tokens": - metrics.total_output, - "request_throughput": - metrics.request_throughput, - "output_throughput": - metrics.output_throughput, - "total_token_throughput": - metrics.total_token_throughput, - "ttft_description": - pd.Series([output.ttft for output in outputs]).describe().to_dict(), - "tpot_description": - pd.Series([output.tpot for output in outputs]).describe().to_dict(), + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "ttft_description": pd.Series([output.ttft for output in outputs]) + .describe() + .to_dict(), + "tpot_description": pd.Series([output.tpot for output in outputs]) + .describe() + .to_dict(), "input_lens": [output.prompt_len for output in outputs], - "output_lens": - actual_output_lens, + "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], "itls": [output.itl for output in outputs], "errors": [output.error for output in outputs], } - ret = [{ - 'generated': output.generated_text, - 'expected': gt - } for output, gt in zip(outputs, expected)] + ret = [ + {"generated": output.generated_text, "expected": gt} + for output, gt in zip(outputs, expected) + ] def process_one_metric( # E.g., "ttft" @@ -606,29 +633,35 @@ def process_one_metric( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -638,13 +671,13 @@ def process_one_metric( def evaluate(ret, args): - def _eval_correctness_json(expected, actual): # extract json string from string using regex - import re - actual = actual.replace('\n', '').replace(' ', '').strip() + import regex as re + + actual = actual.replace("\n", "").replace(" ", "").strip() try: - actual = re.search(r'\{.*\}', actual).group() + actual = re.search(r"\{.*\}", actual).group() actual = json.loads(actual) except Exception: return False @@ -655,29 +688,33 @@ def _eval_correctness_choice(expected, actual): return actual in args.choice def _eval_correctness_regex(expected, actual): - import re + import regex as re + return re.match(args.regex, actual) is not None def _eval_correctness(expected, actual): - if args.structure_type == 'guided_json': + if args.structure_type == "guided_json": return _eval_correctness_json(expected, actual) - elif args.structure_type == 'guided_regex': + elif args.structure_type == "guided_regex": return _eval_correctness_regex(expected, actual) - elif args.structure_type == 'guided_choice': + elif args.structure_type == "guided_choice": return _eval_correctness_choice(expected, actual) else: return None scores = [] for res in ret: - score = _eval_correctness(res['expected'], res['generated']) - res['correctness'] = score + score = _eval_correctness(res["expected"], res["generated"]) + res["correctness"] = score scores.append(score) not_none_scores = [score for score in scores if score is not None] - return (sum(not_none_scores) / len(not_none_scores) * - 100) if len(not_none_scores) > 0 else None + return ( + (sum(not_none_scores) / len(not_none_scores) * 100) + if len(not_none_scores) > 0 + else None + ) def parse_goodput(slo_pairs): @@ -689,9 +726,10 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict @@ -705,12 +743,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -736,19 +776,19 @@ def main(args: argparse.Namespace): tokenizer_mode=args.tokenizer_mode, ) - if args.dataset == 'grammar': - args.structure_type = 'guided_grammar' - elif args.dataset == 'regex': - args.structure_type = 'guided_regex' - elif args.dataset == 'choice': - args.structure_type = 'guided_choice' + if args.dataset == "grammar": + args.structure_type = "guided_grammar" + elif args.dataset == "regex": + args.structure_type = "guided_regex" + elif args.dataset == "choice": + args.structure_type = "guided_choice" else: - args.structure_type = 'guided_json' + args.structure_type = "guided_json" if args.no_structured_output: args.structured_output_ratio = 0 if args.save_results: - result_file_name = f'{args.structured_output_ratio}guided' + result_file_name = f"{args.structured_output_ratio}guided" result_file_name += f"_{backend}" result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.model.split('/')[-1]}" @@ -776,36 +816,29 @@ def main(args: argparse.Namespace): disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, max_concurrency=args.max_concurrency, structured_output_ratio=args.structured_output_ratio, goodput_config_dict=goodput_config_dict, - )) + ) + ) # Save config and results to json score = evaluate(ret, args) - print("correct_rate(%)", score, '\n') + print("correct_rate(%)", score, "\n") if args.save_results: results = { - "backend": - backend, - "model_id": - model_id, - "tokenizer_id": - tokenizer_id, - "num_prompts": - args.num_prompts, - "request_rate": - args.request_rate if args.request_rate < float("inf") else "inf", - "burstiness": - args.burstiness, - "max_concurrency": - args.max_concurrency, - "correct_rate(%)": - score + "backend": backend, + "model_id": model_id, + "tokenizer_id": tokenizer_id, + "num_prompts": args.num_prompts, + "request_rate": args.request_rate + if args.request_rate < float("inf") + else "inf", + "burstiness": args.burstiness, + "max_concurrency": args.max_concurrency, + "correct_rate(%)": score, } results = {"outputs": ret, **results, **benchmark_result} @@ -814,13 +847,14 @@ def main(args: argparse.Namespace): result_file_name = args.result_filename if args.result_dir: result_file_name = os.path.join(args.result_dir, result_file_name) - with open(result_file_name, "w", encoding='utf-8') as outfile: + with open(result_file_name, "w", encoding="utf-8") as outfile: json.dump(results, outfile, indent=4) if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput.") + description="Benchmark the online serving throughput." + ) parser.add_argument( "--backend", type=str, @@ -842,16 +876,14 @@ def main(args: argparse.Namespace): default="/v1/completions", help="API endpoint.", ) - parser.add_argument("--dataset", - default='json', - choices=[ - 'json', 'json-unique', 'grammar', 'regex', - 'choice', 'xgrammar_bench' - ]) - parser.add_argument("--json-schema-path", - type=str, - default=None, - help="Path to json schema.") + parser.add_argument( + "--dataset", + default="json", + choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"], + ) + parser.add_argument( + "--json-schema-path", type=str, default=None, help="Path to json schema." + ) parser.add_argument( "--max-concurrency", type=int, @@ -863,7 +895,8 @@ def main(args: argparse.Namespace): "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", type=str, @@ -873,15 +906,13 @@ def main(args: argparse.Namespace): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument( "--tokenizer-mode", type=str, default="auto", - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument( "--num-prompts", @@ -958,44 +989,51 @@ def main(args: argparse.Namespace): "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " - "Default value is \"ttft,tpot,itl\".") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\". " - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") - - parser.add_argument("--no-structured-output", - action='store_true', - default=False, - help="Whether to disable JSON decoding or not.") - parser.add_argument("--structured-output-ratio", - type=float, - default=1.0, - help="Ratio of Structured Outputs requests") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + parser.add_argument( + "--no-structured-output", + action="store_true", + default=False, + help="Whether to disable JSON decoding or not.", + ) + parser.add_argument( + "--structured-output-ratio", + type=float, + default=1.0, + help="Ratio of Structured Outputs requests", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 486a89a6028..546ac5eaa1f 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline inference throughput.""" + import argparse import dataclasses import json @@ -12,19 +14,26 @@ import torch import uvloop -from benchmark_dataset import (AIMODataset, BurstGPTDataset, - ConversationDataset, InstructCoderDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from PIL import Image from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) - +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase + +from benchmark_dataset import ( + AIMODataset, + BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -162,23 +171,30 @@ def run_vllm( disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -187,7 +203,8 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -197,10 +214,9 @@ def run_vllm( outputs = None if not use_beam_search: start = time.perf_counter() - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + outputs = llm.generate( + prompts, sampling_params, lora_request=lora_requests, use_tqdm=True + ) end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -216,30 +232,35 @@ def run_vllm( beam_width=n, max_tokens=output_len, ignore_eos=True, - )) + ), + ) end = time.perf_counter() return end - start, outputs def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking multimodal models as it properly handles multimodal inputs and chat formatting. For non-multimodal models, use run_vllm() instead. """ from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests." + ) prompts = [] sampling_params: list[SamplingParams] = [] @@ -253,7 +274,8 @@ def run_vllm_chat( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() @@ -283,11 +305,15 @@ async def run_vllm_async(requests: list[SampleRequest], lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -296,17 +322,16 @@ async def run_vllm_async(requests: list[SampleRequest], ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp, - lr) in enumerate(zip(prompts, sampling_params, lora_requests)): - generator = llm.generate(prompt, - sp, - lora_request=lr, - request_id=f"test{i}") + for i, (prompt, sp, lr) in enumerate( + zip(prompts, sampling_params, lora_requests) + ): + generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -325,7 +350,8 @@ def run_hf( disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token @@ -348,14 +374,15 @@ def run_hf( # Check if we can add more requests to the batch. next_prompt_len = requests[i + 1].prompt_len next_output_len = requests[i + 1].expected_output_len - if (max(max_prompt_len, next_prompt_len) + - max(max_output_len, next_output_len)) <= 2048: + if ( + max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len) + ) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", - padding=True).input_ids + input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=True, @@ -385,6 +412,7 @@ def run_mii( output_len: int, ) -> float: from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) prompts = [request.prompt for request in requests] @@ -396,8 +424,9 @@ def run_mii( return end - start -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ @@ -405,9 +434,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, "tokens_per_second": [results["tokens_per_second"]], }, extra_info={ - k: results[k] - for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }) + k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" @@ -439,7 +468,8 @@ def get_requests(args, tokenizer): sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_name == "sonnet": assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) dataset_cls = SonnetDataset sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["return_prompt_formatted"] = True @@ -448,21 +478,21 @@ def get_requests(args, tokenizer): elif args.dataset_name == "hf": if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_split"] = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset - common_kwargs['dataset_subset'] = args.hf_subset - common_kwargs['dataset_split'] = args.hf_split + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_cls = AIMODataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -477,10 +507,10 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None - for request in requests) + is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": if args.async_engine: @@ -491,23 +521,34 @@ def main(args: argparse.Namespace): AsyncEngineArgs.from_cli_args(args), args.disable_frontend_multiprocessing, args.disable_detokenize, - )) + ) + ) else: elapsed_time, request_outputs = run_vllm( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, + args.n, + EngineArgs.from_cli_args(args), + args.disable_detokenize, + ) elif args.backend == "hf": assert args.tensor_parallel_size == 1 - elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.hf_max_batch_size, args.trust_remote_code, - args.disable_detokenize) + elapsed_time = run_hf( + requests, + args.model, + tokenizer, + args.n, + args.hf_max_batch_size, + args.trust_remote_code, + args.disable_detokenize, + ) elif args.backend == "mii": - elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, - args.output_len) + elapsed_time = run_mii( + requests, args.model, args.tensor_parallel_size, args.output_len + ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) else: raise ValueError(f"Unknown backend: {args.backend}") @@ -519,28 +560,31 @@ def main(args: argparse.Namespace): for ro in request_outputs: if not isinstance(ro, RequestOutput): continue - total_prompt_tokens += len( - ro.prompt_token_ids) if ro.prompt_token_ids else 0 - total_output_tokens += sum( - len(o.token_ids) for o in ro.outputs if o) + total_prompt_tokens += ( + len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 + ) + total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) total_num_tokens = total_prompt_tokens + total_output_tokens else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len - for r in requests) + total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests) total_prompt_tokens = total_num_tokens - total_output_tokens if is_multi_modal and args.backend != "vllm-chat": - print("\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details.") + print( + "\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details." + ) # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # vllm-chat backend counts the image tokens now - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s" + ) print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") @@ -568,7 +612,8 @@ def validate_args(args): warnings.warn( "The '--dataset' argument will be deprecated in the next release. " "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2) + stacklevel=2, + ) args.dataset_path = args.dataset if not getattr(args, "tokenizer", None): @@ -581,9 +626,8 @@ def validate_args(args): # === Dataset Configuration === if not args.dataset and not args.dataset_path: - print( - "When dataset path is not set, it will default to random dataset") - args.dataset_name = 'random' + print("When dataset path is not set, it will default to random dataset") + args.dataset_name = "random" if args.input_len is None: raise ValueError("input_len must be provided for a random dataset") @@ -591,41 +635,55 @@ def validate_args(args): # --hf-subset and --hf-split: only used # when dataset_name is 'hf' if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None): - warnings.warn("--hf-subset and --hf-split will be ignored \ + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None + ): + warnings.warn( + "--hf-subset and --hf-split will be ignored \ since --dataset-name is not 'hf'.", - stacklevel=2) + stacklevel=2, + ) elif args.dataset_name == "hf": if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm-chat", ( + f"{args.dataset_path} needs to use vllm-chat as the backend." + ) # noqa: E501 + elif args.dataset_path in ( + InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm", ( + f"{args.dataset_path} needs to use vllm as the backend." + ) # noqa: E501 else: - raise ValueError( - f"{args.dataset_path} is not supported by hf dataset.") + raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != 'random' and args.random_range_ratio is not None: - warnings.warn("--random-range-ratio will be ignored since \ + if args.dataset_name != "random" and args.random_range_ratio is not None: + warnings.warn( + "--random-range-ratio will be ignored since \ --dataset-name is not 'random'.", - stacklevel=2) + stacklevel=2, + ) # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # set. - if args.dataset_name not in {"random", "sonnet", None - } and args.prefix_len is not None: - warnings.warn("--prefix-len will be ignored since --dataset-name\ + if ( + args.dataset_name not in {"random", "sonnet", None} + and args.prefix_len is not None + ): + warnings.warn( + "--prefix-len will be ignored since --dataset-name\ is not 'random', 'sonnet', or not set.", - stacklevel=2) + stacklevel=2, + ) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError( - "LoRA benchmarking is only supported for vLLM backend") + raise ValueError("LoRA benchmarking is only supported for vLLM backend") if getattr(args, "enable_lora", False) and args.lora_path is None: raise ValueError("LoRA path must be provided when enable_lora is True") @@ -635,8 +693,10 @@ def validate_args(args): if args.backend != "hf" and args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") - if args.backend in {"hf", "mii"} and getattr(args, "quantization", - None) is not None: + if ( + args.backend in {"hf", "mii"} + and getattr(args, "quantization", None) is not None + ): raise ValueError("Quantization is only for vLLM backend.") if args.backend == "mii" and args.dtype != "auto": @@ -644,29 +704,32 @@ def validate_args(args): if args.backend == "mii" and args.n != 1: raise ValueError("n must be 1 for MII backend.") if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError( - "Tokenizer must be the same as the model for MII backend.") + raise ValueError("Tokenizer must be the same as the model for MII backend.") # --data-parallel is not supported currently. # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead") + please use benchmark serving instead" + ) if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm") + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm", + ) parser.add_argument( "--dataset-name", type=str, choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], help="Name of the dataset to benchmark on.", - default="sharegpt") + default="sharegpt", + ) parser.add_argument( "--dataset", type=str, @@ -674,57 +737,70 @@ def validate_args(args): help="Path to the ShareGPT dataset, will be deprecated in\ the next release. The dataset is expected to " "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]") - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.") + "list[dict[..., value: ]]]]", + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset" + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=1000, help="Number of prompts to process." + ) + parser.add_argument( + "--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.", + ) parser.add_argument( - '--output-json', + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument("--async-engine", - action='store_true', - default=False, - help="Use vLLM async engine rather than LLM class.") - parser.add_argument("--disable-frontend-multiprocessing", - action='store_true', - default=False, - help="Disable decoupled async engine frontend.") + help="Path to save the throughput results in JSON format.", + ) + parser.add_argument( + "--async-engine", + action="store_true", + default=False, + help="Use vLLM async engine rather than LLM class.", + ) + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + default=False, + help="Disable decoupled async engine frontend.", + ) parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)")) + help=( + "Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)" + ), + ) # LoRA parser.add_argument( "--lora-path", type=str, default=None, help="Path to the LoRA adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.") + "a relative path, or a Hugging Face model identifier." + ) parser.add_argument( "--prefix-len", type=int, @@ -738,7 +814,8 @@ def validate_args(args): f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " "controls how much of the input is fixed lines versus " "random lines, but the total input length remains approximately " - "input_len tokens.") + "input_len tokens.", + ) # random dataset parser.add_argument( "--random-range-ratio", @@ -752,14 +829,12 @@ def validate_args(args): ) # hf dtaset - parser.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - parser.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + parser.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + parser.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 45a0ddbd5d0..283f938df50 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json @@ -7,9 +8,9 @@ from typing import Any -def convert_to_pytorch_benchmark_format(args: argparse.Namespace, - metrics: dict[str, list], - extra_info: dict[str, Any]) -> list: +def convert_to_pytorch_benchmark_format( + args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] +) -> list: """ Save the benchmark results in the format used by PyTorch OSS benchmark with on metric per record @@ -37,12 +38,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, }, } - tp = record["benchmark"]["extra_info"]["args"].get( - "tensor_parallel_size") + tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") # Save tensor_parallel_size parameter if it's part of the metadata if not tp and "tensor_parallel_size" in extra_info: - record["benchmark"]["extra_info"]["args"][ - "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( + extra_info["tensor_parallel_size"] + ) records.append(record) @@ -50,7 +51,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, class InfEncoder(json.JSONEncoder): - def clear_inf(self, o: Any): if isinstance(o, dict): return {k: self.clear_inf(v) for k, v in o.items()} @@ -66,4 +66,9 @@ def iterencode(self, o: Any, *args, **kwargs) -> Any: def write_to_json(filename: str, records: list) -> None: with open(filename, "w") as f: - json.dump(records, f, cls=InfEncoder) + json.dump( + records, + f, + cls=InfEncoder, + default=lambda o: f"<{type(o).__name__} object is not JSON serializable>", + ) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 9e36b0a9d3b..9ec270bbd2e 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy @@ -23,8 +24,9 @@ # bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: min_run_time = 1 globals = { @@ -41,16 +43,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ).blocked_autorange(min_run_time=min_run_time) -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_int8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: assert dtype == torch.int8 b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, - torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): @@ -63,54 +67,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, timers = [] # pytorch impl - bfloat16 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16), + ) + ) # pytorch impl - float16 timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + bench_fn( + label, + sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.float16), + b.to(dtype=torch.float16), + ) + ) # cutlass impl timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass with bias timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) # cutlass sparse impl timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass sparse with bias timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16, bias)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) return timers -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_fp8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: assert dtype == torch.float8_e4m3fn - b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, - k) + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, - torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): @@ -124,97 +181,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # pytorch impl w. bf16 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), + ) + ) # pytorch impl: bf16 output, without fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + ) + ) # pytorch impl: bf16 output, with fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + ) # pytorch impl: fp16 output, without fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + ) + ) # pytorch impl: fp16 output, with fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True, + ) + ) # cutlass impl: bf16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass impl: bf16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass impl: fp16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.float16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + ) + ) # cutlass impl: bf16 output, with bias timers.append( - bench_fn(label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16, bias)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) # cutlass impl: fp16 output, with bias timers.append( - bench_fn(label, sub_label, - "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.float16, bias.to(dtype=torch.float16))) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + bias.to(dtype=torch.float16), + ) + ) return timers -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label) if dtype == torch.float8_e4m3fn: @@ -228,12 +353,12 @@ def print_timers(timers: Iterable[TMeasurement]): compare.print() -def run(dtype: torch.dtype, - MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run( + dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]] +) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})") print_timers(timers) results.extend(timers) @@ -241,10 +366,12 @@ def run(dtype: torch.dtype, # output makers -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None): +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): print(f"== All Results {base_description} ====") print_timers(data) @@ -258,8 +385,7 @@ def make_output(data: Iterable[TMeasurement], def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, MKNs) @@ -319,7 +445,7 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: pkl.dump(all_data, f) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -344,12 +470,15 @@ def to_torch_dtype(dt): Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") @@ -368,19 +497,19 @@ def to_torch_dtype(dt): range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index fe4d8fdfc06..b4f3c6bf94e 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Cutlass bench utils from collections.abc import Iterable @@ -10,8 +11,9 @@ def to_fp8(tensor: torch.Tensor) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) def to_int8(tensor: torch.Tensor) -> torch.Tensor: @@ -26,10 +28,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: return tensor.to(dtype=torch.float16) -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 +def make_rand_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 if dtype == torch.int8: return to_int8(a), to_int8(b) @@ -49,9 +52,7 @@ def prune_to_2_4(tensor): # Create binary mask mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) # Apply mask and reshape back pruned = reshaped * mask @@ -62,10 +63,11 @@ def prune_to_2_4(tensor): return pruned.reshape(original_shape) -def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 +def make_rand_sparse_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 b = prune_to_2_4(b.t()).t() @@ -86,9 +88,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, return b_compressed, e, a, b -def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, - m: int, n: int, k: int) -> \ - tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: +def make_n_rand_sparse_tensors( + num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ABs = [] for _ in range(num_tensors): b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index e7b742d8bec..cec422e8d59 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy @@ -16,7 +17,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul) + w8a8_block_fp8_matmul, +) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -25,8 +27,9 @@ # bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: min_run_time = 1 globals = { @@ -44,45 +47,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_int8( - dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) - azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m,), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32) bench_fns = { - "pytorch_bf16_bf16_bf16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) - ), - "pytorch_fp16_fp16_fp16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), - "cutlass_i8_i8_bf16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), - "cutlass_i8_i8_bf16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, - bias), - "cutlass_i8_i8_bf16_scaled_mm_azp": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj), - "cutlass_i8_i8_bf16_scaled_mm_azp_bias": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, None, bias), - "cutlass_i8_i8_bf16_scaled_mm_azp_pt": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, azp), - "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, azp, bias), + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias + ), } timers = [] @@ -96,73 +102,73 @@ def bench_int8( def bench_fp8( - dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - block_scale_a = torch.rand((m, k // 128), - device="cuda", - dtype=torch.float32) - block_scale_b = torch.rand((k // 128, n // 128), - device="cuda", - dtype=torch.float32) + + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + block_scale_a = torch.rand( + (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 + ) + block_scale_b = torch.rand( + ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 + ) block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t() - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) print(m, k, n) bench_fns = { - "pytorch_bf16_bf16_bf16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) - ), - "pytorch_fp16_fp16_fp16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), - "pytorch_fp8_fp8_fp16_scaled_mm": - lambda: torch._scaled_mm( - a, b, scale_a, scale_b, out_dtype=torch.float16), - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": - lambda: torch._scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.float16, - use_fast_accum=True), - "pytorch_fp8_fp8_bf16_scaled_mm": - lambda: torch._scaled_mm( - a, b, scale_a, scale_b, out_dtype=torch.bfloat16), - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": - lambda: torch._scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True), - "cutlass_fp8_fp8_bf16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), - "cutlass_fp8_fp8_fp16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), - "cutlass_fp8_fp8_bf16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, - bias), - "cutlass_fp8_fp8_fp16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16)), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": - lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a, - block_scale_b.t(), (128, 128)), - "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": - lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major, - block_scale_b_K_major, torch.float16), + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16 + ), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True + ), + "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16 + ), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True + ), + "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16 + ), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) + ), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) + ), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( + a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16 + ), } timers = [] @@ -175,13 +181,15 @@ def bench_fp8( return timers -def bench(dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: +def bench( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: @@ -195,27 +203,33 @@ def print_timers(timers: Iterable[TMeasurement]): compare.print() -def run(dtype: torch.dtype, - MKNs: Iterable[tuple[int, int, int]], - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: +def run( + dtype: torch.dtype, + MKNs: Iterable[tuple[int, int, int]], + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, - m, - k, - n, - f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})", - bench_kernels=bench_kernels) + timers = bench( + dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels, + ) print_timers(timers) results.extend(timers) return results -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None): +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): print(f"== All Results {base_description} ====") print_timers(data) @@ -226,8 +240,7 @@ def make_output(data: Iterable[TMeasurement], def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -285,7 +298,7 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: pkl.dump(all_data, f) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -310,19 +323,21 @@ def to_torch_dtype(dt): Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) parser.add_argument( "--kernels", nargs="+", type=str, default=None, - help= - "Exact names of the kernels to benchmark. If not set, runs all kernels." + help="Exact names of the kernels to benchmark. If not set, runs all kernels.", ) subparsers = parser.add_subparsers(dest="cmd") @@ -343,19 +358,19 @@ def to_torch_dtype(dt): range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 3d1121df40d..25b96ef5662 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Weight Shapes are in the format # ([K, N], TP_SPLIT_DIM) @@ -42,4 +43,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} \ No newline at end of file +} diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 980e6866891..f62d8102e2d 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os @@ -12,39 +13,37 @@ async def forward_request(url, data): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } - async with session.post(url=url, json=data, - headers=headers) as response: + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: # if response.headers.get('Transfer-Encoding') == 'chunked': if True: - async for chunk_bytes in response.content.iter_chunked( - 1024): + async for chunk_bytes in response.content.iter_chunked(1024): yield chunk_bytes else: content = await response.read() yield content -@app.route('/v1/completions', methods=['POST']) +@app.route("/v1/completions", methods=["POST"]) async def handle_request(): try: original_request_data = await request.get_json() prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill - prefill_request['max_tokens'] = 1 + prefill_request["max_tokens"] = 1 # finish prefill - async for _ in forward_request('http://localhost:8100/v1/completions', - prefill_request): + async for _ in forward_request( + "http://localhost:8100/v1/completions", prefill_request + ): continue # return decode - generator = forward_request('http://localhost:8200/v1/completions', - original_request_data) + generator = forward_request( + "http://localhost:8200/v1/completions", original_request_data + ) response = await make_response(generator) response.timeout = None @@ -53,11 +52,12 @@ async def handle_request(): except Exception as e: import sys import traceback + exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server") print(e) print("".join(traceback.format_exception(*exc_info))) -if __name__ == '__main__': +if __name__ == "__main__": app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index c2ad4916bf0..b1df2f25582 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import itertools @@ -8,7 +9,6 @@ class RoundRobinProxy: - def __init__(self, target_ports): self.target_ports = target_ports self.port_cycle = itertools.cycle(self.target_ports) @@ -21,14 +21,15 @@ async def handle_request(self, request): try: # Forward the request async with session.request( - method=request.method, - url=target_url, - headers=request.headers, - data=request.content, + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, ) as response: # Start sending the response - resp = web.StreamResponse(status=response.status, - headers=response.headers) + resp = web.StreamResponse( + status=response.status, headers=response.headers + ) await resp.prepare(request) # Stream the response content @@ -45,11 +46,11 @@ async def handle_request(self, request): async def main(): proxy = RoundRobinProxy([8100, 8200]) app = web.Application() - app.router.add_route('*', '/{path:.*}', proxy.handle_request) + app.router.add_route("*", "/{path:.*}", proxy.handle_request) runner = web.AppRunner(app) await runner.setup() - site = web.TCPSite(runner, 'localhost', 8000) + site = web.TCPSite(runner, "localhost", 8000) await site.start() print("Proxy server started on http://localhost:8000") @@ -58,5 +59,5 @@ async def main(): await asyncio.Event().wait() -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index a7b4b9e8bf3..74fa56d076c 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json @@ -6,43 +7,41 @@ import pandas as pd if __name__ == "__main__": - data = [] - for name in ['disagg_prefill', 'chunked_prefill']: + for name in ["disagg_prefill", "chunked_prefill"]: for qps in [2, 4, 6, 8]: with open(f"results/{name}-qps-{qps}.json") as f: x = json.load(f) - x['name'] = name - x['qps'] = qps + x["name"] = name + x["qps"] = qps data.append(x) df = pd.DataFrame.from_dict(data) - dis_df = df[df['name'] == 'disagg_prefill'] - chu_df = df[df['name'] == 'chunked_prefill'] + dis_df = df[df["name"] == "disagg_prefill"] + chu_df = df[df["name"] == "chunked_prefill"] - plt.style.use('bmh') - plt.rcParams['font.size'] = 20 + plt.style.use("bmh") + plt.rcParams["font.size"] = 20 for key in [ - 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', - 'median_itl_ms', 'p99_itl_ms' + "mean_ttft_ms", + "median_ttft_ms", + "p99_ttft_ms", + "mean_itl_ms", + "median_itl_ms", + "p99_itl_ms", ]: - fig, ax = plt.subplots(figsize=(11, 7)) - plt.plot(dis_df['qps'], - dis_df[key], - label='disagg_prefill', - marker='o', - linewidth=4) - plt.plot(chu_df['qps'], - chu_df[key], - label='chunked_prefill', - marker='o', - linewidth=4) + plt.plot( + dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4 + ) + plt.plot( + chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4 + ) ax.legend() - ax.set_xlabel('QPS') + ax.set_xlabel("QPS") ax.set_ylabel(key) ax.set_ylim(bottom=0) - fig.savefig(f'results/{key}.png') + fig.savefig(f"results/{key}.png") plt.close(fig) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 3da583a3344..90152421446 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pickle as pkl import time @@ -24,10 +25,12 @@ class bench_params_t: dtype: torch.dtype def description(self): - return (f'N {self.num_tokens} ' - f'x D {self.hidden_size} ' - f'x R {self.add_residual} ' - f'x DT {self.dtype}') + return ( + f"N {self.num_tokens} " + f"x D {self.hidden_size} " + f"x R {self.add_residual} " + f"x DT {self.dtype}" + ) def get_bench_params() -> list[bench_params_t]: @@ -38,15 +41,19 @@ def get_bench_params() -> list[bench_params_t]: DTYPES = [torch.bfloat16, torch.float] combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) - bench_params = list(map(lambda x: \ - bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + bench_params = list( + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + ) return bench_params # Reference impls -def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): +def unfused_int8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): # Norm torch_out = None if residual is None: @@ -58,9 +65,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, torch_out, _, _ = ops.scaled_int8_quant(torch_out) -def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): +def unfused_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): # Norm torch_out = None if residual is None: @@ -73,22 +83,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def fused_impl( - rms_norm_layer: RMSNorm, # this stores the weights - x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): - out, _ = ops.rms_norm_dynamic_per_token_quant(x, - rms_norm_layer.weight, - 1e-6, - quant_dtype, - residual=residual) + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): + out, _ = ops.rms_norm_dynamic_per_token_quant( + x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual + ) # Bench functions -def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, - quant_dtype: torch.dtype, label: str, sub_label: str, - fn: Callable, description: str) -> TMeasurement: - +def bench_fn( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor, + quant_dtype: torch.dtype, + label: str, + sub_label: str, + fn: Callable, + description: str, +) -> TMeasurement: min_run_time = 1 globals = { @@ -106,43 +121,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, description=description, ).blocked_autorange(min_run_time=min_run_time) -def bench(params: bench_params_t, label: str, sub_label: str) \ - -> Iterable[TMeasurement]: +def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]: # Make inputs layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) # Make weights layer.weight.data.normal_(mean=1.0, std=0.1) # Make inputs scale = 1 / params.hidden_size - x = torch.randn(params.num_tokens, - params.hidden_size, - dtype=params.dtype, - device='cuda') * scale - residual = (torch.randn_like(x) * scale).to(device='cuda') \ - if params.add_residual else None + x = ( + torch.randn( + params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda" + ) + * scale + ) + residual = ( + (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None + ) timers = [] # unfused int8 impl. timers.append( - bench_fn(layer, x, residual, torch.int8, label, sub_label, - unfused_int8_impl, "unfused_int8_impl")) + bench_fn( + layer, + x, + residual, + torch.int8, + label, + sub_label, + unfused_int8_impl, + "unfused_int8_impl", + ) + ) # unfused fp8 impl. timers.append( - bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, - unfused_fp8_impl, "unfused_fp8_impl")) + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + label, + sub_label, + unfused_fp8_impl, + "unfused_fp8_impl", + ) + ) # fused int8 impl. timers.append( - bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, - "fused_int8_impl")) + bench_fn( + layer, + x, + residual, + torch.int8, + label, + sub_label, + fused_impl, + "fused_int8_impl", + ) + ) # fused fp8 impl. timers.append( - bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, - fused_impl, "fused_fp8_impl")) + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + label, + sub_label, + fused_impl, + "fused_fp8_impl", + ) + ) print_timers(timers) @@ -157,13 +210,12 @@ def print_timers(timers: Iterable[TMeasurement]): def main(): - torch.set_default_device('cuda') + torch.set_default_device("cuda") bench_params = get_bench_params() timers = [] for bp in tqdm(bench_params): - timers.extend( - bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) print_timers(timers) # pickle all the results @@ -172,5 +224,5 @@ def main(): pkl.dump(timers, f) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/benchmarks/kernels/bench_fp8_gemm.py b/benchmarks/kernels/bench_fp8_gemm.py new file mode 100644 index 00000000000..b964ed242ed --- /dev/null +++ b/benchmarks/kernels/bench_fp8_gemm.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant +from vllm.triton_utils import triton + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=[ + "torch-bf16", + # "fp8-tensor-w-token-a", + "fp8-tensor-w-tensor-a", + "fp8-channel-w-token-a", + # "fp8-channel-w-tensor-a", + # "fp8-tensor-w-token-a-noquant", + "fp8-tensor-w-tensor-a-noquant", + "fp8-channel-w-token-a-noquant", + # "fp8-channel-w-tensor-a-noquant", + ], + line_names=[ + "torch-bf16", + # "fp8-tensor-w-token-a", + "fp8-tensor-w-tensor-a", + "fp8-channel-w-token-a", + # "fp8-channel-w-tensor-a", + # "fp8-tensor-w-token-a-noquant", + "fp8-tensor-w-tensor-a-noquant", + "fp8-channel-w-token-a-noquant", + # "fp8-channel-w-tensor-a-noquant", + ], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs FP8 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + # Create input tensors + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if "torch-bf16" in provider: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + + elif "fp8" in provider: + # Weights are always quantized ahead of time + if "noquant" in provider: + # For no quantization, we just measure the GEMM + if "tensor-w-token-a" in provider: + # Dynamic per-token quant for A, per-tensor quant for B + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) + assert scale_b_fp8.numel() == 1 + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "tensor-w-tensor-a" in provider: + # Static per-tensor quantization with fixed scales + # for both A and B + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + assert scale_b_fp8.numel() == 1 + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-token-a" in provider: + # Static per-channel quantization for weights, per-token + # quant for A + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-tensor-a" in provider: + # Static per-channel quantization for weights, per-tensor + # quant for A + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + + def run_quant(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + else: + # In these cases, we quantize the activations during the GEMM call + if "tensor-w-token-a" in provider: + # Dynamic per-token quant for A, per-tensor quant for B + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) + assert scale_b_fp8.numel() == 1 + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "tensor-w-tensor-a" in provider: + # Static per-tensor quantization with fixed scales + # for both A and B + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + assert scale_b_fp8.numel() == 1 + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-token-a" in provider: + # Static per-channel quantization for weights, per-token + # quant for A + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( + a, use_per_token_if_dynamic=True + ) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + elif "channel-w-tensor-a" in provider: + # Static per-channel quantization for weights, per-tensor + # quant for A + scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) + scale_b = torch.tensor((N,), device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + scale_b_fp8 = scale_b_fp8.expand(N).contiguous() + assert scale_b_fp8.numel() == N + + def run_quant(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + b_fp8 = b_fp8.t() + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + # Calculate TFLOP/s, two flops per multiply-add + tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) + return tflops(ms), tflops(max_ms), tflops(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=[*WEIGHT_SHAPES.keys()], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_fp8_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 8d20b91560d..42de062b08e 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import sys @@ -9,32 +10,39 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.aqlm import ( - dequantize_weight, generic_dequantize_gemm, get_int_dtype, - optimized_dequantize_gemm) + dequantize_weight, + generic_dequantize_gemm, + get_int_dtype, + optimized_dequantize_gemm, +) from vllm.utils import FlexibleArgumentParser -os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ["CUDA_VISIBLE_DEVICES"] = "0" def torch_mult( - input: torch.Tensor, # [..., in_features] - weights: torch.Tensor, - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + weights: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, ) -> torch.Tensor: output = F.linear(input, weights) return output def dequant_out_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) if bias is None: @@ -46,40 +54,42 @@ def dequant_out_scale( flattened_output *= b_scales return flattened_output.view(orig_shape) else: - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) + b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) weights *= b_scales return F.linear(input, weights, bias) def dequant_weight_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) + b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) weights *= b_scales return F.linear(input, weights, bias) def dequant_no_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) return F.linear(input, weights, bias) @@ -89,23 +99,26 @@ def dequant_no_scale( # the generic pytorch version. # Just visual comparison. def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: - n = int(parts.sum().item()) - device = torch.device('cuda:0') + device = torch.device("cuda:0") code_range = (1 << bits) // 2 ingroups = 8 - codes = torch.randint(-code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device) + codes = torch.randint( + -code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device, + ) - codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device) + codebooks = torch.randn( + size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device, + ) count = 0 for index in range(16): @@ -138,24 +151,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: def main(): - parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") # Add arguments - parser.add_argument("--nbooks", - type=int, - default=1, - help="Number of codebooks (default: 1)") - parser.add_argument("--bits", - type=int, - default=16, - help="Number of bits per code element (default: 16)") + parser.add_argument( + "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)" + ) + parser.add_argument( + "--bits", + type=int, + default=16, + help="Number of bits per code element (default: 16)", + ) parser.add_argument( "--test", type=bool, default=False, help="Run the decompression/dequant tester rather than benchmarking " - "(default: False)") + "(default: False)", + ) # Parse the arguments args = parser.parse_args() @@ -165,7 +179,7 @@ def main(): bits = args.bits if args.test: - dequant_test(4096, torch.tensor((4096, )), nbooks, bits) + dequant_test(4096, torch.tensor((4096,)), nbooks, bits) return # Otherwise, benchmark. @@ -184,31 +198,54 @@ def main(): with open(filename, "w") as f: sys.stdout = f - print('m | k | n | n parts', end='') + print("m | k | n | n parts", end="") for method in methods: - print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') - print('') + print(f" | {method.__name__.replace('_', ' ')} (µs)", end="") + print("") # These are reasonable prefill sizes. - ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), - (4096, (11008, 11008)), (11008, (4096, ))) + ksandpartions = ( + (4096, (4096, 4096, 4096)), + (4096, (4096,)), + (4096, (11008, 11008)), + (11008, (4096,)), + ) # reasonable ranges for m. for m in [ - 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, - 128, 256, 512, 1024, 1536, 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 10, + 12, + 14, + 16, + 24, + 32, + 48, + 52, + 56, + 64, + 96, + 112, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ]: - print(f'{m}', file=sys.__stdout__) + print(f"{m}", file=sys.__stdout__) for ksp in ksandpartions: - run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, - methods) + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods) sys.stdout = sys.__stdout__ -def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, - methods): - +def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods): # I didn't see visible improvements from increasing these, but feel free :) num_warmup_trials = 1 num_trials = 1 @@ -229,7 +266,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, ) n = parts.sum().item() - print(f'{m} | {k} | {n} | {parts.tolist()}', end='') + print(f"{m} | {k} | {n} | {parts.tolist()}", end="") for method in methods: best_time_us = 1e20 @@ -249,32 +286,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, if kernel_dur_us < best_time_us: best_time_us = kernel_dur_us - print(f' | {kernel_dur_us:.0f}', end='') + print(f" | {kernel_dur_us:.0f}", end="") - print('') + print("") -def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, - nbooks: int, bits: int, method) -> float: - +def run_timing( + num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method +) -> float: n = int(parts.sum().item()) - device = torch.device('cuda:0') + device = torch.device("cuda:0") input = torch.randn((1, m, k), dtype=torch.float16, device=device) code_range = (1 << bits) // 2 ingroups = 8 - codes = torch.randint(-code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device) - - codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device) + codes = torch.randint( + -code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device, + ) + + codebooks = torch.randn( + size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device, + ) scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index b23b4f3ea68..97ee0603413 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -1,29 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - MINIMUM_BITBLAS_VERSION) + MINIMUM_BITBLAS_VERSION, +) try: import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: - raise ImportError("bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e - raise ValueError("Trying to use the bitblas backend, but could not import" - f"with the following error: {bitblas_import_exception}. " - "Please install bitblas through the following command: " - f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" - ) from bitblas_import_exception + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target from vllm.utils import FlexibleArgumentParser parser = FlexibleArgumentParser( - description="Benchmark BitBLAS int4 on a specific target.") + description="Benchmark BitBLAS int4 on a specific target." +) # Add arguments to the parser parser.add_argument( @@ -32,10 +39,9 @@ default=auto_detect_nvidia_target(), help="Specify the target device for benchmarking.", ) -parser.add_argument("--group_size", - type=int, - default=None, - help="Group size for grouped quantization.") +parser.add_argument( + "--group_size", type=int, default=None, help="Group size for grouped quantization." +) parser.add_argument( "--A_dtype", type=str, @@ -82,17 +88,17 @@ choices=["nt", "nn"], help="Matrix layout, 'nt' for non-transpose A and transpose W.", ) -parser.add_argument("--with_bias", - action="store_true", - help="Include bias in the benchmark.") +parser.add_argument( + "--with_bias", action="store_true", help="Include bias in the benchmark." +) parser.add_argument( "--with_scaling", action="store_true", help="Include scaling factor in the quantization.", ) -parser.add_argument("--with_zeros", - action="store_true", - help="Include zeros in the quantization.") +parser.add_argument( + "--with_zeros", action="store_true", help="Include zeros in the quantization." +) parser.add_argument( "--zeros_mode", type=str, @@ -170,8 +176,7 @@ ] # Build test shapes with all the shared arguments -test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) - for shape in shapes] +test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes] benchmark_sets = [] benchmark_sets.extend(test_shapes) @@ -206,12 +211,12 @@ func_name = args_split[0] input_args_str = "-".join(args_split[1:]) col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) - col_widths[1] = max(col_widths[1], - len(input_args_str) + 2, - len(headers[1]) + 2) - col_widths[2] = max(col_widths[2], - len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, - len(headers[2]) + 2) + col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2) + col_widths[2] = max( + col_widths[2], + len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, + len(headers[2]) + 2, + ) # break only if you want to measure widths from a single example; # otherwise, let it loop over all items. @@ -232,5 +237,6 @@ f"{values['BitBLAS_top20_latency']:.3f} ms", ] row_str = "".join( - [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) + [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)] + ) print(row_str) diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py new file mode 100644 index 00000000000..35c20ee41b9 --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -0,0 +1,490 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe +kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit +activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) +and 16-bit activations. +""" + +import nvtx +import torch +import torch.utils.benchmark as benchmark + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.scalar_type import scalar_types +from vllm.utils import FlexibleArgumentParser + +WEIGHT_SHAPES_MOE = { + "nvidia/DeepSeek-R1-FP4": [ + [256, 8, 2048, 7168], + ], +} + +DEFAULT_MODELS = [ + "nvidia/DeepSeek-R1-FP4", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + _, a_fp8_scale = ops.scaled_fp8_quant(a) + + w1_fp8q = torch.empty( + (num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn + ) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn) + w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + w1_fp8q_notransp = w1_fp8q.clone() + w2_fp8q_notransp = w2_fp8q.clone() + w1_fp8q = w1_fp8q.transpose(1, 2) + w2_fp8q = w2_fp8q.transpose(1, 2) + + score = torch.randn((m, num_experts), device=device, dtype=dtype) + + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) + + quant_blocksize = 16 + w1_blockscale = torch.empty( + (num_experts, 2 * n, k // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn, + ) + w2_blockscale = torch.empty( + (num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn + ) + + # n_b_scales = 2 * n if per_out_ch else 1 + # k_b_scales = k if per_out_ch else 1 + w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8) + w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8) + + w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_e = w1[expert] + w2_e = w2[expert] + w1_amax = torch.abs(w1_e).max().to(torch.float32) + w2_amax = torch.abs(w2_e).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1_e, w1_gs[expert] + ) + + w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2_e, w2_gs[expert] + ) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + num_repeats: int, + ): + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + def run_cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + w1_gs: torch.Tensor, + w2_gs: torch.Tensor, + a1_gs: torch.Tensor, + a2_gs: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + num_repeats: int, + ): + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp4", color="green"): + cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + + def run_cutlass_from_graph( + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_alphas, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + ) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_fp8q_notransp": w1_fp8q_notransp, + "w2_fp8q_notransp": w2_fp8q_notransp, + "w1_fp8scale": w1_fp8scale, + "w2_fp8scale": w2_fp8scale, + "a_fp8_scale": a_fp8_scale, + # Cutlass params + "a": a, + "a1_gscale": a1_gs, + "w1_fp4": w1_fp4, + "w1_blockscale": w1_blockscale, + "w1_alphas": w1_gs, + "a2_gscale": a2_gs, + "w2_fp4": w2_fp4, + "w2_blockscale": w2_blockscale, + "w2_alphas": w2_gs, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "m": m, + "n": n, + "k": k, + "e": num_experts, + "device": device, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe_fp4": run_cutlass_moe_fp4, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + + run_cutlass_moe_fp4( + a, + w1_fp4, + w2_fp4, + w1_blockscale, + w2_blockscale, + w1_gs, + w2_gs, + a1_gs, + a2_gs, + topk_weights, + topk_ids, + m, + n, + k, + num_experts, + device, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index c92ea43e826..1be83b84e95 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.utils.benchmark as benchmark @@ -6,14 +7,18 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, - fused_experts, - fused_topk) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + cutlass_moe_fp8, + fused_experts, + fused_topk, +) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = [ - "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", - "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" + "nm-testing/Mixtral-8x7B-Instruct-v0.1", + "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", + "ibm-granite/granite-3.0-3b-a800m", ] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] @@ -24,19 +29,27 @@ def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) -def bench_run(results: list[benchmark.Measurement], model: str, - num_experts: int, topk: int, per_act_token: bool, - per_out_ch: bool, mkn: tuple[int, int, int]): +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): label = "Quant Matmul" sub_label = ( - "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, - mkn)) + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) print(f"Testing: {sub_label}") @@ -50,35 +63,17 @@ def bench_run(results: list[benchmark.Measurement], model: str, _, a_scale = ops.scaled_fp8_quant(a) - w1_q = torch.empty((num_experts, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((num_experts, k, n), - device="cuda", - dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_experts, ), - 2 * n, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_experts, ), - n, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) + w1_q = torch.empty( + (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn + ) + w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) @@ -91,82 +86,120 @@ def bench_run(results: list[benchmark.Measurement], model: str, score = torch.randn((m, num_experts), device="cuda", dtype=dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score, topk, renormalize=False) + a, score, topk, renormalize=False + ) - def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a_scale: torch.Tensor, num_repeats: int): + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + num_repeats: int, + ): for _ in range(num_repeats): - fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - num_repeats: int): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) + + def run_cutlass_moe( + a: torch.Tensor, + a_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + num_repeats: int, + ): for _ in range(num_repeats): - cutlass_moe_fp8(a, - w1, - w2, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + cutlass_moe_fp8( + a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale, + ) def run_cutlass_from_graph( - a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + a: torch.Tensor, + a_scale: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + ): with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, a_scale: torch.Tensor): + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return cutlass_moe_fp8( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + ): with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) def replay_graph(graph, num_repeats): for _ in range(num_repeats): @@ -176,16 +209,35 @@ def replay_graph(graph, num_repeats): cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): - run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) + run_cutlass_from_graph( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) torch.cuda.synchronize() triton_stream = torch.cuda.Stream() triton_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(triton_graph, stream=triton_stream): - run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, - topk_ids, w1_scale, w2_scale, a_scale) + run_triton_from_graph( + a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + ) torch.cuda.synchronize() min_run_time = 5 @@ -225,18 +277,27 @@ def replay_graph(graph, num_repeats): } # Warmup - run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, - w1_scale, w2_scale, a_scale, num_warmup) + run_triton_moe( + a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + num_warmup, + ) results.append( benchmark.Timer( - stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="triton_moe", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup replay_graph(triton_graph, num_warmup) @@ -248,22 +309,35 @@ def replay_graph(graph, num_repeats): label=label, sub_label=sub_label, description="triton_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup - run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, - num_warmup) + run_cutlass_moe( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + num_warmup, + ) results.append( benchmark.Timer( - stmt= - "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="grouped_gemm_moe", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup replay_graph(cutlass_graph, num_warmup) @@ -275,7 +349,8 @@ def replay_graph(graph, num_repeats): label=label, sub_label=sub_label, description="grouped_gemm_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) def main(args): @@ -303,8 +378,15 @@ def main(args): for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) compare = benchmark.Compare(results) compare.print() @@ -312,7 +394,8 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + description="Benchmark Marlin across specified models/shapes/batches" + ) parser.add_argument( "--models", nargs="+", @@ -320,21 +403,14 @@ def main(args): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) - parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) - parser.add_argument("--limit-per-act-token", - nargs="+", - type=int, - default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index e12d74c01e4..69978ec6b23 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time @@ -10,14 +11,16 @@ @torch.inference_mode() -def main(num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - seed: int = 0, - do_profile: bool = False, - num_warmup_iters: int = 5, - num_iters: int = 100) -> None: +def main( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: current_platform.seed_everything(seed) torch.set_default_device("cuda") @@ -56,33 +59,35 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': - parser = FlexibleArgumentParser( - description="Benchmark the layernorm kernel.") +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.") parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--add-residual", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. " - "If --profile is set, this number is ignored") + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) args = parser.parse_args() print(args) - main(num_tokens=args.num_tokens, - hidden_size=args.hidden_size, - add_residual=args.add_residual, - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - do_profile=args.profile, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters) + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index d382ede10b4..3d38d4b3534 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy @@ -20,18 +21,36 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, - lora_shrink) - from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, - _LORA_B_PTR_DICT) + from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] DEFAULT_BATCH_SIZES = [ - 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, - 2048, 3072, 4096, 5120, 6144, 7168, 8192 + 1, + 16, + 32, + 64, + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 640, + 768, + 896, + 1024, + 2048, + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, ] DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] DEFAULT_LORA_RANKS = [16] @@ -52,12 +71,9 @@ def dtype_to_str(dtype: torch.dtype): raise ValueError(f"Unsupported dtype {dtype}") -def make_rand_lora_weight_tensor(k: int, - n: int, - num_loras: int, - dtype: torch.dtype, - device: str = "cuda") -> torch.Tensor: - +def make_rand_lora_weight_tensor( + k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda" +) -> torch.Tensor: # LoRA weights column major return torch.rand((num_loras, n, k), dtype=dtype).to(device) @@ -78,18 +94,15 @@ def make_rand_tensors( A = torch.rand(a_shape, dtype=a_dtype).to(device) # LoRA weights column major - Bs = [ - torch.rand(b_shape, dtype=b_dtype).to(device) - for _ in range(num_slices) - ] + Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)] C = torch.zeros(c_shape, dtype=c_dtype).to(device) return A, Bs, C -def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, - sort_by_lora_id: bool, - device: str) -> torch.Tensor: +def make_prompt_lora_mapping( + num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str +) -> torch.Tensor: """ All prompts are mapped to a LoRA ID in range [0, num_active_loras). where 0 refers to first lora, 1 refers to second lora and so on. @@ -97,9 +110,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, assert num_active_loras > 0 if not sort_by_lora_id: - return torch.randint(0, - num_active_loras, (num_prompts, ), - dtype=torch.long) + return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long) # Divide LoRAs equally and in order. part_size = num_prompts // num_active_loras @@ -110,14 +121,18 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, while len(prompt_lora_mapping) < num_prompts: prompt_lora_mapping.extend([lora_id] * part_size) lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id - return torch.tensor(prompt_lora_mapping[:num_prompts], - dtype=torch.long, - device=device) - - -def make_token_lora_mapping(num_tokens: int, num_prompts: int, - prompt_lora_mapping: torch.Tensor, - seq_len_tensor: torch.Tensor, device: str): + return torch.tensor( + prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device + ) + + +def make_token_lora_mapping( + num_tokens: int, + num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, + device: str, +): """ Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor """ @@ -136,11 +151,15 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) -def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, - lora_weights: list[torch.Tensor], - seq_lens_cpu: torch.Tensor, - prompt_lora_mapping_cpu: torch.Tensor, scaling: float, - add_inputs: Optional[bool]): +def ref_group_gemm( + ref_out: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, + scaling: float, + add_inputs: Optional[bool], +): """ Torch group gemm reference implementation to test correctness of benchmarking operations. @@ -149,7 +168,7 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, out_list = [] current_offset = 0 for lora_index, b_length in zip(range(batches), seq_lens_cpu): - x = input[current_offset:b_length + current_offset, :] + x = input[current_offset : b_length + current_offset, :] current_offset += b_length w = lora_weights[prompt_lora_mapping_cpu[lora_index]] result = torch.nn.functional.linear(x, w) @@ -168,6 +187,7 @@ class OpType(Enum): """ LoRA Ops to benchmark and its properties. """ + LORA_SHRINK = auto() LORA_EXPAND = auto() @@ -188,8 +208,9 @@ def is_expand_fn(self) -> bool: def num_slices(self) -> list[int]: return [1, 2, 3] - def mkn(self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int) -> tuple[int, int, int]: + def mkn( + self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int + ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length if self.is_shrink_fn(): m = num_tokens @@ -203,7 +224,7 @@ def mkn(self, batch_size: int, seq_length: int, hidden_size: int, return m, k, n def matmul_dtypes( - self, op_dtype: torch.dtype + self, op_dtype: torch.dtype ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: """ return a type, b type and c type for A x B = C @@ -215,9 +236,14 @@ def matmul_dtypes( return torch.float32, op_dtype, op_dtype def matmul_shapes( - self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int, num_loras: int, - num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]: + self, + batch_size: int, + seq_length: int, + hidden_size: int, + lora_rank: int, + num_loras: int, + num_slices: int, + ) -> tuple[tuple[int], tuple[int], tuple[int]]: """ Given num_slices, return the shapes of the A, B, and C matrices in A x B = C, for the op_type @@ -241,31 +267,38 @@ def bench_fn(self) -> Callable: raise ValueError(f"Unrecognized optype {self}") - def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, - lora_weights: list[torch.Tensor], - **kwargs) -> Callable: + def run_ref_group_gemm( + self, + output: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + **kwargs, + ) -> Callable: """Each benchmark operation expects the input, lora_weights and outputs - in a slightly different format. Refer to self.matmul_shapes(). - run_ref_group_gemm accounts for those differences in executing a - reference group gemm for correctness testing. + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. """ w_dtype = lora_weights[0].dtype num_slices = len(lora_weights) if self in [OpType.LORA_SHRINK]: for slice_idx in range(num_slices): - ref_group_gemm(ref_out=output[slice_idx, :], - input=input, - lora_weights=lora_weights[slice_idx], - **kwargs) + ref_group_gemm( + ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs, + ) elif self in [OpType.LORA_EXPAND]: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size ref_group_gemm( - ref_out=output[:, slice_offset:slice_offset + hidden_size], + ref_out=output[:, slice_offset : slice_offset + hidden_size], input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], - **kwargs) + **kwargs, + ) else: raise ValueError(f"Unrecognized optype {self}") @@ -275,6 +308,7 @@ class BenchmarkContext: """ LoRA benchmark context """ + batch_size: int hidden_size: int num_loras: int @@ -299,17 +333,18 @@ def bench_label(self) -> str: return f"lora-{self.dtype}" def bench_sublabel(self, op_type: OpType) -> str: - m, k, n = op_type.mkn(self.batch_size, self.seq_length, - self.hidden_size, self.lora_rank) + m, k, n = op_type.mkn( + self.batch_size, self.seq_length, self.hidden_size, self.lora_rank + ) desc = { - 'bs': self.batch_size, - 'sl': self.seq_length, - 'm': m, - 'k': k, - 'n': n, - 'num_loras': self.num_loras, - 'sort_by_lora': self.sort_by_lora_id, - 'num_slices': self.num_slices, + "bs": self.batch_size, + "sl": self.seq_length, + "m": m, + "k": k, + "n": n, + "num_loras": self.num_loras, + "sort_by_lora": self.sort_by_lora_id, + "num_slices": self.num_slices, } return json.dumps(desc) @@ -319,6 +354,7 @@ class BenchmarkTensors: """ Input/Output tensors used for benchmarks """ + # matmul tensors input: torch.Tensor lora_weights_lst: list[torch.Tensor] @@ -330,23 +366,29 @@ class BenchmarkTensors: prompt_lora_mapping: torch.Tensor def io_types(self) -> str: - return (f"{dtype_to_str(self.input.dtype)}x" - f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" - f"{dtype_to_str(self.output.dtype)}") + return ( + f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}" + ) @staticmethod - def make(ctx: BenchmarkContext, - op_type: OpType, - device: str = "cuda") -> "BenchmarkTensors": - + def make( + ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" + ) -> "BenchmarkTensors": # Make input / output matmul tensors. a_shape, b_shape, c_shape = op_type.matmul_shapes( - ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, - ctx.num_loras, ctx.num_slices) + ctx.batch_size, + ctx.seq_length, + ctx.hidden_size, + ctx.lora_rank, + ctx.num_loras, + ctx.num_slices, + ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) - input_tensor, lora_weights, output_tensor = \ - make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, - num_slices = ctx.num_slices) + input_tensor, lora_weights, output_tensor = make_rand_tensors( + a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices + ) # Make metadata tensors. # Keep the metadata tensors in the CPU for further processing if needed. @@ -356,27 +398,38 @@ def make(ctx: BenchmarkContext, # Make metadata tensors involved in correctness testing. # Prepare seq lens tensor - seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, - (ctx.batch_size, )) + seq_len_tensor = torch.randint( + ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,) + ) assert total_tokens == seq_len_tensor.sum() # Prepare prompt lora indices tensor prompt_lora_indices_tensor = make_prompt_lora_mapping( - ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu" + ) # Make LoRAKernelMeta token_lora_indices_tensor = make_token_lora_mapping( - total_tokens, ctx.batch_size, prompt_lora_indices_tensor, - seq_len_tensor, "cpu") + total_tokens, + ctx.batch_size, + prompt_lora_indices_tensor, + seq_len_tensor, + "cpu", + ) lora_kernel_meta = LoRAKernelMeta.make( max_loras=ctx.num_loras, max_num_tokens=token_lora_indices_tensor.size(0), - device="cpu") - lora_kernel_meta.prepare_tensors( - token_lora_mapping=token_lora_indices_tensor) - - return BenchmarkTensors(input_tensor, lora_weights, output_tensor, - lora_kernel_meta, seq_len_tensor, - prompt_lora_indices_tensor) + device="cpu", + ) + lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor) + + return BenchmarkTensors( + input_tensor, + lora_weights, + output_tensor, + lora_kernel_meta, + seq_len_tensor, + prompt_lora_indices_tensor, + ) def sanity_check(self) -> None: """ @@ -386,7 +439,7 @@ def sanity_check(self) -> None: # check metadata tensors assert torch.sum(self.seq_lens) == num_tokens num_seqs = self.seq_lens.shape[0] - #assert self.seq_start_loc.shape[0] == num_seqs + # assert self.seq_start_loc.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens @@ -430,8 +483,11 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: _, num_tokens, _, num_slices = self.metadata() # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) # Expected input shape [num_tokens, hidden_size] assert len(i_shape) == 2 assert i_shape[0] == num_tokens @@ -445,16 +501,17 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: assert o_shape == (num_slices, num_tokens, lora_rank) return { - 'inputs': self.input, - 'lora_a_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, - 'token_indices_sorted_by_lora_ids': - self.lora_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, - 'lora_ids': self.lora_kernel_meta.active_lora_ids, - 'scaling': 1.0, + "inputs": self.input, + "lora_a_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "scaling": 1.0, } def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: @@ -464,8 +521,11 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: _, num_tokens, _, num_slices = self.metadata() # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) # Expected input shape : [num_slices, num_tokens, lora_rank] assert len(i_shape) == 3 assert i_shape[0] == num_slices @@ -480,22 +540,23 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: assert o_shape == (num_tokens, hidden_size * num_slices) return { - 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, - 'token_indices_sorted_by_lora_ids': - self.lora_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, - 'lora_ids': self.lora_kernel_meta.active_lora_ids, - 'offset_start': 0, - 'add_inputs': add_inputs, + "inputs": self.input, + "lora_b_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "offset_start": 0, + "add_inputs": add_inputs, } - def bench_fn_kwargs(self, - op_type: OpType, - add_inputs: Optional[bool] = None) -> dict[str, Any]: + def bench_fn_kwargs( + self, op_type: OpType, add_inputs: Optional[bool] = None + ) -> dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None else: @@ -507,8 +568,9 @@ def bench_fn_kwargs(self, return self.as_lora_expand_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") - def test_correctness(self, op_type: OpType, - expand_fn_add_inputs: Optional[bool]) -> bool: + def test_correctness( + self, op_type: OpType, expand_fn_add_inputs: Optional[bool] + ) -> bool: """ Test correctness of op_type implementation against a grouped gemm reference implementation. @@ -518,8 +580,7 @@ def test_correctness(self, op_type: OpType, ref_output = self.output.clone() self.output.zero_() - op_type.bench_fn()( - **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) op_type.run_ref_group_gemm( ref_output, @@ -528,7 +589,8 @@ def test_correctness(self, op_type: OpType, seq_lens_cpu=seq_lens_cpu, prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, scaling=1.0, - add_inputs=expand_fn_add_inputs) + add_inputs=expand_fn_add_inputs, + ) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -539,13 +601,14 @@ def test_correctness(self, op_type: OpType, return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) -def bench_optype(ctx: BenchmarkContext, - arg_pool_size: int, - op_type: OpType, - cuda_graph_nops: Optional[int] = None, - expand_fn_add_inputs: Optional[bool] = None, - test_correctness: bool = False) -> TMeasurement: - +def bench_optype( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, + expand_fn_add_inputs: Optional[bool] = None, + test_correctness: bool = False, +) -> TMeasurement: assert arg_pool_size >= 1 if op_type.is_shrink_fn(): assert expand_fn_add_inputs is None @@ -553,17 +616,17 @@ def bench_optype(ctx: BenchmarkContext, assert expand_fn_add_inputs is not None # BenchmarkContext -> BenchmarkTensors - bench_tensors : list[BenchmarkTensors] = \ - [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) + ] for bt in bench_tensors: bt.sanity_check() # Test correctness of our implementation. if test_correctness: - assert all([ - bt.test_correctness(op_type, expand_fn_add_inputs) - for bt in bench_tensors - ]) + assert all( + [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ @@ -585,40 +648,49 @@ def bench_optype(ctx: BenchmarkContext, for k, v in _kwargs.items(): kwargs[k].values.append(v) - describe_args = (f"add_inputs={expand_fn_add_inputs}" - if expand_fn_add_inputs is not None else "") - description = ( - f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") + describe_args = ( + f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "" + ) + description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})" cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) timer = None - with Bench(cuda_graph_params, - ctx.bench_label(), ctx.bench_sublabel(op_type), description, - op_type.bench_fn(), **kwargs) as bench: + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + op_type.bench_fn(), + **kwargs, + ) as bench: timer = bench.run() return timer -def bench_torch_mm(ctx: BenchmarkContext, - arg_pool_size: int, - op_type: OpType, - cuda_graph_nops: Optional[int] = None) -> TMeasurement: +def bench_torch_mm( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, +) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. When all the input tokens have the same LoRA ID, the LoRA kernels are just - a matmul. This torch.mm benchmark serves as a roofline for that case. + a matmul. This torch.mm benchmark serves as a roofline for that case. input op_type is used in determining the m, k, n dimensions for the matmul. """ - batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, - ctx.hidden_size, - ctx.lora_rank, - ctx.seq_length, - ctx.dtype) + batch_size, hidden_size, lora_rank, seq_length, dtype = ( + ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype, + ) m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) # For a fairer comparison. @@ -632,18 +704,24 @@ def bench_torch_mm(ctx: BenchmarkContext, Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) # Make torch.mm kwargs - mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} + mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)} description = ( f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" f"x{dtype_to_str(dtype)}" - f"=>{dtype_to_str(dtype)})") + f"=>{dtype_to_str(dtype)})" + ) cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) - with Bench(cuda_graph_params, ctx.bench_label(), - ctx.bench_sublabel(op_type), description, torch.mm, - **mm_kwargs) as bench: + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + torch.mm, + **mm_kwargs, + ) as bench: return bench.run() @@ -660,8 +738,7 @@ def use_cuda_graph_recommendation() -> str: """ -def print_timers(timers: list[TMeasurement], - args: Optional[argparse.Namespace] = None): +def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() @@ -670,22 +747,23 @@ def print_timers(timers: list[TMeasurement], f"Note : The timings reported above is for {args.cuda_graph_nops} " "consecutive invocations of the benchmarking functions. " f"Please divide by {args.cuda_graph_nops} for single invocation " - "timings.") + "timings." + ) - print("Note on Comparison with torch.mm : The torch.mm numbers are " - "benchmark numbers of a simple matmul emulating the single lora " - "case. It is provided as a roofline for comparing our LoRA Kernel " - "implementations. It is expected that the LoRA kernels will be " - "slower than torch.mm in cases where num_loras is big. But for " - "small num_loras the goal should be to match the torch.mm numbers.") + print( + "Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers." + ) def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): - if args.cuda_graph_nops is not None: assert args.cuda_graph_nops > 0 - print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " - "Graph") + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph") else: print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") @@ -697,21 +775,30 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): for bench_op in bench_ops: for num_slices in bench_op.num_slices(): _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( - num_slices) + num_slices + ) # Benchmark torch.mm as a roofline seq_len_timers.append( - bench_torch_mm(_ctx, args.arg_pool_size, bench_op, - args.cuda_graph_nops)) + bench_torch_mm( + _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops + ) + ) # Benchmark bench_op - expand_fn_add_inputs = [ - None - ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + expand_fn_add_inputs = ( + [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( - bench_optype(_ctx, args.arg_pool_size, bench_op, - args.cuda_graph_nops, add_input_arg, - args.test_correctness)) + bench_optype( + _ctx, + args.arg_pool_size, + bench_op, + args.cuda_graph_nops, + add_input_arg, + args.test_correctness, + ) + ) print_timers(seq_len_timers) timers.extend(seq_len_timers) @@ -733,13 +820,17 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): pickle.dump(timers, f) -def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], - args: argparse.Namespace) -> list[BenchmarkContext]: - +def as_benchmark_contexts( + hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace +) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa - args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, - args.sort_by_lora_id): + args.batch_sizes, + list(hidden_sizes), + lora_ranks, + args.num_loras, + args.sort_by_lora_id, + ): ctxs.append( BenchmarkContext( batch_size=batch_size, @@ -747,13 +838,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], lora_rank=lora_rank, num_loras=num_loras, num_active_loras=args.num_active_loras - if args.num_active_loras else num_loras, + if args.num_active_loras + else num_loras, # To be filled based on the OpType to benchmark seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, # To be filled based on the OpType to benchmark - num_slices=None)) + num_slices=None, + ) + ) return ctxs @@ -761,13 +855,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], def run_list_bench(args: argparse.Namespace): print(args) - print("List bench :\n" - f" Hidden Sizes {args.hidden_sizes}" - f" LoRA Ranks {args.lora_ranks}") + print( + "List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}" + ) # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) run(args, bench_contexts) @@ -776,19 +873,22 @@ def run_range_bench(args: argparse.Namespace): print(args) hidden_sizes = list( - range(args.hidden_sizes_start, args.hidden_sizes_end + 1, - args.hidden_sizes_increment)) + range( + args.hidden_sizes_start, + args.hidden_sizes_end + 1, + args.hidden_sizes_increment, + ) + ) lora_ranks = list( - range(args.lora_ranks_start, args.lora_ranks_end + 1, - args.lora_ranks_increment)) + range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment) + ) - print("Range bench :\n" - f" Hidden Sizes {hidden_sizes}" - f" LoRA Ranks {lora_ranks}") + print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}") # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args + ) run(args, bench_contexts) @@ -806,21 +906,19 @@ def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]: # Get all hidden sizes hidden_sizes: set[int] = set() for model_name, tp_size in product(args.models, args.tp_sizes): - hidden_sizes = hidden_sizes.union( - hidden_sizes_from_model(model_name, tp_size)) + hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size)) - print("Model bench :\n" - f" Hidden Sizes {hidden_sizes}" - f" LoRA Ranks {args.lora_ranks}") + print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) run(args, bench_contexts) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "torch.float16": @@ -830,14 +928,15 @@ def to_torch_dtype(dt): raise ValueError("unsupported dtype") def get_bool(s: str) -> bool: - return s.lower() in ['true', '1'] + return s.lower() in ["true", "1"] def add_common_command_args(p: argparse.ArgumentParser): p.add_argument( "--dtype", type=to_torch_dtype, required=True, - help="Available options are ['torch.float16', 'torch.bfloat16']") + help="Available options are ['torch.float16', 'torch.bfloat16']", + ) p.add_argument( "--arg-pool-size", @@ -845,56 +944,66 @@ def add_common_command_args(p: argparse.ArgumentParser): default=32, help="Run profiles with a pool of input/output/meta tensors instead" "of simply reusing the same tensors for all runs. A bigger arg-pool" - "mitigates hardware caching effects during benchmarking.") + "mitigates hardware caching effects during benchmarking.", + ) p.add_argument( "--cuda-graph-nops", type=int, - help=("when set profiling is done using cudagraph, " - "with the given number of operations in a graph." - "Note that the measurement returned is the time " - "taken for N consecutive executions of the benchmarking " - "functions, where N is the value of this argument.")) - p.add_argument("--num-loras", - nargs="+", - type=int, - default=DEFAULT_NUM_LORAS) - p.add_argument("--num-active-loras", - type=int, - default=None, - help="Active LoRAs. When None, all LoRAs are active") - p.add_argument("--sort-by-lora-id", - nargs="+", - type=get_bool, - default=DEFAULT_SORT_BY_LORA_IDS) - p.add_argument("--op-types", - nargs="+", - type=OpType.from_str, - default=list(OpType)) - p.add_argument('--seq-lengths', - nargs="+", - type=int, - default=DEFAULT_SEQ_LENGTHS) - p.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - p.add_argument("--expand-fn-add-inputs", - nargs="+", - type=get_bool, - default=DEFAULT_EXPAND_FN_ADD_INPUTS) + help=( + "when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument." + ), + ) + p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS) + p.add_argument( + "--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active", + ) + p.add_argument( + "--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS, + ) + p.add_argument( + "--op-types", nargs="+", type=OpType.from_str, default=list(OpType) + ) + p.add_argument( + "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS + ) + p.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + p.add_argument( + "--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS, + ) p.add_argument( - '-o', - '--output-directory', + "-o", + "--output-directory", type=str, - help=("Output directory to store a the list of benchmarking" - "TMeasurement objects as a pickle file")) + help=( + "Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file" + ), + ) p.add_argument( "--test-correctness", - action='store_true', - help=("When enabled, the benchmarking functions are tested" - "for correctness before the actual benchmarking")) + action="store_true", + help=( + "When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking" + ), + ) parser = FlexibleArgumentParser( description=f""" @@ -910,50 +1019,45 @@ def add_common_command_args(p: argparse.ArgumentParser): range_bench example: python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) subparsers = parser.add_subparsers(dest="cmd", required=True) list_parser = subparsers.add_parser("list_bench") - list_parser.add_argument("--hidden-sizes", - nargs="+", - type=int, - default=DEFAULT_HIDDEN_SIZES) - list_parser.add_argument("--lora-ranks", - nargs="+", - type=int, - default=DEFAULT_LORA_RANKS) + list_parser.add_argument( + "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES + ) + list_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) add_common_command_args(list_parser) list_parser.set_defaults(func=run_list_bench) range_parser = subparsers.add_parser("range_bench") range_parser.add_argument("--hidden-sizes-start", type=int, required=True) range_parser.add_argument("--hidden-sizes-end", type=int, required=True) - range_parser.add_argument("--hidden-sizes-increment", - type=int, - required=True) + range_parser.add_argument("--hidden-sizes-increment", type=int, required=True) range_parser.add_argument("--lora-ranks-start", type=int, required=True) range_parser.add_argument("--lora-ranks-end", type=int, required=True) - range_parser.add_argument("--lora-ranks-increment", - type=int, - required=True) + range_parser.add_argument("--lora-ranks-increment", type=int, required=True) add_common_command_args(range_parser) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--lora-ranks", - nargs="+", - type=int, - default=DEFAULT_LORA_RANKS) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) add_common_command_args(model_parser) model_parser.set_defaults(func=run_model_bench) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index a661ea9d7e6..0f896f187ec 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy @@ -20,12 +21,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, - marlin_zero_points) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + marlin_permute_scales, + marlin_zero_points, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace) + MarlinWorkspace, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -82,12 +89,14 @@ def rand_data(shape, dtype=torch.float16, scale=1): return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") -def quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( @@ -96,21 +105,24 @@ def quantize_and_pack(atype: torch.dtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps - ref_zero_points_after_scales=True) + ref_zero_points_after_scales=True, + ) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) return w_ref, w_q, w_s, w_zp -def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, - group_size: Optional[int]) -> list[BenchmarkTensors]: +def create_bench_tensors( + shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] +) -> list[BenchmarkTensors]: m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / - (k * n * types.weight_type.size_bits)) + num_weights = math.ceil( + 2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits) + ) a = rand_data((m, k), types.act_type, scale=5) @@ -124,8 +136,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - types.group_zero_type is not None) + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) @@ -133,21 +150,30 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, w_ref = w_ref.to(torch.float32) - w_ch_s = None if types.channel_scale_type is None else\ - rand_data((n,), types.channel_scale_type) - w_tok_s = None if types.token_scale_type is None else\ - rand_data((m,), types.token_scale_type) + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) benchmark_tensors.append( - BenchmarkTensors(w_ref=w_ref, - a=a, - w_q=w_q_packed, - wtype=types.weight_type, - w_g_s=w_s, - w_g_zp=w_zp, - group_size=group_size, - w_ch_s=w_ch_s, - w_tok_s=w_tok_s)) + BenchmarkTensors( + w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) + ) return benchmark_tensors @@ -170,50 +196,57 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() return lambda: ops.cutlass_scaled_mm( - bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16 + ) def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: device = bt.a.device - workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = MarlinWorkspace( + bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) if bt.w_g_zp is None: w_zp = torch.empty(0, dtype=torch.int, device=device) else: - w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.wtype.size_bits) + w_zp = marlin_zero_points( + bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) if bt.group_size is None: w_s = torch.tensor([], device="cuda", dtype=torch.half) else: - w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.group_size) + w_s = marlin_permute_scales( + bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size + ) sort_indices = torch.empty(0, dtype=torch.int, device=device) g_idx = torch.empty(0, dtype=torch.int, device=device) - w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.wtype.size_bits) + w_q = ops.gptq_marlin_repack( + bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) if bt.a.dtype.is_floating_point: assert bt.w_ch_s is None assert bt.w_tok_s is None assert bt.group_size is not None - fn = lambda: ops.gptq_marlin_gemm(a=bt.a, - b_q_weight=w_q, - b_scales=w_s, - b_zeros=w_zp, - g_idx=g_idx, - perm=sort_indices, - workspace=workspace.scratch, - b_q_type=bt.wtype, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - is_k_full=True, - is_zp_float=False) + fn = lambda: ops.gptq_marlin_gemm( + a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True, + is_zp_float=False, + ) else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 @@ -221,36 +254,35 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: if bt.w_ch_s is not None: s_ch = bt.w_ch_s.to(torch.float32) else: - s_ch = torch.ones(bt.w_ref.shape[1], - dtype=torch.float32, - device=device) + s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) if bt.w_tok_s is not None: s_tok = bt.w_tok_s.to(torch.float32) else: - s_tok = torch.ones(bt.a.shape[0], - dtype=torch.float32, - device=device) - - fn = lambda: ops.marlin_qqq_gemm(a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0]) + s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) + + fn = lambda: ops.marlin_qqq_gemm( + a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + ) return fn -def machete_create_bench_fn(bt: BenchmarkTensors, - out_type=torch.dtype, - schedule=None) -> Callable: +def machete_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: w_q = bt.w_q.t().contiguous().t() # make col major - w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, - None if bt.w_g_s is None else bt.w_g_s.dtype) + w_q = ops.machete_prepack_B( + w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype + ) w_g_zp = bt.w_g_zp if w_g_zp is not None: @@ -275,26 +307,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors, # bench -def bench_fns(label: str, sub_label: str, description: str, - fns: list[Callable]): - +def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]): min_run_time = 1 if not NVTX_PROFILE else 0.1 res = TBenchmark.Timer( stmt=""" for fn in fns: fn() """, - globals={ - "fns": fns - }, + globals={"fns": fns}, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) if NVTX_PROFILE: - with nvtx.annotate("mm-bench"), nvtx.annotate( - f"{label}|{sub_label}|{description}"): + with ( + nvtx.annotate("mm-bench"), + nvtx.annotate(f"{label}|{sub_label}|{description}"), + ): fns[0]() return res @@ -304,19 +334,20 @@ def bench_fns(label: str, sub_label: str, description: str, _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(types: TypeConfig, - group_size: int, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - sweep_schedules: bool = True) -> list[TMeasurement]: +def bench( + types: TypeConfig, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + sweep_schedules: bool = True, +) -> list[TMeasurement]: benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) sub_label += f", L={len(benchmark_tensors)}" - name_type_string = f"W{types.weight_type}"+\ - f"-A{terse_type_name(types.act_type)}" + name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}" if types.group_scale_type is not None: name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" if types.group_zero_type is not None: @@ -332,31 +363,45 @@ def bench(types: TypeConfig, # pytorch impl timers.append( bench_fns( - label, sub_label, "torch.matmul (fp16)", - [torch_matmul_f16_create_bench_fn(bt) - for bt in benchmark_tensors])) + label, + sub_label, + "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: timers.append( bench_fns( - label, sub_label, - f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ - cutlass_scaled_mm_create_bench_fn(bt) - for bt in benchmark_tensors - ])) + label, + sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", + [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fns(label, sub_label, f"marlin ({name_type_string})", - [marlin_create_bench_fn(bt) - for bt in benchmark_tensors])) + bench_fns( + label, + sub_label, + f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) # machete timers.append( - bench_fns(label, sub_label, f"machete ({name_type_string})", [ - machete_create_bench_fn(bt, out_type=types.output_type) - for bt in benchmark_tensors - ])) + bench_fns( + label, + sub_label, + f"machete ({name_type_string})", + [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS @@ -371,7 +416,8 @@ def bench(types: TypeConfig, group_zeros_type=types.group_zero_type, token_scales_type=types.token_scale_type, channel_scales_type=types.channel_scale_type, - out_type=types.output_type) + out_type=types.output_type, + ) if schedules is None or len(schedules) == 0: raise ValueError("No schedules found to sweep") @@ -383,11 +429,17 @@ def bench(types: TypeConfig, if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - res = bench_fns(label, sub_label, "machete_best", [ - machete_create_bench_fn( - bt, out_type=types.output_type, schedule=schedule) - for bt in benchmark_tensors - ]) + res = bench_fns( + label, + sub_label, + "machete_best", + [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule + ) + for bt in benchmark_tensors + ], + ) results_row = { "M": m, @@ -398,10 +450,8 @@ def bench(types: TypeConfig, "median": res.median, } if _SWEEP_SCHEDULES_RESULTS is None: - _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( - columns=results_row.keys()) - _SWEEP_SCHEDULES_RESULTS.\ - loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: @@ -422,8 +472,9 @@ def print_timers(timers: list[TMeasurement]): def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: types = TypeConfig( act_type=args.act_type, - weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ - else scalar_types.uint4, + weight_type=scalar_types.uint4b8 + if args.group_zero_type is None + else scalar_types.uint4, output_type=args.out_type, group_scale_type=args.group_scale_type, group_zero_type=args.group_zero_type, @@ -433,14 +484,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: results: list[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(types, - args.group_size, - m, - k, - n, - f"{args.act_type}-gemm", - f"MKN=({m}x{k}x{n})", - sweep_schedules=args.sweep_schedules) + timers = bench( + types, + args.group_size, + m, + k, + n, + f"{args.act_type}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=args.sweep_schedules, + ) print_timers(timers) results.extend(timers) @@ -454,7 +507,6 @@ def make_output( base_description: str, timestamp=None, ): - print(f"== All Results {base_description} ====") print_timers(data) @@ -468,8 +520,7 @@ def make_output( def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, args.sweep_schedules, MKNs) @@ -479,8 +530,9 @@ def run_square_bench(args): def run_range_bench(args): m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) - m_increment, k_increment, n_increment = \ - (int(x) for x in args.dim_increment.split(",")) + m_increment, k_increment, n_increment = ( + int(x) for x in args.dim_increment.split(",") + ) Ms = list(range(m_start, m_end + 1, m_increment)) Ks = list(range(k_start, k_end + 1, k_increment)) Ns = list(range(n_start, n_end + 1, n_increment)) @@ -492,7 +544,6 @@ def run_range_bench(args): def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") @@ -535,10 +586,13 @@ def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]: with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: args_dict = vars(args) args_dict.pop("func") - pkl.dump({ - "args": args_dict, - "results": all_results, - }, f) + pkl.dump( + { + "args": args_dict, + "results": all_results, + }, + f, + ) if __name__ == "__main__": @@ -554,7 +608,6 @@ def to_torch_dtype(dt): }[dt] class ToTorchDtype(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, to_torch_dtype(values)) @@ -580,32 +633,32 @@ def __call__(self, parser, namespace, values, option_string=None): "--act-type", action=ToTorchDtype, required=True, - choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + choices=["bfloat16", "float16", "int8", "float8_e4m3fn"], ) parser.add_argument( "--group-scale-type", action=ToTorchDtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--group-zero-type", type=to_torch_dtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--channel-scale-type", action=ToTorchDtype, - choices=['float'], + choices=["float"], ) parser.add_argument( "--token-scale-type", action=ToTorchDtype, - choices=['float'], + choices=["float"], ) parser.add_argument( "--out-type", action=ToTorchDtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--group-size", @@ -618,9 +671,11 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Run a sweep over all supported schedules", ) - parser.add_argument("--sweep-csv-out", - help="CSV to store sweep results", - default="sch_sweep_results.csv") + parser.add_argument( + "--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv", + ) subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") @@ -634,17 +689,20 @@ def __call__(self, parser, namespace, values, option_string=None): "--dim-start", type=str, required=True, - help="Start value for M,K,N as common separated list") + help="Start value for M,K,N as common separated list", + ) range_parser.add_argument( "--dim-end", type=str, required=True, - help="End value (inclusive) for M,K,N as common separated list") + help="End value (inclusive) for M,K,N as common separated list", + ) range_parser.add_argument( "--dim-increment", type=str, required=True, - help="Increment value for M,K,N as common separated list") + help="Increment value for M,K,N as common separated list", + ) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") @@ -655,14 +713,12 @@ def __call__(self, parser, namespace, values, option_string=None): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 1e785ac8fc7..9ea1fddae2a 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.utils.benchmark as benchmark @@ -6,19 +7,34 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + MARLIN_SUPPORTED_GROUP_SIZES, + query_marlin_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, marlin_quantize) + MarlinWorkspace, + marlin_quantize, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize) + marlin_24_quantize, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser @@ -29,22 +45,29 @@ K_FULL_OPTS = [False, True] -def bench_run(results: list[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, quant_type: ScalarType, - group_size: int, size_m: int, size_k: int, size_n: int): +def bench_run( + results: list[benchmark.Measurement], + model: str, + act_order: bool, + is_k_full: bool, + quant_type: ScalarType, + group_size: int, + size_m: int, + size_k: int, + size_n: int, +): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, q={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, - str(quant_type), group_size, size_m, - size_k, size_n)) + sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( + model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n + ) print(f"Testing: {sub_label}") a = torch.randn(size_m, size_k).to(torch.half).cuda() b = torch.rand(size_k, size_n).to(torch.half).cuda() - a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() # Marlin quant ( @@ -57,14 +80,16 @@ def bench_run(results: list[benchmark.Measurement], model: str, ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( + marlin_24_quantize(b, quant_type, group_size) + ) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant - (w_ref, q_w, s, g_idx, - rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( + b, quant_type, group_size, act_order + ) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" @@ -74,32 +99,37 @@ def bench_run(results: list[benchmark.Measurement], model: str, (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare - marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + marlin_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) - marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_24_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL + ) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) # AllSpark W8A16 quant - as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES - and group_size == -1 and not act_order and is_k_full) + as_supported_case = ( + quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 + and not act_order + and is_k_full + ) if as_supported_case: properties = torch.cuda.get_device_properties(b.device.index) sm_count = properties.multi_processor_count sm_version = properties.major * 10 + properties.minor - supported_arch = (sm_version >= 80 and sm_version < 90) + supported_arch = sm_version >= 80 and sm_version < 90 as_supported_case = as_supported_case and supported_arch if supported_arch: has_zp = False - w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, - has_zp) + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) qw = qw.to(torch.uint8) - qw_reorder, s_reorder, zp_reorder = \ - ops.allspark_repack_weight( - qw, s, zp, has_zp) + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp + ) CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD globals = { @@ -136,8 +166,7 @@ def bench_run(results: list[benchmark.Measurement], model: str, "zp_reorder": zp_reorder if as_supported_case else None, "sm_count": sm_count if as_supported_case else None, "sm_version": sm_version if as_supported_case else None, - "CUBLAS_M_THRESHOLD": - CUBLAS_M_THRESHOLD if as_supported_case else None, + "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, @@ -158,60 +187,63 @@ def bench_run(results: list[benchmark.Measurement], model: str, label=label, sub_label=sub_label, description="pytorch_gemm", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp16", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) - if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + if ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ): results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 + stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_24_gemm", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 + stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_repack", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) if as_supported_case: results.append( benchmark.Timer( - stmt= - "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 + stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="allspark_w8a16_gemm_fp32", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) def main(args): @@ -233,37 +265,50 @@ def main(args): continue for act_order in ACT_ORDER_OPTS: - if len(args.limit_act_order - ) > 0 and act_order not in args.limit_act_order: + if ( + len(args.limit_act_order) > 0 + and act_order not in args.limit_act_order + ): continue for is_k_full in K_FULL_OPTS: - if len(args.limit_k_full - ) > 0 and is_k_full not in args.limit_k_full: + if ( + len(args.limit_k_full) > 0 + and is_k_full not in args.limit_k_full + ): continue - for quant_type in query_marlin_supported_quant_types( - False): - if len(args.limit_num_bits) > 0 and \ - quant_type.size_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types(False): + if ( + len(args.limit_num_bits) > 0 + and quant_type.size_bits not in args.limit_num_bits + ): continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: - if len( - args.limit_group_size - ) > 0 and group_size not in args.limit_group_size: + if ( + len(args.limit_group_size) > 0 + and group_size not in args.limit_group_size + ): continue # For act_order, the group_size must be less than # size_k - if act_order and (group_size == size_k - or group_size == -1): + if act_order and (group_size == size_k or group_size == -1): continue for size_m in args.batch_sizes: - bench_run(results, model, act_order, is_k_full, - quant_type, group_size, size_m, - size_k, size_n) + bench_run( + results, + model, + act_order, + is_k_full, + quant_type, + group_size, + size_m, + size_k, + size_n, + ) compare = benchmark.Compare(results) compare.print() @@ -274,7 +319,8 @@ def main(args): # if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + description="Benchmark Marlin across specified models/shapes/batches" + ) parser.add_argument( "--models", nargs="+", @@ -282,10 +328,9 @@ def main(args): default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 1884a80a407..6cb55b35993 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json @@ -6,15 +7,16 @@ from contextlib import nullcontext from datetime import datetime from itertools import product +from types import SimpleNamespace from typing import Any, TypedDict import ray import torch from ray.experimental.tqdm_ray import tqdm -from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser @@ -30,56 +32,60 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config(config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False) -> float: +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False, +) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: - w1 = torch.randint(-127, - 127, ( - num_experts, - shard_intermediate_size, - hidden_size, - ), - dtype=torch.int8) - w2 = torch.randint(-127, - 127, ( - num_experts, - hidden_size, - shard_intermediate_size // 2, - ), - dtype=torch.int8) + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) else: - w1 = torch.randn(num_experts, - shard_intermediate_size, - hidden_size, - dtype=init_dtype) - w2 = torch.randn(num_experts, - hidden_size, - shard_intermediate_size // 2, - dtype=init_dtype) - gating_output = torch.randn(num_iters, - num_tokens, - num_experts, - dtype=torch.float32) + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) w1_scale = None w2_scale = None a1_scale = None a2_scale = None if use_int8_w8a16: - w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), - dtype=torch.float32) + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: if block_quant_shape: @@ -92,10 +98,14 @@ def benchmark_config(config: BenchmarkConfig, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w1 = (K + block_k - 1) // block_k k_tiles_w2 = (N + block_k - 1) // block_k - w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), - dtype=torch.float32) * factor_for_scale - w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), - dtype=torch.float32) * factor_for_scale + w1_scale = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_scale = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) else: w1_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32) @@ -113,10 +123,12 @@ def prepare(i: int): def run(): from vllm.model_executor.layers.fused_moe import override_config + with override_config(config): if use_deep_gemm: topk_weights, topk_ids, token_expert_indices = fused_topk( - x, input_gating, topk, False) + x, input_gating, topk, False + ) return fused_experts( x, w1, @@ -212,8 +224,7 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16, - block_quant_shape) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): @@ -249,20 +260,25 @@ def get_configs_compute_bound(use_fp16, if block_quant_shape is not None and not use_fp16: block_n, block_k = block_quant_shape[0], block_quant_shape[1] for config in configs[:]: - if config["BLOCK_SIZE_K"] % block_k != 0 or config[ - "BLOCK_SIZE_N"] % block_n != 0: + if ( + config["BLOCK_SIZE_K"] % block_k != 0 + or config["BLOCK_SIZE_N"] % block_n != 0 + ): configs.remove(config) return configs -def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, - search_space, is_fp16, topk): +def prune_rocm_search_space( + num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk +): N1, K1 = shard_intermediate_size, hidden_size N2, K2 = hidden_size, shard_intermediate_size // 2 - pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, - search_space, is_fp16) - pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, - search_space, is_fp16) + pruned_space_1 = prune_rocm_configs( + num_tokens * topk, N1, K1, search_space, is_fp16 + ) + pruned_space_2 = prune_rocm_configs( + num_tokens * topk, N2, K2, search_space, is_fp16 + ) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) return search_space @@ -300,14 +316,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): SPLIT_K = config.get("SPLIT_K", 1) GROUP_M = config.get("GROUP_SIZE_M") if is_fp16: - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): + if ( + matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N + ): continue - if (matrix_instr_nonkdim >= M - and matrix_instr_nonkdim != BLOCK_SIZE_M): + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: continue - if (matrix_instr_nonkdim >= N - and matrix_instr_nonkdim != BLOCK_SIZE_N): + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough @@ -328,8 +344,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): continue # out of shared memory resource # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + LDS = ( + BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + ) if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm @@ -363,7 +381,6 @@ def merge_unique_dicts(list1, list2): @ray.remote(num_gpus=1) class BenchmarkWorker: - def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) @@ -387,36 +404,40 @@ def benchmark( use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = get_config_dtype_str(dtype, - use_int8_w8a16=use_int8_w8a16, - use_fp8_w8a8=use_fp8_w8a8) + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, - dtype_str) + op_config = get_moe_configs( + num_experts, shard_intermediate_size // 2, dtype_str + ) if op_config is None: - config = get_default_config(num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype_str, - is_marlin=False) + config = get_default_config( + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + is_marlin=False, + ) else: - config = op_config[min(op_config.keys(), - key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=100, - block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm) + config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + ) return config, kernel_time def tune( @@ -437,10 +458,14 @@ def tune( best_time = float("inf") if current_platform.is_rocm(): is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = prune_rocm_search_space(num_tokens, - shard_intermediate_size, - hidden_size, search_space, - is_fp16, topk) + search_space = prune_rocm_search_space( + num_tokens, + shard_intermediate_size, + hidden_size, + search_space, + is_fp16, + topk, + ) need_device_guard = False if current_platform.is_rocm(): @@ -448,8 +473,7 @@ def tune( if visible_device != f"{self.device_id}": need_device_guard = True - with torch.cuda.device( - self.device_id) if need_device_guard else nullcontext(): + with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config( @@ -464,7 +488,8 @@ def tune( use_int8_w8a16, num_iters=20, block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm) + use_deep_gemm=use_deep_gemm, + ) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -480,42 +505,44 @@ def tune( def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { - "BLOCK_SIZE_M": - config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": - config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": - config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": - config["GROUP_SIZE_M"], - "num_warps": - config["num_warps"], - "num_stages": - config["num_stages"], - **({ - "waves_per_eu": config["waves_per_eu"] - } if "waves_per_eu" in config else {}), - **({ - "matrix_instr_nonkdim": config["matrix_instr_nonkdim"] - } if "matrix_instr_nonkdim" in config else {}), - **({ - "kpack": config["kpack"] - } if "kpack" in config else {}), + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), + **( + {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]} + if "matrix_instr_nonkdim" in config + else {} + ), + **({"kpack": config["kpack"]} if "kpack" in config else {}), } -def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, - shard_intermediate_size: int, hidden_size: int, topk: int, - dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - block_quant_shape: List[int]) -> None: - dtype_str = get_config_dtype_str(dtype, - use_int8_w8a16=use_int8_w8a16, - use_fp8_w8a8=use_fp8_w8a8) +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: List[int], +) -> None: + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str, block_quant_shape) + filename = get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) print(f"Writing best config to {filename}...") with open(filename, "w") as f: @@ -524,18 +551,20 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def get_weight_block_size_safety(config, default_value=None): - - quantization_config = getattr(config, 'quantization_config', {}) + quantization_config = getattr(config, "quantization_config", {}) if isinstance(quantization_config, dict): - return quantization_config.get('weight_block_size', default_value) + return quantization_config.get("weight_block_size", default_value) return default_value def main(args: argparse.Namespace): print(args) - config = AutoConfig.from_pretrained( - args.model, trust_remote_code=args.trust_remote_code) + config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) + if args.model_prefix: + config = getattr(config, args.model_prefix) + config = SimpleNamespace(**config) + if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k @@ -546,15 +575,12 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif (config.architectures[0] == "DeepseekV3ForCausalLM" - or config.architectures[0] == "DeepseekV2ForCausalLM"): + elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in [ - "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" - ]: + elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -569,15 +595,35 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + dtype = ( + torch.float16 + if current_platform.is_rocm() + else getattr(torch, config.torch_dtype) + ) use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_quant_shape = get_weight_block_size_safety(config) if args.batch_size is None: batch_sizes = [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ] else: batch_sizes = [args.batch_size] @@ -588,7 +634,8 @@ def main(args: argparse.Namespace): # Ray will set ROCR_VISIBLE_DEVICES for device visibility logger.warning( "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." - "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.") + "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." + ) val = os.environ["HIP_VISIBLE_DEVICES"] os.environ["ROCR_VISIBLE_DEVICES"] = val del os.environ["HIP_VISIBLE_DEVICES"] @@ -615,25 +662,59 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, - block_quant_shape, use_deep_gemm) - for batch_size in batch_sizes]) + "tune", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) best_configs = { - M: sort_config(config) - for M, config in zip(batch_sizes, configs) + M: sort_config(config) for M, config in zip(batch_sizes, configs) } - save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, - block_quant_shape) + save_configs( + best_configs, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, + ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: outputs = _distribute( "benchmark", - [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, - use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm) - for batch_size in batch_sizes]) + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") @@ -642,23 +723,21 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if __name__ == "__main__": parser = FlexibleArgumentParser() - parser.add_argument("--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", - "-tp", - "--tensor-parallel-size", - type=int, - default=2) - parser.add_argument("--dtype", - type=str, - choices=["auto", "fp8_w8a8", "int8_w8a16"], - default="auto") + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2 + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--model-prefix", type=str, required=False) args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 937df962465..dba1f3943b9 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse from typing import Any, TypedDict @@ -8,7 +9,9 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _moe_permute, _moe_unpermute_and_reduce) + _moe_permute, + _moe_unpermute_and_reduce, +) from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize @@ -27,15 +30,17 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_permute(num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - use_customized_permute: bool = False) -> float: +def benchmark_permute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False, +) -> float: # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) # output_hidden_states = torch.empty_like(hidden_states) @@ -46,36 +51,41 @@ def benchmark_permute(num_tokens: int, align_block_size = None qhidden_states = hidden_states - gating_output = torch.randn(num_iters, - num_tokens, - num_experts, - dtype=torch.float32) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) topk_weights, topk_ids, token_expert_indices = fused_topk( - qhidden_states, input_gating, topk, False) + qhidden_states, input_gating, topk, False + ) def prepare(i: int): input_gating.copy_(gating_output[i]) def run(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - m_indices) = moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( + moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + ) else: - (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qhidden_states, None, topk_ids, - num_experts, None, align_block_size) + ( + permuted_hidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = _moe_permute( + qhidden_states, None, topk_ids, num_experts, None, align_block_size + ) # JIT compilation & warmup run() @@ -111,15 +121,17 @@ def run(): return avg -def benchmark_unpermute(num_tokens: int, - num_experts: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - use_customized_permute: bool = False) -> float: +def benchmark_unpermute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False, +) -> float: # init_dtype = torch.float16 if use_fp8_w8a8 else dtype hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) output_hidden_states = torch.empty_like(hidden_states) @@ -133,46 +145,74 @@ def benchmark_unpermute(num_tokens: int, input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) topk_weights, topk_ids, token_expert_indices = fused_topk( - qhidden_states, input_gating, topk, False) + qhidden_states, input_gating, topk, False + ) def prepare(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - m_indices) = moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( + moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + ) # convert to fp16/bf16 as gemm output - return (permuted_hidden_states.to(dtype), first_token_off, - inv_perm_idx, m_indices) + return ( + permuted_hidden_states.to(dtype), + first_token_off, + inv_perm_idx, + m_indices, + ) else: - (permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qhidden_states, None, topk_ids, - num_experts, None, align_block_size) + ( + permuted_qhidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = _moe_permute( + qhidden_states, None, topk_ids, num_experts, None, align_block_size + ) # convert to fp16/bf16 as gemm output - return (permuted_qhidden_states.to(dtype), a1q_scale, - sorted_token_ids, expert_ids, inv_perm) + return ( + permuted_qhidden_states.to(dtype), + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) def run(input: tuple): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, - m_indices) = input - moe_unpermute(permuted_hidden_states, topk_weights, topk_ids, - inv_perm_idx, first_token_off, topk, num_experts, - num_experts) + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input + moe_unpermute( + permuted_hidden_states, + topk_weights, + topk_ids, + inv_perm_idx, + first_token_off, + topk, + num_experts, + num_experts, + ) else: - (permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = input - _moe_unpermute_and_reduce(output_hidden_states, - permuted_hidden_states, inv_perm, - topk_weights) + ( + permuted_hidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = input + _moe_unpermute_and_reduce( + output_hidden_states, permuted_hidden_states, inv_perm, topk_weights + ) # JIT compilation & warmup input = prepare() @@ -209,7 +249,6 @@ def run(input: tuple): @ray.remote(num_gpus=1) class BenchmarkWorker: - def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) @@ -241,7 +280,8 @@ def benchmark( use_fp8_w8a8, use_int8_w8a16, num_iters=100, - use_customized_permute=use_customized_permute) + use_customized_permute=use_customized_permute, + ) unpermute_time = benchmark_unpermute( num_tokens, num_experts, @@ -251,15 +291,15 @@ def benchmark( use_fp8_w8a8, use_int8_w8a16, num_iters=100, - use_customized_permute=use_customized_permute) + use_customized_permute=use_customized_permute, + ) return permute_time, unpermute_time def get_weight_block_size_safety(config, default_value=None): - - quantization_config = getattr(config, 'quantization_config', {}) + quantization_config = getattr(config, "quantization_config", {}) if isinstance(quantization_config, dict): - return quantization_config.get('weight_block_size', default_value) + return quantization_config.get("weight_block_size", default_value) return default_value @@ -267,20 +307,21 @@ def main(args: argparse.Namespace): print(args) config = AutoConfig.from_pretrained( - args.model, trust_remote_code=args.trust_remote_code) + args.model, trust_remote_code=args.trust_remote_code + ) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k elif config.architectures[0] == "JambaForCausalLM": E = config.num_experts topk = config.num_experts_per_tok - elif (config.architectures[0] == "DeepseekV3ForCausalLM" - or config.architectures[0] == "DeepseekV2ForCausalLM"): + elif ( + config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM" + ): E = config.n_routed_experts topk = config.num_experts_per_tok - elif config.architectures[0] in [ - "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" - ]: + elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: E = config.num_experts topk = config.num_experts_per_tok @@ -299,8 +340,24 @@ def main(args: argparse.Namespace): if args.batch_size is None: batch_sizes = [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ] else: batch_sizes = [args.batch_size] @@ -321,9 +378,21 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: return ray.get(outputs) outputs = _distribute( - "benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8, - use_int8_w8a16, use_customized_permute) - for batch_size in batch_sizes]) + "benchmark", + [ + ( + batch_size, + E, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_customized_permute, + ) + for batch_size in batch_sizes + ], + ) for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}") @@ -333,13 +402,12 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if __name__ == "__main__": parser = FlexibleArgumentParser() - parser.add_argument("--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--dtype", - type=str, - choices=["auto", "fp8_w8a8", "int8_w8a16"], - default="auto") + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) parser.add_argument("--use-customized-permute", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 2625239b08e..7e0376c18ec 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import time @@ -9,8 +10,11 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + FlexibleArgumentParser, + create_kv_caches_with_random, +) logger = init_logger(__name__) @@ -38,19 +42,15 @@ def main( current_platform.seed_everything(seed) scale = float(1.0 / (head_size**0.5)) - query = torch.empty(num_seqs, - num_query_heads, - head_size, - dtype=dtype, - device=device) + query = torch.empty( + num_seqs, num_query_heads, head_size, dtype=dtype, device=device + ) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 alibi_slopes = None if use_alibi: - alibi_slopes = torch.randn(num_query_heads, - dtype=torch.float, - device=device) + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device) seq_lens = [seq_len for _ in range(num_seqs)] max_seq_len = max(seq_lens) @@ -61,24 +61,23 @@ def main( block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) - block_tables = torch.tensor(block_tables_lst, - dtype=torch.int, - device=device) + block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -86,11 +85,11 @@ def main( if version == "v2": if current_platform.is_rocm(): global PARTITION_SIZE - if not args.custom_paged_attn: + if not args.custom_paged_attn and not current_platform.is_navi(): PARTITION_SIZE = 1024 else: PARTITION_SIZE = PARTITION_SIZE_ROCM - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +109,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, - dtype=torch.float32, - device=device) + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) for _ in range(num_iters): if version == "v1": @@ -166,6 +163,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: scale, block_tables, seq_lens, + None, block_size, max_seq_len, alibi_slopes, @@ -195,30 +193,29 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': - logger.warning("This script benchmarks the paged attention kernel. " - "By default this is no longer used in vLLM inference.") +if __name__ == "__main__": + logger.warning( + "This script benchmarks the paged attention kernel. " + "By default this is no longer used in vLLM inference." + ) - parser = FlexibleArgumentParser( - description="Benchmark the paged attention kernel.") - parser.add_argument("--version", - type=str, - choices=["v1", "v2"], - default="v2") + parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") + parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) - parser.add_argument("--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument( @@ -228,10 +225,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: default="auto", help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " - "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") - parser.add_argument("--custom-paged-attn", - action="store_true", - help="Use custom paged attention") + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)", + ) + parser.add_argument( + "--custom-paged-attn", action="store_true", help="Use custom paged attention" + ) args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index b643897a60e..6ab26f5f1ad 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time @@ -10,15 +11,17 @@ @torch.inference_mode() -def main(num_tokens: int, - hidden_size: int, - static_scale: bool, - quant_dtype: torch.dtype, - dtype: torch.dtype, - seed: int = 0, - do_profile: bool = False, - num_warmup_iters: int = 5, - num_iters: int = 100) -> None: +def main( + num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: current_platform.seed_everything(seed) torch.set_default_device("cuda") @@ -56,7 +59,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -66,37 +69,40 @@ def to_torch_dtype(dt): raise ValueError(f"Unsupported dtype: {dt}") parser = FlexibleArgumentParser( - description="Benchmark the quantization (fp8 or int8) kernel.") + description="Benchmark the quantization (fp8 or int8) kernel." + ) parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--static-scale", action="store_true") - parser.add_argument("--quant-dtype", - type=str, - choices=["fp8", "int8"], - default="int8") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8" + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. " - "If --profile is set, this number is ignored") + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) args = parser.parse_args() print(args) - main(num_tokens=args.num_tokens, - hidden_size=args.hidden_size, - static_scale=args.static_scale, - quant_dtype=to_torch_dtype(args.quant_dtype), - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - do_profile=args.profile, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters) + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index 09a319ccf1d..4cf633a8135 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Optional, Union @@ -12,7 +13,6 @@ class HuggingFaceRMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -114,23 +114,19 @@ def rmsnorm_vllm( def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): dtype = torch.bfloat16 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None output_naive = rmsnorm_naive( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) output_flashinfer = rmsnorm_flashinfer( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) output_vllm = rmsnorm_vllm( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) if use_residual: output_naive = output_naive[0] @@ -141,9 +137,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): print(f"FlashInfer output={output_flashinfer}") print(f"vLLM output={output_vllm}") - if torch.allclose(output_naive, output_flashinfer, atol=1e-2, - rtol=1e-2) and torch.allclose( - output_naive, output_vllm, atol=1e-2, rtol=1e-2): + if torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): print("✅ All implementations match") else: print("❌ Implementations differ") @@ -152,12 +148,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): batch_size_range = [2**i for i in range(0, 7, 2)] seq_length_range = [2**i for i in range(6, 11, 1)] head_num_range = [32, 48] -configs = list( - itertools.product(head_num_range, batch_size_range, seq_length_range)) +configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) def get_benchmark(use_residual): - @triton.testing.perf_report( triton.testing.Benchmark( x_names=["head_num", "batch_size", "seq_len"], @@ -167,19 +161,15 @@ def get_benchmark(use_residual): line_names=["HuggingFace", "FlashInfer", "vLLM"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", - plot_name= - f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", args={}, - )) + ) + ) def benchmark(head_num, batch_size, seq_len, provider): dtype = torch.bfloat16 hidden_size = head_num * 128 # assuming head_dim = 128 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None @@ -240,9 +230,9 @@ def benchmark(head_num, batch_size, seq_len, provider): default=4096, help="Hidden size (2nd dimension) of the sequence", ) - parser.add_argument("--use-residual", - action="store_true", - help="Whether to use residual connection") + parser.add_argument( + "--use-residual", action="store_true", help="Whether to use residual connection" + ) parser.add_argument( "--save-path", type=str, @@ -253,10 +243,12 @@ def benchmark(head_num, batch_size, seq_len, provider): args = parser.parse_args() # Run correctness test - calculate_diff(batch_size=args.batch_size, - seq_len=args.seq_len, - hidden_size=args.hidden_size, - use_residual=args.use_residual) + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual, + ) # Get the benchmark function with proper use_residual setting benchmark = get_benchmark(args.use_residual) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 05d24fc4b16..b81baf17a8c 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from itertools import accumulate from typing import Optional @@ -6,8 +7,7 @@ import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, - get_rope) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -23,7 +23,7 @@ def benchmark_rope_kernels_multi_lora( seed: int, device: str, max_position: int = 8192, - base: int = 10000, + base: float = 10000, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -32,40 +32,49 @@ def benchmark_rope_kernels_multi_lora( # silulating serving 4 LoRAs scaling_factors = [1, 2, 4, 8] # batched RoPE can take multiple scaling factors - batched_rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) + batched_rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": tuple(scaling_factors)}, + ) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior non_batched_ropes: list[RotaryEmbedding] = [] for scaling_factor in scaling_factors: non_batched_ropes.append( - get_rope(head_size, rotary_dim, max_position, base, is_neox_style, - { - "rope_type": "linear", - "factor": (scaling_factor, ) - })) + get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": (scaling_factor,)}, + ) + ) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) key = torch.randn_like(query) # create query offsets for batched RoPE, we concat multiple kv cache # together and each query needs to find the right kv cache of its type offset_map = torch.tensor( list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) + accumulate( + [0] + + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ] + ) + ) + ) + query_types = torch.randint( + 0, len(scaling_factors), (batch_size, seq_len), device=device + ) # map query types to offsets query_offsets = offset_map[query_types] # the kernel takes flattened offsets @@ -86,27 +95,28 @@ def benchmark_rope_kernels_multi_lora( torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the rotary embedding kernels.") + description="Benchmark the rotary embedding kernels." + ) parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--num-heads", type=int, default=8) - parser.add_argument("--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) - parser.add_argument("--dtype", - type=str, - choices=["bfloat16", "float"], - default="float") + parser.add_argument( + "--dtype", type=str, choices=["bfloat16", "float"], default="float" + ) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--device", - type=str, - choices=["cuda:0", "cuda:1"], - default="cuda:0") + parser.add_argument( + "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" + ) args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 70190ba24d9..18c459c31d3 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project WEIGHT_SHAPES = { "ideal": [[4 * 256 * 32, 256 * 32]], diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 8f07bc8ca52..4fcdbadd65e 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from sglang quantization/tuning_block_wise_kernel.py import argparse @@ -14,14 +15,16 @@ import triton from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - _w8a8_block_fp8_matmul) + _w8a8_block_fp8_matmul, +) from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser mp.set_start_method("spawn", force=True) -assert current_platform.is_cuda( -), "Only support tune w8a8 block fp8 kernel on CUDA device." +assert current_platform.is_cuda(), ( + "Only support tune w8a8 block fp8 kernel on CUDA device." +) DTYPE_MAP = { "float32": torch.float32, @@ -40,7 +43,7 @@ def w8a8_block_matmul( config: dict[str, Any], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - """This function performs matrix multiplication with + """This function performs matrix multiplication with block-wise quantization. It takes two input tensors `A` and `B` with scales `As` and `Bs`. @@ -51,7 +54,7 @@ def w8a8_block_matmul( B: The input tensor, e.g., weight. As: The per-token-group quantization scale for `A`. Bs: The per-block quantization scale for `B`. - block_size: The block size for per-block quantization. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. output_dytpe: The dtype of the returned tensor. @@ -71,18 +74,18 @@ def w8a8_block_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) if A.dtype == torch.float8_e4m3fn: kernel = _w8a8_block_fp8_matmul else: - raise RuntimeError( - "Currently, only support tune w8a8 block fp8 kernel.") + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") kernel[grid]( A, @@ -119,14 +122,16 @@ def get_configs_compute_bound(): for block_n in [32, 64, 128, 256]: for num_warps in [4, 8]: for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) return configs @@ -165,15 +170,9 @@ def get_weight_shapes(tp_size): return weight_shapes -def benchmark_config(A, - B, - As, - Bs, - block_size, - config, - out_dtype=torch.float16, - num_iters=10): - +def benchmark_config( + A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 +): def run(): w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) @@ -206,26 +205,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): fp8_max, fp8_min = fp8_info.max, fp8_info.min A_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * - fp8_max) + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) B_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * - fp8_max) + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) else: - raise RuntimeError( - "Currently, only support tune w8a8 block fp8 kernel.") + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32, - device="cuda") * factor_for_scale - Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") * - factor_for_scale) + As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + * factor_for_scale + ) best_config = None best_time = float("inf") @@ -267,7 +266,8 @@ def save_configs( device_name = current_platform.get_device_name().replace(" ", "_") json_file_name = ( f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," - f"block_shape=[{block_n},{block_k}].json") + f"block_shape=[{block_n},{block_k}].json" + ) config_file_path = os.path.join(save_path, json_file_name) print(f"Writing best config to {config_file_path}...") @@ -295,8 +295,7 @@ def tune_on_gpu(args_dict): search_space = get_configs_compute_bound() search_space = [ - config for config in search_space - if block_k % config["BLOCK_SIZE_K"] == 0 + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 ] start = time.time() @@ -312,15 +311,11 @@ def tune_on_gpu(args_dict): out_dtype, search_space, input_type, - ) for batch_size in tqdm(batch_sizes, - desc=f"GPU {gpu_id} - Batch sizes") + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") ] - best_configs = { - M: config - for M, config in zip(batch_sizes, benchmark_results) - } - save_configs(N, K, block_n, block_k, best_configs, save_path, - input_type) + best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} + save_configs(N, K, block_n, block_k, best_configs, save_path, input_type) end = time.time() print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") @@ -376,13 +371,14 @@ def main(args): process_args = [] for gpu_id in range(num_gpus): - process_args.append({ - "gpu_id": gpu_id, - "batch_sizes": batches_per_gpu[gpu_id], - "weight_shapes": - weight_shapes, # Each GPU processes all weight shapes - "args": args, - }) + process_args.append( + { + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": weight_shapes, # Each GPU processes all weight shapes + "args": args, + } + ) ctx = mp.get_context("spawn") with ctx.Pool(num_gpus) as pool: @@ -398,13 +394,11 @@ def main(args): python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 Then copy to model_executor/layers/quantization/utils/configs """, - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) parser.add_argument("--tp-size", "-tp", type=int, default=8) - parser.add_argument("--input-type", - type=str, - choices=["fp8"], - default="fp8") + parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8") parser.add_argument( "--out-dtype", type=str, diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 5fa55bb974e..e67ce054531 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # fmt: off # ruff: noqa: E501 import time @@ -11,7 +12,9 @@ # Import vLLM functions from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) from vllm.triton_utils import triton diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index bd62173a7b3..9a4da0ef5a8 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import pickle -import re from collections import defaultdict import matplotlib.pyplot as plt import pandas as pd +import regex as re import seaborn as sns from torch.utils.benchmark import Measurement as TMeasurement @@ -14,13 +15,14 @@ if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the latency of processing a single batch of ' - 'requests till completion.') - parser.add_argument('filename', type=str) + description="Benchmark the latency of processing a single batch of " + "requests till completion." + ) + parser.add_argument("filename", type=str) args = parser.parse_args() - with open(args.filename, 'rb') as f: + with open(args.filename, "rb") as f: data = pickle.load(f) raw_results: list[TMeasurement] = data["results"] @@ -38,11 +40,7 @@ raise Exception("MKN not found") kernel = v.task_spec.description - results[KN].append({ - "kernel": kernel, - "batch_size": M, - "median": v.median - }) + results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median}) rows = int(math.ceil(len(results) / 2)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) @@ -50,14 +48,16 @@ for axs_idx, (shape, data) in enumerate(results.items()): plt.sca(axs[axs_idx]) df = pd.DataFrame(data) - sns.lineplot(data=df, - x="batch_size", - y="median", - hue="kernel", - style="kernel", - markers=True, - dashes=False, - palette="Dark2") + sns.lineplot( + data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2", + ) plt.title(f"Shape: {shape}") plt.ylabel("time (median, s)") plt.tight_layout() diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index ac64f786f18..4bbb36bb435 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses from collections.abc import Iterable @@ -23,6 +24,7 @@ class ArgPool: For every invocation during a benchmarking run, it will choose a different value from the list. """ + values: Iterable[Any] def __getitem__(self, index): @@ -30,9 +32,7 @@ def __getitem__(self, index): class Bench: - class ArgsIterator: - def __init__(self, args_list, kwargs_list): assert len(args_list) == len(kwargs_list) self.args_list = args_list @@ -53,10 +53,16 @@ def reset(self): def n_args(self): return self.n - def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], - label: str, sub_label: str, description: str, fn: Callable, - *args, **kwargs): - + def __init__( + self, + cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, + sub_label: str, + description: str, + fn: Callable, + *args, + **kwargs, + ): self.cuda_graph_params = cuda_graph_params self.use_cuda_graph = self.cuda_graph_params is not None self.label = label @@ -67,10 +73,8 @@ def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], # Process args self._args = args self._kwargs = kwargs - self.args_list, self.kwargs_list = self.collapse_argpool( - *args, **kwargs) - self.args_iterator = self.ArgsIterator(self.args_list, - self.kwargs_list) + self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list) # Cudagraph runner self.g = None @@ -100,16 +104,13 @@ def collapse_argpool(self, *args, **kwargs): for i in range(argpool_size): # collapse args; Just pick the ith value - args_list[i] = tuple([ - arg[i] if isinstance(arg, ArgPool) else arg - for arg in args_list[i] - ]) + args_list[i] = tuple( + [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]] + ) # collapse kwargs kwargs_i = kwargs_list[i] - arg_pool_keys = [ - k for k, v in kwargs_i.items() if isinstance(v, ArgPool) - ] + arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)] for k in arg_pool_keys: # again just pick the ith value kwargs_i[k] = kwargs_i[k][i] @@ -142,7 +143,7 @@ def get_cuda_graph_runner(self): def run_cudagrah(self) -> TMeasurement: assert self.use_cuda_graph - globals = {'g': self.g} + globals = {"g": self.g} return TBenchmark.Timer( stmt="g.replay()", @@ -162,15 +163,15 @@ def run_eager(self) -> TMeasurement: has_arg_pool = self.args_iterator.n_args > 1 if has_arg_pool: - setup = ''' + setup = """ args_iterator.reset() args_it = args_iterator.__next__() - ''' - stmt = ''' + """ + stmt = """ args, kwargs = next(args_it) fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + """ + globals = {"fn": self.fn, "args_iterator": self.args_iterator} else: # no arg pool. Just use the args and kwargs directly self.args_iterator.reset() @@ -178,10 +179,10 @@ def run_eager(self) -> TMeasurement: args, kwargs = next(args_it) setup = "" - stmt = ''' + stmt = """ fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + """ + globals = {"fn": self.fn, "args": args, "kwargs": kwargs} return TBenchmark.Timer( stmt=stmt, diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 89b05d5882a..a27f02394af 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Weight Shapes are in the format # ([K, N], TP_SPLIT_DIM) @@ -48,4 +49,50 @@ ([16384, 106496], 1), ([53248, 16384], 0), ], + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], } diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 5f94552e9dc..0957a9c65f0 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import cProfile import pstats @@ -7,9 +8,8 @@ from vllm.utils import FlexibleArgumentParser # A very long prompt, total number of tokens is about 15k. -LONG_PROMPT = ["You are an expert in large language models, aren't you?" - ] * 1000 -LONG_PROMPT = ' '.join(LONG_PROMPT) +LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000 +LONG_PROMPT = " ".join(LONG_PROMPT) def main(args): @@ -30,32 +30,35 @@ def main(args): print("------start generating------") for i in range(3): - profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', - globals(), locals()) + profiler.runctx( + "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals() + ) # analyze the runtime of hashing function stats = pstats.Stats(profiler) - stats.sort_stats('cumulative') + stats.sort_stats("cumulative") total_time = 0 total_calls = 0 for func in stats.stats: - if 'hash_of_block' in func[2]: + if "hash_of_block" in func[2]: total_time = stats.stats[func][3] total_calls = stats.stats[func][0] percentage = (total_time / stats.total_tt) * 100 - print(f"Hashing took {total_time:.2f} seconds," - f"{percentage:.2f}% of the total runtime.") + print( + f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime." + ) if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the performance of hashing function in' - 'automatic prefix caching.') - parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') + description="Benchmark the performance of hashing function in" + "automatic prefix caching." + ) + parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k") + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--enable-prefix-caching", action="store_true", help="enable prefix caching" + ) args = parser.parse_args() main(args) diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml new file mode 100644 index 00000000000..65b1e09a247 --- /dev/null +++ b/benchmarks/pyproject.toml @@ -0,0 +1,49 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.lint.isort] +known-first-party = ["vllm"] + +[tool.ruff.format] +docstring-code-format = true \ No newline at end of file diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh index 53dc7ed70b9..b043ab83e46 100755 --- a/benchmarks/run_structured_output_benchmark.sh +++ b/benchmarks/run_structured_output_benchmark.sh @@ -1,32 +1,98 @@ #!/bin/bash -# Define the model to use -MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} - -# Define the backend to use -BACKEND=${2:-"vllm"} - -# Define the dataset to use -DATASET=${3:-"xgrammar_bench"} - +# default values +MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"} +BACKEND=${BACKEND:-"vllm"} +DATASET=${DATASET:-"xgrammar_bench"} SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -OUTPUT_DIR=${4:-"$SCRIPT_DIR/structured_output_benchmark_results"} +OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"} +PORT=${PORT:-8000} +STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1} +TOTAL_SECONDS=${TOTAL_SECONDS:-90} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300} +TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"} -GUIDED_RATIO=${5:-0.5} +usage() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --model MODEL Model to benchmark (default: $MODEL)" + echo " --backend BACKEND Backend to use (default: $BACKEND)" + echo " --dataset DATASET Dataset to use (default: $DATASET)" + echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)" + echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)" + echo " --port PORT Port to use (default: $PORT)" + echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)" + echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)" + echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)" + echo " -h, --help Show this help message and exit" + exit 0 +} + +# parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --backend) + BACKEND="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --max-new-tokens) + MAX_NEW_TOKENS="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --structured-output-ratio) + STRUCTURED_OUTPUT_RATIO="$2" + shift 2 + ;; + --tokenizer-mode) + TOKENIZER_MODE="$2" + shift 2 + ;; + --total-seconds) + TOTAL_SECONDS="$2" + shift 2 + ;; + -h|--help) + usage + ;; + *) + echo "Unknown argument: $1\n" + usage + ;; + esac +done # Create output directory if it doesn't exist mkdir -p "$OUTPUT_DIR" # Define QPS values to test -QPS_VALUES=(70 60 50 25 20 15 10) +QPS_VALUES=(25 20 15 10 5 1) # Common parameters COMMON_PARAMS="--backend $BACKEND \ --model $MODEL \ --dataset $DATASET \ - --structured-output-ratio $GUIDED_RATIO \ + --structured-output-ratio $STRUCTURED_OUTPUT_RATIO \ --save-results \ - --result-dir $OUTPUT_DIR" + --result-dir $OUTPUT_DIR \ + --output-len $MAX_NEW_TOKENS \ + --port $PORT \ + --tokenizer-mode $TOKENIZER_MODE" echo "Starting structured output benchmark with model: $MODEL" echo "Backend: $BACKEND" @@ -45,12 +111,15 @@ for qps in "${QPS_VALUES[@]}"; do # Construct filename for this run FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" + NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc) + NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part + echo "Running benchmark with $NUM_PROMPTS prompts" + # Run the benchmark python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ --request-rate $qps \ --result-filename "$FILENAME" \ - --tokenizer-mode ${TOKENIZER_MODE:-"auto"} \ - --port ${PORT:-8000} + --num-prompts $NUM_PROMPTS echo "Completed benchmark with QPS: $qps" echo "----------------------------------------" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 00670bd398b..fb763db9fc3 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) +elseif(POWER10_FOUND) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.7.2 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + set(DNNL_CPU_RUNTIME "OMP") + + FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) endif() @@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) +elseif(POWER10_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) endif() # @@ -214,4 +245,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") \ No newline at end of file +message(STATUS "Enabling C extension.") diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index b04e4c2d06e..a4edd5b96fe 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -46,22 +46,38 @@ else() endif() +# Ensure the vllm/vllm_flash_attn directory exists before installation +install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS) + +# Make sure vllm-flash-attn install rules are nested under vllm/ +# This is here to support installing all components under the same prefix with cmake --install. +# setup.py installs every component separately but uses the same prefix for all. +# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3, +# and these statements don't hurt when installing neither component. +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS) +install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS) + # Fetch the vllm-flash-attn library FetchContent_MakeAvailable(vllm-flash-attn) message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") +# Restore the install prefix +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + # Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in # case only one is built, in the case both are built redundant work is done) install( DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn + DESTINATION vllm/vllm_flash_attn COMPONENT _vllm_fa2_C FILES_MATCHING PATTERN "*.py" ) install( DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn + DESTINATION vllm/vllm_flash_attn COMPONENT _vllm_fa3_C FILES_MATCHING PATTERN "*.py" ) diff --git a/cmake/hipify.py b/cmake/hipify.py index a15577125eb..55d378f5b11 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # # A command line tool for running pytorch's hipify preprocessor on CUDA diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c9cd099b82a..6d90555f296 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -76,7 +76,7 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS) set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc) add_custom_target( hipify${NAME} - COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS} + COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS} DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS} BYPRODUCTS ${HIP_SRCS} COMMENT "Running hipify on ${NAME} extension source files.") @@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs) "${multiValueArgs}" ${ARGN} ) foreach(_ARCH ${arg_CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_ARCH}" - CODE "sm_${_ARCH}") + # handle +PTX suffix: generate both sm and ptx codes if requested + string(FIND "${_ARCH}" "+PTX" _HAS_PTX) + if(NOT _HAS_PTX EQUAL -1) + string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") + string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "compute_${_STRIPPED_ARCH}") + else() + string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + endif() endforeach() if (${arg_BUILD_PTX_FOR_ARCH}) @@ -251,7 +266,10 @@ endmacro() # # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # `.[letter]` compute the "loose intersection" with the -# `TGT_CUDA_ARCHS` list of gencodes. +# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in +# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there +# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the +# architecture in `SRC_CUDA_ARCHS`. # The loose intersection is defined as: # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # where `<=` is the version comparison operator. @@ -268,44 +286,63 @@ endmacro() # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # +# Example With PTX: +# SRC_CUDA_ARCHS="8.0+PTX" +# TGT_CUDA_ARCHS="9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0+PTX" +# function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) - list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) - set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) + set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") + set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) + + # handle +PTX suffix: separate base arch for matching, record PTX requests + set(_PTX_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\+PTX$") + string(REPLACE "+PTX" "" _base "${_arch}") + list(APPEND _PTX_ARCHS "${_base}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + list(APPEND _SRC_CUDA_ARCHS "${_base}") + endif() + endforeach() + list(REMOVE_DUPLICATES _PTX_ARCHS) + list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS_) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") + if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") set(_CUDA_ARCHS "9.0a") endif() endif() - if ("10.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") if ("10.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") set(_CUDA_ARCHS "10.0a") endif() endif() - list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # is less or equal to ARCH (but has the same major version since SASS binary # compatibility is only forward compatible within the same major version). - foreach(_ARCH ${TGT_CUDA_ARCHS_}) + foreach(_ARCH ${_TGT_CUDA_ARCHS}) set(_TMP_ARCH) # Extract the major version of the target arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") - foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) # Extract the major version of the source arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") - # Check major-version match AND version-less-or-equal + # Check version-less-or-equal, and allow PTX arches to match across majors if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) set(_TMP_ARCH "${_SRC_ARCH}") endif() else() @@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) + + # reapply +PTX suffix to architectures that requested PTX + set(_FINAL_ARCHS) + foreach(_arch ${_CUDA_ARCHS}) + if(_arch IN_LIST _PTX_ARCHS) + list(APPEND _FINAL_ARCHS "${_arch}+PTX") + else() + list(APPEND _FINAL_ARCHS "${_arch}") + endif() + endforeach() + set(_CUDA_ARCHS ${_FINAL_ARCHS}) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83..55e65967970 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ + if (num_tokens == 0) { \ + return; \ + } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index eb216dc8baf..79a546554fa 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -172,7 +172,7 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in + // For example, if the thread group size is 4, then the first thread in // the group has 0, 4, 8, ... th vectors of the query, and the second thread // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // q is split from a qkv tensor, it may not be contiguous. @@ -259,7 +259,7 @@ __device__ void paged_attention_kernel( // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in + // For example, if the thread group size is 4, then the first thread in // the group has 0, 4, 8, ... th vectors of the key, and the second thread // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 14e5edd7e28..6bee9e4ce11 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, + "output heads must be contiguous in memory"); + TORCH_CHECK( + prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, + "prefix_output heads must be contiguous in memory"); + TORCH_CHECK( + suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, + "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index 6743af0cf2d..f4b6b19f4b2 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options( {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, hw_info, - -1, // split_kv + 1, // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu new file mode 100644 index 00000000000..c1b45b143f4 --- /dev/null +++ b/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include + +__device__ int64_t save_blocks(int* block_offset, int64_t range_start, + int64_t range_end, int64_t block_size, + int64_t input_block_count, int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return input_block_count; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + int64_t current_block_count = input_block_count; + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[current_block_count++] = idx; + } + return current_block_count; +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * This function builds the index of each row of blocks from vertical indices + * and slash indices. The vertical indices are treated as points, while the + * slash indices are converted as ranges. The output consists of the merged + * ranges and separate column indices, where the ranges are represented by + * block indices. + * + * The implementation is referenced from the original MInference repo: + * https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu. + */ +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + block_count.data_ptr(), block_offset.data_ptr(), + column_count.data_ptr(), column_index.data_ptr(), batch_size, + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, + causal); +} + +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * Like the above convert_vertical_slash_indexes, but with + * pre-computed vertical and slash counts. + */ +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, // [N_HEADS, ] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); +} diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index c2ae554c9f8..d0f85e23609 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index a8e1be37eb4..089b9840ea2 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace vec_op { @@ -62,6 +63,10 @@ typedef struct f32x4x4_t { __vector float val[4]; } f32x4x4_t; +typedef struct i32x4x4_t { + __vector int32_t val[4]; +} i32x4x4_t; + struct FP32Vec8; struct FP32Vec16; @@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec { vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr); } + + void save(void* ptr, const int elem_num) const { + const int clamped_elem = std::max(0, std::min(elem_num, 16)); + + // Calculate elements to store in each 128-bit part (8 elements each) + const int elements_val0 = std::min(clamped_elem, 8); + const int elements_val1 = std::max(clamped_elem - 8, 0); + + // Convert elements to bytes (2 bytes per element) + const size_t bytes_val0 = elements_val0 * sizeof(signed short); + const size_t bytes_val1 = elements_val1 * sizeof(signed short); + + signed short* dest = static_cast(ptr); + // Store the first part using vec_xst_len + if (bytes_val0 > 0) { + vec_xst_len(reg.val[0], dest, bytes_val0); + } + // Store the second part if needed + if (bytes_val1 > 0) { + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); + } + } }; const static __vector signed short zero = vec_splats((signed short)0); @@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec { } }; +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + i32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + + i32x4x4_t reg; + + explicit INT32Vec16(const void* data_ptr) { + reg.val[0] = vec_xl(0, reinterpret_cast(data_ptr)); + reg.val[1] = + vec_xl(16, reinterpret_cast(data_ptr)); + reg.val[2] = + vec_xl(32, reinterpret_cast(data_ptr)); + reg.val[3] = + vec_xl(48, reinterpret_cast(data_ptr)); + } + + void save(int32_t* ptr) const { + vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr)); + } + + void save(int32_t* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(int32_t)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(int32_t)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(int32_t)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(int32_t)); + + vec_xst_len(reg.val[0], reinterpret_cast(ptr), bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16& v) { + reg.val[0] = vec_ctf(v.reg.val[0], 0); + reg.val[1] = vec_ctf(v.reg.val[1], 0); + reg.val[2] = vec_ctf(v.reg.val[2], 0); + reg.val[3] = vec_ctf(v.reg.val[3], 0); + } + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1]), @@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec { vec_div(reg.val[3], b.reg.val[3])})); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(f32x4x4_t( + {vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))})); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 max(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + // Create a vector of element indices for each chunk + __vector unsigned int indices = {0, 1, 2, 3}; + __vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + // Compute masks for each chunk + __vector unsigned int chunk_offset0 = {0, 0, 0, + 0}; // Chunk 0: Elements 0-3 + __vector unsigned int chunk_offset1 = {4, 4, 4, + 4}; // Chunk 1: Elements 4-7 + __vector unsigned int chunk_offset2 = {8, 8, 8, + 8}; // Chunk 2: Elements 8-11 + __vector unsigned int chunk_offset3 = {12, 12, 12, + 12}; // Chunk 3: Elements 12-15 + + // Compute masks for each chunk + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + // Apply masks to compute the result for each chunk + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_max(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_max(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_max(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_max(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]), + vec_min(reg.val[1], b.reg.val[1]), + vec_min(reg.val[2], b.reg.val[2]), + vec_min(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 min(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + vector unsigned int indices = {0, 1, 2, 3}; + vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; + + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_min(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_min(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_min(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_min(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 abs() const { + return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]), + vec_abs(reg.val[2]), vec_abs(reg.val[3])})); + } + + float reduce_max() { + __vector float max01 = vec_max(reg.val[0], reg.val[1]); + __vector float max23 = vec_max(reg.val[2], reg.val[3]); + __vector float max_all = vec_max(max01, max23); + __vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8)); + temp = vec_max(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + + float reduce_min() { + __vector float min01 = vec_min(reg.val[0], reg.val[1]); + __vector float min23 = vec_min(reg.val[2], reg.val[3]); + __vector float min_all = vec_min(min01, min23); + __vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8)); + temp = vec_min(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec { vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[3], 48, ptr); } + + void save(float* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(float)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(float)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(float)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(float)); + + vec_xst_len(reg.val[0], ptr, bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + +struct INT8Vec16 : public Vec { + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 + + union AliasReg { + __vector signed char reg; + int8_t values[VEC_NUM_ELEM]; + }; + + __vector signed char reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + __vector signed int ret[4]; + ret[0] = vec_cts(vec.reg.val[0], 0); + ret[1] = vec_cts(vec.reg.val[1], 0); + ret[2] = vec_cts(vec.reg.val[2], 0); + ret[3] = vec_cts(vec.reg.val[3], 0); + + __vector signed short packed1 = vec_packs(ret[0], ret[1]); + __vector signed short packed2 = vec_packs(ret[2], ret[3]); + + reg = vec_packs(packed1, packed2); + } + + void save(void* ptr) const { + *reinterpret_cast<__vector signed char*>(ptr) = reg; + } + void save(signed char* ptr, const int elem_num) { + vec_xst_len(reg, ptr, static_cast(elem_num)); + } }; template diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index cf67847b45b..9a613ba588d 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -19,6 +19,7 @@ namespace vec_op { #define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 8a59e884d6c..74bb014cf39 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -9,7 +9,8 @@ void rotary_embedding_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -85,10 +86,13 @@ void rotary_embedding_impl( compute_loop(token_head, cache_ptr, query); } - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - compute_loop(token_head, cache_ptr, key); + if (key != nullptr) { + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); + } } } } @@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl( } } + if (key == nullptr) { + return; + } + #pragma omp parallel for collapse(2) for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { @@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = positions.numel(); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t key_stride = key.stride(-2); + int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads; + int64_t key_stride = key.has_value() ? key->stride(-2) : 0; int64_t query_stride = query.stride(-2); VLLM_DISPATCH_FLOATING_TYPES( @@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, if (is_neox) { rotary_embedding_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } else { rotary_embedding_gptj_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 6751e7e55fc..f61dbcc948e 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output, } } +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +#elif defined(__powerpc64__) +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} template void dynamic_quant_epilogue(const float* input, scalar_t* output, const float* a_scale, const float* b_scale, @@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, + "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") } template @@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, + "dynamic_quant_epilogue requires AVX512/powerpc64 support.") } #endif } // namespace @@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant( } }); } + +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_ppc64le only supports INT8 inputs."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + // We dont need this + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + }); +} + +#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7ae7e3386b4..447e826bc1c 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& bias); +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); +#endif + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -117,7 +125,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); @@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor azp_adj," " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#elif defined(__powerpc64__) + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index dbe0e30f5cb..195872e8edd 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -15,15 +15,6 @@ cutlassGetStatusString(error)); \ } -/** - * Panic wrapper for unwinding CUDA runtime errors - */ -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ - } - inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { int max_shared_mem_per_block_opt_in = 0; cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, @@ -59,3 +50,13 @@ struct enable_sm90_only : Kernel { #endif } }; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index d64f0d0a5c2..1dd7101acc2 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum from typing import Union diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index dc6e0769b87..f7b75c48373 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -65,5 +65,19 @@ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) + #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index fb6882f3e7c..d073dd6d2de 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 98daf1a1b8e..f62d08c17c6 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -13,6 +13,10 @@ #include #include +#ifdef USE_ROCM + namespace cub = hipcub; +#endif + #include "static_switch.h" @@ -501,15 +505,9 @@ void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { auto kernel = &causal_conv1d_fwd_kernel; if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. C10_CUDA_CHECK(cudaFuncSetAttribute( (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif } kernel<<>>(params); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index bd0a34119c8..0c9df925bdb 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -321,7 +321,7 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h deleted file mode 100644 index a217401b3d7..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ /dev/null @@ -1,1616 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/scalar_type.hpp" - -namespace marlin_moe { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales -using FragZP = Vec; - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { - half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - // Apply zero-point to frag_b0 - if constexpr (has_zp) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - -#else - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - -#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu deleted file mode 100644 index 77bc0dd90ed..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku4.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h deleted file mode 100644 index 833fadf3772..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu deleted file mode 100644 index f7e57b03759..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku4b8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = false; - - if (false) { - } - GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h deleted file mode 100644 index 494da8f10e2..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu deleted file mode 100644 index a901f0b11cd..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8b128.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = false; - - if (false) { - } - GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h deleted file mode 100644 index f3018aa0c1a..00000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu deleted file mode 100644 index 5f12483e951..00000000000 --- a/csrc/moe/marlin_moe_ops.cu +++ /dev/null @@ -1,588 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/exception.hpp" -#include "core/scalar_type.hpp" -#include "core/registration.h" -#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" -#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" -#include "marlin_kernels/marlin_moe_kernel_ku4.h" - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_moe { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / blockDim.x; - int rest = size_k % blockDim.x; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += blockDim.x; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - int expert_id = threadIdx.x; - int num_experts = blockDim.x; - - int occurrences = 0; - for (int i = 0; i < topk_length; ++i) { - occurrences += (topk_ids[i] == expert_id); - } - expert_offsets[expert_id + 1] = occurrences; - __syncthreads(); - - if (threadIdx.x == 0) { - int tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; - expert_offsets[i + 1] = tot_offset; - } - } - __syncthreads(); -} - -#else - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N - {64, 64, 128}, // Reduce both 2X -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X - {64, 64, 128}, // Reduce N 4X, same K -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - - int tb_n = th_config.thread_n; - int tb_k = th_config.thread_k; - - // Get max scale groups per thread-block - int tb_groups; - if (group_size == -1) { - tb_groups = 1; - } else if (group_size == 0) { - tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size - } else { - tb_groups = ceildiv(tb_k, group_size); - } - - if (cache_scales_chunk) { - int load_groups = - tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 4; - - } else { - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * STAGES; - } -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = ceildiv(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * STAGES; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - -#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ - else if (KERNEL_FUNCTION( \ - q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ - group_blocks, num_threads, blocks, max_shared_mem, stream, \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks)) { \ - } - -void marlin_mm_moe(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, void* zp, - const void* g_idx, const void* perm, void* a_tmp, - void* expert_offsets, int prob_m, int prob_n, int prob_k, - void* workspace, vllm::ScalarType const& q_type, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - int num_bits = q_type.size_bits(); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - int tot_m = prob_m; - - const int* topk_ids_ptr = (const int*)topk_ids; - int* expert_offsets_ptr = (int*)expert_offsets; - compute_expert_offsets<<<1, num_experts, 0, stream>>>( - topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); - - bool do_permute_a = has_act_order; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - int pack_factor = 32 / q_type.size_bits(); - - for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - const int4* A_ptr = (const int4*)A; - int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = - (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; - int4* C_ptr = (int4*)C; - const float* topk_weights_ptr = (const float*)topk_weights; - const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; - const int4* zp_ptr = - (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; - const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; - const int* perm_ptr = (const int*)perm + prob_k * expert_idx; - int* locks = (int*)workspace; - - if (do_permute_a) { - // Permute A columns - int topk_rows = replicate_input ? tot_m : tot_m * topk; - int block_rows = ceildiv(topk_rows, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - int tot_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < tot_m_blocks; - m_block += 4 * exec_cfg.max_m_blocks) { - if (false) { - } - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - } - } -} - -} // namespace marlin_moe - -torch::Tensor marlin_gemm_moe( - const torch::Tensor& a, const torch::Tensor& b_q_weights, - const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, - const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - torch::Tensor& b_zeros, const torch::Tensor& g_idx, - const torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, - int64_t moe_block_size, bool replicate_input, bool apply_weights) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - bool has_zp = b_zeros.size(1) != 0; - if (has_zp) { - TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); - } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); - } - - int pack_factor = 32 / b_q_type.size_bits(); - - int max_par = 4; - - int dev = a.get_device(); - - auto options_dtype = - torch::TensorOptions().dtype(a.dtype()).device(a.device()); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); - torch::Tensor a_tmp = - replicate_input ? torch::zeros({size_m, size_k}, options_dtype) - : torch::zeros({size_m, topk, size_k}, options_dtype); - torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(1) != 0; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); - TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), - " is not size_n = ", size_n); - num_groups = b_scales.size(1); - - TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), - "if is_k_full is false, has_act_order must be true"); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - // Verify b_zeros - if (has_zp) { - int rank = b_zeros.sizes().size(); - TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); - TORCH_CHECK(b_zeros.size(1) == num_groups, - "b_zeros dim 1 = ", b_zeros.size(1), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, - "b_zeros dim 2 = ", b_zeros.size(2), - " is not size_n / pack_factor = ", size_n / pack_factor); - } - - marlin_moe::marlin_mm_moe( - a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), - topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, - num_experts, topk, moe_block_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, - replicate_input, apply_weights); - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm_moe", &marlin_gemm_moe); -} diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 902bcd9dfd2..49f33718a21 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import glob import itertools import os @@ -31,7 +32,10 @@ # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] @@ -39,7 +43,7 @@ # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -72,6 +76,12 @@ def generate_new_kernels(): # for fp8 if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index c40c33d01f3..537282aba8c 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,17 +7,18 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index c9e199bcea1..1c255396099 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -301,9 +301,11 @@ __global__ void Marlin( int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens @@ -341,6 +343,16 @@ __global__ void Marlin( extern __shared__ int4 sh[]; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + constexpr bool has_act_order = group_blocks == 0; constexpr int pack_factor = 32 / w_type.size_bits(); @@ -348,7 +360,8 @@ __global__ void Marlin( constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int scales_expert_stride = + prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); @@ -460,9 +473,16 @@ __global__ void Marlin( if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + int idx = tid4 * 4 + i; + idx = idx < block_num_valid_tokens ? idx : 0; + if constexpr (w_type == vllm::kFE2M1f) { + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } } } } @@ -493,6 +513,11 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; if constexpr (has_zp) { @@ -606,7 +631,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -664,7 +689,8 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } @@ -688,10 +714,20 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -801,7 +837,7 @@ __global__ void Marlin( sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < act_s_max_num_groups) { + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } @@ -1021,12 +1057,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1199,22 +1242,7 @@ __global__ void Marlin( }; auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { - if constexpr (has_zp && is_zp_float || !has_zp) { - dequant(q, frag_b_ptr); - } else { - static_assert(has_zp && !is_zp_float); - static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id()); - // If (has_zp && !is_zp_float), - // we use not-zp version `dequant` function - // to improve numerical accuracy. - // Since both weight and zero point are dequanted using this logic, - // the final dequanted weight would be correct. - if constexpr (w_type_id == vllm::kU4.id()) { - dequant(q, frag_b_ptr); - } else if constexpr (w_type_id == vllm::kU8.id()) { - dequant(q, frag_b_ptr); - } - } + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. @@ -1244,13 +1272,23 @@ __global__ void Marlin( dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); } } - if constexpr (has_zp && is_zp_float) { + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } } + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1259,7 +1297,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1272,6 +1313,11 @@ __global__ void Marlin( dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { static_assert(group_blocks != -1); @@ -1279,7 +1325,8 @@ __global__ void Marlin( act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1287,7 +1334,7 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); @@ -1554,10 +1601,17 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1648,7 +1702,9 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } fetch_to_shared(i, i, i < slice_iters, i); @@ -1711,17 +1767,20 @@ __global__ void Marlin( if constexpr (has_act_order) { slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, - last_group_id); - __syncthreads(); + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } } } if (slice_iters == 0) { @@ -1737,7 +1796,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1747,7 +1807,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1771,7 +1832,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 00b4e934cc3..2cff04f699b 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -291,6 +291,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ @@ -338,6 +339,21 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) + #define BIGGROUP_GET_IF(W_TYPE) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ @@ -394,6 +410,8 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, BIGGROUP_GET_IF(vllm::kFE4M3fn) + FP4_GET_IF(vllm::kFE2M1f) + ACT_GET_IF(vllm::kU4B8) ACT_GET_IF(vllm::kU8B128) @@ -465,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, @@ -479,14 +497,16 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, bool m_block_size_8 = moe_block_size == 8; if (has_zp) { - TORCH_CHECK(q_type == vllm::kU4, - "q_type must be u4 when has_zp = True. Got = ", q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { - TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || - q_type == vllm::kFE4M3fn, - "q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", - q_type.str()); + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, @@ -519,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -627,7 +648,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); @@ -639,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -790,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -802,13 +835,14 @@ torch::Tensor moe_wna16_marlin_gemm( if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || - b_q_type == vllm::kFE4M3fn, - "b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = " - "False. Got = ", + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", b_q_type.str()); } @@ -854,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), @@ -866,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm( at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index d7be769458e..6b6a9d04a60 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, } if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, cumsum_buffer.data_ptr()); }); } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // set dynamic shared mem auto kernel = @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.numel()); }); } else { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto kernel = vllm::moe::moe_align_block_size_kernel; @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, TORCH_CHECK(num_experts == 256, "sgl_moe_align_block_size kernel only supports deepseek v3."); - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 0bae119a7c4..c4faef73106 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -28,4 +28,10 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); -#endif \ No newline at end of file +#endif + +bool moe_permute_unpermute_supported(); + +void shuffle_rows(const torch::Tensor& input_tensor, + const torch::Tensor& dst2src_map, + torch::Tensor& output_tensor); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 76d5f0eab02..68f429fac18 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -5,6 +5,9 @@ #include "permute_unpermute_kernels/dispatch.h" #include "core/registration.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_weights, //[n_token, topk] @@ -127,7 +130,101 @@ void moe_unpermute( }); } +template +__global__ void shuffleInputRowsKernel(const T* input, + const int32_t* dst2src_map, T* output, + int64_t num_src_rows, + int64_t num_dst_rows, int64_t num_cols) { + int64_t dest_row_idx = blockIdx.x; + int64_t const source_row_idx = dst2src_map[dest_row_idx]; + + if (blockIdx.x < num_dst_rows) { + // Load 128-bits per thread + constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + auto const* source_row_ptr = + reinterpret_cast(input + source_row_idx * num_cols); + auto* dest_row_ptr = + reinterpret_cast(output + dest_row_idx * num_cols); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD; + + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +void shuffle_rows(const torch::Tensor& input_tensor, + const torch::Tensor& dst2src_map, + torch::Tensor& output_tensor) { + TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(), + "Input and output tensors must have the same data type"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int64_t const blocks = output_tensor.size(0); + int64_t const threads = 256; + int64_t const num_dest_rows = output_tensor.size(0); + int64_t const num_src_rows = input_tensor.size(0); + int64_t const num_cols = input_tensor.size(1); + + TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)), + "num_cols must be divisible by 128 / " + "sizeof(input_tensor.scalar_type()) / 8"); + + MOE_DISPATCH(input_tensor.scalar_type(), [&] { + shuffleInputRowsKernel<<>>( + reinterpret_cast(input_tensor.data_ptr()), + dst2src_map.data_ptr(), + reinterpret_cast(output_tensor.data_ptr()), num_src_rows, + num_dest_rows, num_cols); + }); +} + +#else + +void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, + torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +void moe_unpermute(const torch::Tensor& input, + const torch::Tensor& topk_weights, torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +#endif + +bool moe_permute_unpermute_supported() { +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + return true; +#else + return false; +#endif +} + TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} \ No newline at end of file +} diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h index 41932cdd85b..d0f1ea4aded 100644 --- a/csrc/moe/permute_unpermute_kernels/dispatch.h +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -14,12 +14,13 @@ __VA_ARGS__(); \ break; \ } -#define MOE_DISPATCH_FLOAT_CASE(...) \ - MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) +#define MOE_DISPATCH_FLOAT_CASE(...) \ + MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) #define MOE_DISPATCH(TYPE, ...) \ MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) @@ -39,6 +40,11 @@ template <> struct ScalarType2CudaType { using type = __nv_bfloat16; }; +// uint8 for packed fp4 +template <> +struct ScalarType2CudaType { + using type = uint8_t; +}; // #if __CUDA_ARCH__ >= 890 // fp8 diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index aa353d0f043..de2c153882d 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -1,6 +1,9 @@ #include "moe_permute_unpermute_kernel.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + // CubKeyValueSorter definition begin CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} @@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, int num_experts) { auto tidx = threadIdx.x; auto bidx = blockIdx.x; - auto lidx = tidx & 31; - auto widx = tidx >> 5; - auto warp_count = (blockDim.x + 31) >> 5; auto offset = bidx * blockDim.x; auto bound = min(offset + blockDim.x, size); extern __shared__ int smem_expert_map[]; @@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset, expert_first_token_offset, align_expert_first_token_offset, m_indices, num_local_expert, align_block_size); } -} \ No newline at end of file +} + +#endif diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index de9747b6025..10be47966f6 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__ } } -template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_softmax, + const bool* finished, + float* output, + IndType* indices, + int* source_rows, + const int num_experts, + const int k, + const int start_expert, + const int end_expert) { using cub_kvp = cub::KeyValuePair; @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. @@ -397,8 +405,8 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f token_expert_indices, num_tokens, topk, 0, num_experts, \ stream); +template void topkGatingSoftmaxKernelLauncher( const float* gating_output, float* topk_weights, - int* topk_indicies, + IndType* topk_indicies, int* token_expert_indices, float* softmax_workspace, const int num_tokens, @@ -493,14 +502,44 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + + if(topk_indices.scalar_type() == at::ScalarType::Int) + { + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } + else if (topk_indices.scalar_type() == at::ScalarType::UInt32) + { + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } + else { + assert(topk_indices.scalar_type() == at::ScalarType::Int64); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 2a8b9bb39ca..a74eb3720cf 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -10,7 +10,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Calculate the result of moe by summing up the partial results // from all selected experts. - m.def("moe_sum(Tensor! input, Tensor output) -> ()"); + m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kCUDA, &moe_sum); // Aligning the number of tokens to be processed by each expert such @@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," @@ -76,7 +77,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " "expert_first_token_offset, int n_expert, int n_local_expert,int " "topk, Tensor! hidden_states)->()"); - // conditionally compiled so impl registration is in source file + + m.def("moe_permute_unpermute_supported() -> bool"); + m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); + + // Row shuffle for MoE + m.def( + "shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! " + "output_tensor) -> ()"); + m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 59ae0937604..6905ef6e591 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); + +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal); + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, int64_t context_size, + int64_t block_size_M, int64_t block_size_N, bool causal); #endif void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, @@ -67,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void apply_repetition_penalties_(torch::Tensor& logits, + const torch::Tensor& prompt_mask, + const torch::Tensor& output_mask, + const torch::Tensor& repetition_penalties); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); @@ -86,13 +116,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out, std::optional residual); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox, - int64_t rot_dim, + std::optional key, + int64_t head_size, torch::Tensor& cos_sin_cache, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -178,6 +208,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, torch::Tensor num_tokens_post_padded, int64_t type, int64_t row, int64_t top_k, int64_t tokens); +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens); + int64_t ggml_moe_get_block_size(int64_t type); #ifndef USE_ROCM @@ -204,11 +238,18 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); + const int64_t num_experts, const int64_t n, const int64_t k, + const std::optional& blockscale_offsets); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, @@ -231,6 +272,12 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a); void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index c085d31a3e9..266f2a0667a 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { + const int64_t query_stride, const int64_t key_stride, + const int64_t head_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = + token_idx * query_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( - key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = + token_idx * key_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } } } @@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } template @@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - // or [num_tokens] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } } // namespace vllm @@ -127,10 +136,12 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { @@ -138,40 +149,46 @@ void rotary_embedding( int64_t num_tokens = positions.numel(); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key // hidden_size = num_heads * head_size int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have consistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int rot_dim = cos_sin_cache.size(1); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -181,15 +198,16 @@ void rotary_embedding( if (is_neox) { vllm::rotary_embedding_kernel<<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, + head_stride, num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } @@ -204,10 +222,12 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional + key, // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, @@ -221,38 +241,44 @@ void batched_rotary_embedding( "cos_sin_cache_offsets"); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have concistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -263,16 +289,18 @@ void batched_rotary_embedding( vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index acc3d672202..67e9149c137 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel( void silu_and_mul_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., 2 * d] torch::Tensor& scale) { - TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || + out.dtype() == torch::kFloat8_e4m3fnuz); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.size(-1) % 2 == 0); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index e7978582718..bf46cce60a2 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); + dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else // CUDA path @@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - int32_t dst = std::clamp(x, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // int32_t dst = std::clamp(x, i8_min, i8_max); + int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; return static_cast(dst); #else // CUDA path diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 00000000000..4a8a5ed02d6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,23 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 00000000000..c841125dbb7 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,279 @@ +#pragma once + +#include "cuda_utils.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +// clang-format off +template +struct cutlass_3x_gemm_fp8_blockwise { + static constexpr bool swap_ab = swap_ab_; + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; // TODO: support bias + using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + using ScaleConfig = conditional_t, + cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>>; + + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + conditional_t, + AlignmentC, + ElementD, + conditional_t, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = conditional_t, + AlignmentB, + ElementA, + cute::tuple, + AlignmentA, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp>; + + using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + static constexpr bool swap_ab = Gemm::swap_ab; + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + auto mainloop_args = [&](){ + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + if (swap_ab) { + return typename GemmKernel::MainloopArguments{ + b_ptr, b_stride, a_ptr, a_stride, + b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB + }; + } + else { + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + } + }(); + auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + constexpr int TILE_K = 128; + // TODO: better heuristics + bool swap_ab = (m < 16) || (m % 4 != 0); + bool use_tma_epilogue = (m * n) % 4 == 0; + if (!swap_ab) { + constexpr int TILE_N = 128; + int tile_m = 256; + if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) { + tile_m = 64; + } + else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) { + tile_m = 128; + } + if (tile_m == 64) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else if (tile_m == 128) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else { // tile_m == 256 + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } + } + } else { + // TODO: Test more tile N configs + constexpr int TILE_M = 128; + constexpr int TILE_N = 16; + // TMA epilogue isn't compatible with Swap A/B + cutlass_gemm_caller_blockwise, Int, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 00000000000..2ee6a19407f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,75 @@ +#include +#include "cuda_utils.h" +#include "cutlass_extensions/common.hpp" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + int32_t version_num = get_sm_version_num(); + if (version_num >= 100) { + TORCH_CHECK( + a.size(0) == a_scales.size(0) && + cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), + "a_scale_group_shape must be [1, 128]."); + TORCH_CHECK( + cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && + cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), + "b_scale_group_shape must be [128, 128]."); + } else { + // TODO: Remove this after using cutlass sm90 blockwise scaling gemm + // kernel, or introducing ceil_div to the load_init() of mainloop. + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + } + + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774..c1242fdb39d 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 468b77d9593..6da2da63407 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { + // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; @@ -25,6 +26,34 @@ struct sm100_fp8_config_default { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm100_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + template typename Epilogue, typename... EpilogueArgs> @@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + using Cutlass3xGemmM64 = + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm100_fp8_config_M128::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } } template