Skip to content

Commit 8401e91

Browse files
authored
Add benchmark numbers to dashboard (#2260)
1 parent 0da65f8 commit 8401e91

File tree

3 files changed

+250
-0
lines changed

3 files changed

+250
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
name: Microbenchmarks-Perf-Nightly
2+
# Dashboard: https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fao&benchmarkName=micro-benchmark+api
3+
4+
on:
5+
pull_request:
6+
push:
7+
tags:
8+
- ciflow/benchmark/*
9+
workflow_dispatch:
10+
schedule:
11+
- cron: '0 3 * * *' # Run daily at 7 AM UTC
12+
13+
jobs:
14+
benchmark:
15+
runs-on: linux.aws.h100
16+
strategy:
17+
matrix:
18+
torch-spec:
19+
- '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126'
20+
steps:
21+
- uses: actions/checkout@v4
22+
23+
- name: Setup miniconda
24+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
25+
with:
26+
python-version: "3.9"
27+
28+
- name: Run benchmark
29+
shell: bash
30+
run: |
31+
set -eux
32+
33+
# Upgrade pip
34+
${CONDA_RUN} python -m pip install --upgrade pip
35+
36+
${CONDA_RUN} ls
37+
${CONDA_RUN} bash -c 'pwd'
38+
${CONDA_RUN} bash -c 'echo $PYTHONPATH'
39+
40+
# Install dependencies
41+
${CONDA_RUN} pip install ${{ matrix.torch-spec }}
42+
${CONDA_RUN} pip install -r dev-requirements.txt
43+
${CONDA_RUN} pip install .
44+
45+
${CONDA_RUN} ls
46+
${CONDA_RUN} bash -c 'pwd'
47+
${CONDA_RUN} bash -c 'echo $PYTHONPATH'
48+
49+
# Set PYTHONPATH to current directory (.) if not set, and include the benchmarks directory
50+
${CONDA_RUN} export PYTHONPATH="${PYTHONPATH:-$(pwd)}:$(pwd)/benchmarks"
51+
52+
# Create benchmark results directory
53+
mkdir -p ${{ runner.temp }}/benchmark-results
54+
55+
# Run microbenchmarks for dashboard
56+
${CONDA_RUN} bash -c '
57+
export PYTHONPATH="${PYTHONPATH:-$(pwd)}:$(pwd)/benchmarks"
58+
echo "PYTHONPATH is: $PYTHONPATH"
59+
echo "Current directory is: $(pwd)"
60+
python benchmarks/dashboard/ci_microbenchmark_runner.py \
61+
--config benchmarks/dashboard/microbenchmark_quantization_config.yml \
62+
--output "$RUNNER_TEMP/benchmark-results/microbenchmark-results.json"'
63+
64+
- name: Upload the benchmark results to OSS benchmark database for the dashboard
65+
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
66+
with:
67+
benchmark-results-dir: ${{ runner.temp }}/benchmark-results
68+
dry-run: false
69+
schema-version: v3
70+
github-token: ${{ secrets.GITHUB_TOKEN }}
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
CI Microbenchmark Runner for PyTorch OSS Benchmark Database
10+
11+
This script runs microbenchmarks for a given config file
12+
and outputs results in the format required by the PyTorch OSS benchmark database.
13+
It reuses functionality from benchmark_runner.py and only adds CI-specific code.
14+
15+
Usage:
16+
python ci_microbenchmark_runner.py --config benchmark_config.yml
17+
18+
The YAML file should contain all necessary configuration parameters for the benchmarks.
19+
"""
20+
21+
import argparse
22+
import json
23+
import platform
24+
from typing import Any, Dict, List
25+
26+
import torch
27+
28+
from benchmarks.microbenchmarks.benchmark_inference import run as run_inference
29+
from benchmarks.microbenchmarks.benchmark_runner import (
30+
load_benchmark_configs,
31+
)
32+
from benchmarks.microbenchmarks.utils import clean_caches
33+
34+
35+
def create_benchmark_result(
36+
benchmark_name: str,
37+
shape: List[int],
38+
metric_name: str,
39+
metric_values: List[float],
40+
quant_type: str,
41+
device: str,
42+
) -> Dict[str, Any]:
43+
"""Create a benchmark result in the PyTorch OSS benchmark database format.
44+
45+
Args:
46+
benchmark_name: Name of the benchmark
47+
shape: List of shape dimensions [M, K, N]
48+
metric_name: Name of the metric
49+
metric_values: List of metric values
50+
quant_type: Quantization type
51+
device: Device type (cuda/cpu)
52+
53+
Returns:
54+
Dictionary containing the benchmark result in the required format
55+
"""
56+
print(
57+
f"Creating benchmark result for {benchmark_name} with shape {shape} and metric {metric_name}"
58+
)
59+
60+
# Map device to benchmark device name
61+
benchmark_device = (
62+
torch.cuda.get_device_name(0)
63+
if device == "cuda"
64+
else platform.processor()
65+
if device == "cpu"
66+
else "unknown"
67+
)
68+
69+
# Format shape as M-K-N
70+
mkn_name = f"{shape[0]}-{shape[1]}-{shape[2]}" if len(shape) == 3 else "unknown"
71+
72+
return {
73+
"benchmark": {
74+
"name": "micro-benchmark api",
75+
"mode": "inference",
76+
"dtype": quant_type,
77+
"extra_info": {
78+
"device": device,
79+
"arch": benchmark_device,
80+
},
81+
},
82+
"model": {
83+
"name": mkn_name, # name in M-K-N format
84+
"type": "micro-benchmark custom layer", # type
85+
"origins": ["torchao"],
86+
},
87+
"metric": {
88+
"name": f"{metric_name}(wrt bf16)", # name with unit
89+
"benchmark_values": metric_values, # benchmark_values
90+
"target_value": 0.0, # TODO: Will need to define the target value
91+
},
92+
"runners": [],
93+
"dependencies": {},
94+
}
95+
96+
97+
def run_ci_benchmarks(config_path: str) -> List[Dict[str, Any]]:
98+
"""Run benchmarks using configurations from YAML file and return results in OSS format.
99+
100+
Args:
101+
config_path: Path to the benchmark configuration file
102+
103+
Returns:
104+
List of benchmark results in the PyTorch OSS benchmark database format
105+
"""
106+
# Load configuration using existing function
107+
configs = load_benchmark_configs(argparse.Namespace(config=config_path))
108+
results = []
109+
110+
# Run benchmarks for each config
111+
for config in configs:
112+
# Run benchmark using existing function
113+
clean_caches()
114+
result = run_inference(config)
115+
116+
if result is not None:
117+
# Create benchmark result in OSS format
118+
benchmark_result = create_benchmark_result(
119+
benchmark_name="TorchAO Quantization Benchmark",
120+
shape=[config.m, config.k, config.n],
121+
metric_name="speedup",
122+
metric_values=[result.speedup],
123+
quant_type=config.quantization,
124+
device=config.device,
125+
)
126+
results.append(benchmark_result)
127+
128+
return results
129+
130+
131+
def main():
132+
parser = argparse.ArgumentParser(
133+
description="Run microbenchmarks and output results in PyTorch OSS benchmark database format"
134+
)
135+
parser.add_argument(
136+
"--config",
137+
type=str,
138+
required=True,
139+
help="Path to benchmark configuration file",
140+
)
141+
parser.add_argument(
142+
"--output",
143+
type=str,
144+
default="benchmark_results.json",
145+
help="Path to output JSON file",
146+
)
147+
args = parser.parse_args()
148+
149+
# Run benchmarks
150+
results = run_ci_benchmarks(args.config)
151+
152+
# Save results to JSON file
153+
with open(args.output, "w") as f:
154+
json.dump(results, f, indent=2)
155+
156+
print(f"Benchmark results saved to {args.output}")
157+
158+
159+
if __name__ == "__main__":
160+
main()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Benchmark configuration for microbenchmarks
2+
benchmark_mode: "inference"
3+
quantization_config_recipe_names: # Will run a baseline inference for model by default, without quantization for comparison
4+
- "int8wo"
5+
- "int8dq"
6+
- "float8dq-tensor"
7+
- "float8dq-row"
8+
- "float8wo"
9+
output_dir: "benchmarks/microbenchmarks/results"
10+
model_params:
11+
- name: "small_bf16_linear"
12+
matrix_shapes:
13+
- name: "small_sweep"
14+
min_power: 10
15+
max_power: 15
16+
high_precision_dtype: "torch.bfloat16"
17+
use_torch_compile: true
18+
torch_compile_mode: "max-autotune"
19+
device: "cuda"
20+
model_type: "linear"

0 commit comments

Comments
 (0)