-
Notifications
You must be signed in to change notification settings - Fork 305
Tutorial for benchmarking #2499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
e7b20cc
a6a2ae0
cde732c
9f994f1
ed6f659
41b9986
b8564ca
98eef30
8c05b7f
17a11a0
7df5b33
24eeb0f
5c94cee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
# Benchmarking Overview | ||
|
||
This tutorial will guide you through using the TorchAO benchmarking framework. The tutorial contains integrating new APIs with the framework and dashboard. | ||
|
||
1. [Add an API to benchmarking recipes](#add-an-api-to-benchmarking-recipes) | ||
2. [Add a model architecture for benchmarking recipes](#add-a-model-to-benchmarking-recipes) | ||
3. [Add an HF model to benchmarking recipes](#add-an-hf-model-to-benchmarking-recipes) | ||
4. [Add an API to micro-benchmarking CI dashboard](#add-an-api-to-benchmarking-ci-dashboard) | ||
|
||
## Add an API to Benchmarking Recipes | ||
|
||
The framework currently supports quantization and sparsity recipes, which can be run using the quantize_() or sparsity_() functions: | ||
|
||
To add a new recipe, add the corresponding string configuration to the function `string_to_config()` in `benchmarks/microbenchmarks/utils.py`. | ||
|
||
```python | ||
def string_to_config( | ||
quantization: Optional[str], sparsity: Optional[str], **kwargs | ||
) -> AOBaseConfig: | ||
|
||
# ... existing code ... | ||
|
||
elif quantization == "my_new_quantization": | ||
# If additional information needs to be passed as kwargs, process it here | ||
return MyNewQuantizationConfig(**kwargs) | ||
elif sparsity == "my_new_sparsity": | ||
return MyNewSparsityConfig(**kwargs) | ||
|
||
# ... rest of existing code ... | ||
``` | ||
|
||
Now we can use this recipe throughout the benchmarking framework. | ||
|
||
> **Note:** If the `AOBaseConfig` uses input parameters, like bit-width, group-size etc, you can pass them appended to the string config in input. For example, for `GemliteUIntXWeightOnlyConfig` we can pass bit-width and group-size as `gemlitewo-<bit_width>-<group_size>` | ||
|
||
## Add a Model to Benchmarking Recipes | ||
|
||
To add a new model architecture to the benchmarking system, you need to modify `torchao/testing/model_architectures.py`. | ||
|
||
1. To add a new model type, define your model class in `torchao/testing/model_architectures.py`: | ||
|
||
```python | ||
class MyCustomModel(torch.nn.Module): | ||
def __init__(self, input_dim, output_dim, dtype=torch.bfloat16): | ||
super().__init__() | ||
# Define your model architecture | ||
self.layer1 = torch.nn.Linear(input_dim, 512, bias=False).to(dtype) | ||
self.activation = torch.nn.ReLU() | ||
self.layer2 = torch.nn.Linear(512, output_dim, bias=False).to(dtype) | ||
|
||
def forward(self, x): | ||
x = self.layer1(x) | ||
x = self.activation(x) | ||
x = self.layer2(x) | ||
return x | ||
``` | ||
|
||
2. Update the `create_model_and_input_data` function to handle your new model type: | ||
|
||
```python | ||
def create_model_and_input_data( | ||
model_type: str, | ||
m: int, | ||
k: int, | ||
n: int, | ||
high_precision_dtype: torch.dtype = torch.bfloat16, | ||
device: str = "cuda", | ||
activation: str = "relu", | ||
): | ||
# ... existing code ... | ||
|
||
elif model_type == "my_custom_model": | ||
model = MyCustomModel(k, n, high_precision_dtype).to(device) | ||
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) | ||
|
||
# ... rest of existing code ... | ||
``` | ||
|
||
### Model Design Considerations | ||
|
||
When adding new models: | ||
|
||
- **Input/Output Dimensions**: Ensure your model handles the (m, k, n) dimension convention where: | ||
- `m`: Batch size or sequence length | ||
- `k`: Input feature dimension | ||
- `n`: Output feature dimension | ||
|
||
- **Data Types**: Support the `high_precision_dtype` parameter (typically `torch.bfloat16`) | ||
|
||
- **Device Compatibility**: Ensure your model works on CUDA, CPU, and other target devices | ||
|
||
- **Quantization Compatibility**: Design your model to work with TorchAO quantization methods | ||
|
||
## Add an HF model to benchmarking recipes | ||
(Coming soon!!!) | ||
|
||
## Add an API to Benchmarking CI Dashboard | ||
|
||
To integrate your API with the CI [dashboard](https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fao&benchmarkName=micro-benchmark+api): | ||
|
||
### 1. Modify Existing CI Configuration | ||
|
||
Add your quantization method to the existing CI configuration file at `benchmarks/dashboard/microbenchmark_quantization_config.yml`: | ||
|
||
```yaml | ||
# benchmarks/dashboard/microbenchmark_quantization_config.yml | ||
benchmark_mode: "inference" | ||
quantization_config_recipe_names: | ||
- "int8wo" | ||
- "int8dq" | ||
- "float8dq-tensor" | ||
- "float8dq-row" | ||
- "float8wo" | ||
- "my_new_quantization" # Add your method here | ||
|
||
output_dir: "benchmarks/microbenchmarks/results" | ||
|
||
model_params: | ||
- name: "small_bf16_linear" | ||
matrix_shapes: | ||
- name: "small_sweep" | ||
min_power: 10 | ||
max_power: 15 | ||
high_precision_dtype: "torch.bfloat16" | ||
use_torch_compile: true | ||
torch_compile_mode: "max-autotune" | ||
device: "cuda" | ||
model_type: "linear" | ||
``` | ||
|
||
### 2. Run CI Benchmarks | ||
|
||
Use the CI runner to generate results in PyTorch OSS benchmark database format: | ||
|
||
```bash | ||
python benchmarks/dashboard/ci_microbenchmark_runner.py \ | ||
--config benchmarks/dashboard/microbenchmark_quantization_config.yml \ | ||
--output benchmark_results.json | ||
``` | ||
|
||
### 3. CI Output Format | ||
|
||
The CI runner outputs results in a specific JSON format required by the PyTorch OSS benchmark database: | ||
|
||
```json | ||
[ | ||
{ | ||
"benchmark": { | ||
"name": "micro-benchmark api", | ||
"mode": "inference", | ||
"dtype": "int8wo", | ||
"extra_info": { | ||
"device": "cuda", | ||
"arch": "NVIDIA A100-SXM4-80GB" | ||
} | ||
}, | ||
"model": { | ||
"name": "1024-1024-1024", | ||
"type": "micro-benchmark custom layer", | ||
"origins": ["torchao"] | ||
}, | ||
"metric": { | ||
"name": "speedup(wrt bf16)", | ||
"benchmark_values": [1.25], | ||
"target_value": 0.0 | ||
}, | ||
"runners": [], | ||
"dependencies": {} | ||
} | ||
] | ||
``` | ||
|
||
### 4. Integration with CI Pipeline | ||
|
||
To integrate with your CI pipeline, add the benchmark step to your workflow: | ||
|
||
```yaml | ||
# Example GitHub Actions step | ||
- name: Run Microbenchmarks | ||
run: | | ||
python benchmarks/dashboard/ci_microbenchmark_runner.py \ | ||
--config benchmarks/dashboard/microbenchmark_quantization_config.yml \ | ||
--output benchmark_results.json | ||
|
||
- name: Upload Results | ||
# Upload benchmark_results.json to your dashboard system | ||
``` | ||
|
||
## Troubleshooting | ||
|
||
### Running Tests | ||
|
||
To verify your setup and run the test suite: | ||
|
||
```bash | ||
python -m unittest discover benchmarks/microbenchmarks/test | ||
``` | ||
|
||
### Common Issues | ||
|
||
1. **CUDA Out of Memory**: Reduce batch size or matrix dimensions | ||
2. **Compilation Errors**: Set `use_torch_compile: false` for debugging | ||
3. **Missing Quantization Methods**: Ensure TorchAO is properly installed | ||
4. **Device Not Available**: Check device availability and drivers | ||
|
||
### Best Practices | ||
|
||
1. Use `small_sweep` for basic testing, `custom shapes` for comprehensive or model specific analysis | ||
2. Enable profiling only when needed (adds overhead) | ||
3. Test on multiple devices when possible | ||
4. Use consistent naming conventions for reproducibility | ||
|
||
For information on different use-cases for benchmarking, refer to [Benchmarking Use-Case FAQs](benchmarking_user_faq.md) | ||
|
||
For more detailed information about the framework components, see the README files in the `benchmarks/microbenchmarks/` directory. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Benchmarking Use-Case FAQs | ||
|
||
This guide is intended to provide instructions for the most fequent benchmarking use-case. If you have any use-case that is not answered here, please create an issue here: [TorchAO Issues](https://github.com/pytorch/ao/issues) | ||
|
||
## Table of Contents | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is closer to use case but still not really end use cases yet I feel, I think it might be helpful to describe scenarios like: (2) kernel optimizations (3) performance regression tracking (4) end users There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will be addressed in PR #2512 |
||
- [Run the performance benchmarking on your PR](#run-the-performance-benchmarking-on-your-pr) | ||
- [Benchmark Your API Locally](#benchmark-your-api-locally) | ||
- [Generate evaluation metrics for your quantized model](#generate-evaluation-metrics-for-your-quantized-model) | ||
- [Advanced Usage](#advanced-usage) | ||
|
||
## Run the performance benchmarking on your PR | ||
|
||
### 1. Add label to your PR | ||
To trigger the benchmarking CI workflow on your pull request, you need to add a specific label to your PR. Follow these steps: | ||
|
||
1. Go to your pull request on GitHub. | ||
2. On the right sidebar, find the "Labels" section. | ||
3. Click on the "Labels" dropdown and select "ciflow/benchmark" from the list of available labels. | ||
|
||
Adding this label will automatically trigger the benchmarking CI workflow for your pull request. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add where can we see the results as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also will it run after each new commit is added or if the commit is updated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we add label, it'll run for every commit we add to the PR |
||
|
||
### 2. Manually trigger benchmarking workflow on your github branch | ||
To manually trigger the benchmarking workflow for your branch, follow these steps: | ||
|
||
1. Navigate to the "Actions" tab in your GitHub repository. | ||
2. Select the benchmarking workflow from the list of available workflows. For microbenchmarks, it's `Microbenchmarks-Perf-Nightly`. | ||
3. Click on the "Run workflow" button. | ||
4. In the dropdown menu, select the branch. | ||
5. Click the "Run workflow" button to start the benchmarking process. | ||
|
||
This will execute the benchmarking workflow on the specified branch, allowing you to evaluate the performance of your changes. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens when people: (1). push a new commit ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also when do we use (1) and when do we use (2)? seems like they are doing the same thing, if so maybe just keep one is enough There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will run only for the latest change on the branch, it won't trigger automatically on every commit, for that we'll need to add label to the PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in PR: #2512 |
||
|
||
## Benchmark Your API Locally | ||
|
||
For local development and testing: | ||
|
||
### 1. Quick Start | ||
|
||
Create a minimal configuration for local testing: | ||
|
||
```yaml | ||
# local_test.yml | ||
benchmark_mode: "inference" | ||
quantization_config_recipe_names: | ||
- "baseline" | ||
- "int8wo" | ||
# Add your recipe here | ||
|
||
output_dir: "local_results" # Add your output directory here | ||
|
||
model_params: | ||
# Add your model configurations here | ||
- name: "quick_test" | ||
matrix_shapes: | ||
# Define a custom shape, or use one of the predefined shape generators | ||
- name: "custom" | ||
shapes: [[1024, 1024, 1024]] | ||
- name: "small_sweep" | ||
high_precision_dtype: "torch.bfloat16" | ||
use_torch_compile: true | ||
torch_compile_mode: "max-autotune" | ||
device: "cuda" | ||
model_type: "linear" | ||
enable_profiler: true # Enable profiling for this model | ||
enable_memory_profiler: true # Enable memory profiling for this model | ||
``` | ||
|
||
> **Note:** | ||
> - For a list of latest supported config recipes for quantization or sparsity, please refer to `benchmarks/microbenchmarks/README.md`. | ||
> - For a list of all model types, please refer to `torchao/testing/model_architectures.py`. | ||
|
||
### 2. Run Local Benchmark | ||
|
||
```bash | ||
python -m benchmarks.microbenchmarks.benchmark_runner --config local_test.yml | ||
``` | ||
|
||
### 3. Analysing the Output | ||
|
||
The output generated after running the benchmarking script, is the form of a csv. It'll contain some of the following: | ||
- time for inference for running baseline model and quantized model | ||
- speedup in inference time in quantized model | ||
- compile or eager mode | ||
- if enabled, memory snapshot and gpu chrome trace | ||
|
||
|
||
## Generate evaluation metrics for your quantized model | ||
(Coming soon!!!) | ||
|
||
## Advanced Usage | ||
|
||
### Multiple Model Configurations | ||
|
||
You can benchmark multiple model configurations in a single run: | ||
|
||
```yaml | ||
model_params: | ||
- name: "small_models" | ||
matrix_shapes: | ||
- name: "pow2" | ||
min_power: 10 | ||
max_power: 12 | ||
model_type: "linear" | ||
device: "cuda" | ||
|
||
- name: "transformer_models" | ||
matrix_shapes: | ||
- name: "llama" | ||
model_type: "transformer_block" | ||
device: "cuda" | ||
|
||
- name: "cpu_models" | ||
matrix_shapes: | ||
- name: "custom" | ||
shapes: [[512, 512, 512]] | ||
model_type: "linear" | ||
device: "cpu" | ||
``` | ||
|
||
### Interpreting Results | ||
|
||
The benchmark results include: | ||
|
||
- **Speedup**: Performance improvement compared to baseline (bfloat16) | ||
- **Memory Usage**: Peak memory consumption during inference | ||
- **Latency**: Time taken for inference operations | ||
- **Profiling Data**: Detailed performance traces (when enabled) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there some duplicationg between these and L80-L84 |
||
|
||
Results are saved in CSV format with columns for: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think you can just show a small example output here |
||
|
||
- Model configuration | ||
- Quantization method | ||
- Shape dimensions (M, K, N) | ||
- Performance metrics | ||
- Memory metrics | ||
- Device information | ||
|
||
### Best Practices | ||
|
||
1. Use `small_sweep` for initial testing, `sweep` for comprehensive analysis | ||
2. Enable profiling only when needed (adds overhead) | ||
3. Test on multiple devices when possible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also
API Guide
might be more accurate