|
| 1 | +# Benchmarking API Guide |
| 2 | + |
| 3 | +This tutorial will guide you through using the TorchAO benchmarking framework. The tutorial contains integrating new APIs with the framework and dashboard. |
| 4 | + |
| 5 | +1. [Add an API to benchmarking recipes](#add-an-api-to-benchmarking-recipes) |
| 6 | +2. [Add a model architecture for benchmarking recipes](#add-a-model-to-benchmarking-recipes) |
| 7 | +3. [Add an HF model to benchmarking recipes](#add-an-hf-model-to-benchmarking-recipes) |
| 8 | +4. [Add an API to micro-benchmarking CI dashboard](#add-an-api-to-benchmarking-ci-dashboard) |
| 9 | + |
| 10 | +## Add an API to Benchmarking Recipes |
| 11 | + |
| 12 | +The framework currently supports quantization and sparsity recipes, which can be run using the quantize_() or sparsity_() functions: |
| 13 | + |
| 14 | +To add a new recipe, add the corresponding string configuration to the function `string_to_config()` in `benchmarks/microbenchmarks/utils.py`. |
| 15 | + |
| 16 | +```python |
| 17 | +def string_to_config( |
| 18 | + quantization: Optional[str], sparsity: Optional[str], **kwargs |
| 19 | +) -> AOBaseConfig: |
| 20 | + |
| 21 | +# ... existing code ... |
| 22 | + |
| 23 | +elif quantization == "my_new_quantization": |
| 24 | + # If additional information needs to be passed as kwargs, process it here |
| 25 | + return MyNewQuantizationConfig(**kwargs) |
| 26 | +elif sparsity == "my_new_sparsity": |
| 27 | + return MyNewSparsityConfig(**kwargs) |
| 28 | + |
| 29 | +# ... rest of existing code ... |
| 30 | +``` |
| 31 | + |
| 32 | +Now we can use this recipe throughout the benchmarking framework. |
| 33 | + |
| 34 | +**Note:** If the `AOBaseConfig` uses input parameters, like bit-width, group-size etc, you can pass them appended to the string config in input. For example, for `GemliteUIntXWeightOnlyConfig` we can pass bit-width and group-size as `gemlitewo-<bit_width>-<group_size>` |
| 35 | + |
| 36 | +## Add a Model to Benchmarking Recipes |
| 37 | + |
| 38 | +To add a new model architecture to the benchmarking system, you need to modify `torchao/testing/model_architectures.py`. |
| 39 | + |
| 40 | +1. To add a new model type, define your model class in `torchao/testing/model_architectures.py`: |
| 41 | + |
| 42 | +```python |
| 43 | +class MyCustomModel(torch.nn.Module): |
| 44 | + def __init__(self, input_dim, output_dim, dtype=torch.bfloat16): |
| 45 | + super().__init__() |
| 46 | + # Define your model architecture |
| 47 | + self.layer1 = torch.nn.Linear(input_dim, 512, bias=False).to(dtype) |
| 48 | + self.activation = torch.nn.ReLU() |
| 49 | + self.layer2 = torch.nn.Linear(512, output_dim, bias=False).to(dtype) |
| 50 | + |
| 51 | + def forward(self, x): |
| 52 | + x = self.layer1(x) |
| 53 | + x = self.activation(x) |
| 54 | + x = self.layer2(x) |
| 55 | + return x |
| 56 | +``` |
| 57 | + |
| 58 | +2. Update the `create_model_and_input_data` function to handle your new model type: |
| 59 | + |
| 60 | +```python |
| 61 | +def create_model_and_input_data( |
| 62 | + model_type: str, |
| 63 | + m: int, |
| 64 | + k: int, |
| 65 | + n: int, |
| 66 | + high_precision_dtype: torch.dtype = torch.bfloat16, |
| 67 | + device: str = "cuda", |
| 68 | + activation: str = "relu", |
| 69 | +): |
| 70 | + # ... existing code ... |
| 71 | + |
| 72 | + elif model_type == "my_custom_model": |
| 73 | + model = MyCustomModel(k, n, high_precision_dtype).to(device) |
| 74 | + input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) |
| 75 | + |
| 76 | + # ... rest of existing code ... |
| 77 | +``` |
| 78 | + |
| 79 | +### Model Design Considerations |
| 80 | + |
| 81 | +When adding new models: |
| 82 | + |
| 83 | +- **Input/Output Dimensions**: Ensure your model handles the (m, k, n) dimension convention where: |
| 84 | + - `m`: Batch size or sequence length |
| 85 | + - `k`: Input feature dimension |
| 86 | + - `n`: Output feature dimension |
| 87 | + |
| 88 | +- **Data Types**: Support the `high_precision_dtype` parameter (typically `torch.bfloat16`) |
| 89 | + |
| 90 | +- **Device Compatibility**: Ensure your model works on CUDA, CPU, and other target devices |
| 91 | + |
| 92 | +- **Quantization Compatibility**: Design your model to work with TorchAO quantization methods |
| 93 | + |
| 94 | +## Add an HF model to benchmarking recipes |
| 95 | +(Coming soon!!!) |
| 96 | + |
| 97 | +## Add an API to Benchmarking CI Dashboard |
| 98 | + |
| 99 | +To integrate your API with the CI [dashboard](https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fao&benchmarkName=micro-benchmark+api): |
| 100 | + |
| 101 | +### 1. Modify Existing CI Configuration |
| 102 | + |
| 103 | +Add your quantization method to the existing CI configuration file at `benchmarks/dashboard/microbenchmark_quantization_config.yml`: |
| 104 | + |
| 105 | +```yaml |
| 106 | +# benchmarks/dashboard/microbenchmark_quantization_config.yml |
| 107 | +benchmark_mode: "inference" |
| 108 | +quantization_config_recipe_names: |
| 109 | + - "int8wo" |
| 110 | + - "int8dq" |
| 111 | + - "float8dq-tensor" |
| 112 | + - "float8dq-row" |
| 113 | + - "float8wo" |
| 114 | + - "my_new_quantization" # Add your method here |
| 115 | + |
| 116 | +output_dir: "benchmarks/microbenchmarks/results" |
| 117 | + |
| 118 | +model_params: |
| 119 | + - name: "small_bf16_linear" |
| 120 | + matrix_shapes: |
| 121 | + - name: "small_sweep" |
| 122 | + min_power: 10 |
| 123 | + max_power: 15 |
| 124 | + high_precision_dtype: "torch.bfloat16" |
| 125 | + use_torch_compile: true |
| 126 | + torch_compile_mode: "max-autotune" |
| 127 | + device: "cuda" |
| 128 | + model_type: "linear" |
| 129 | +``` |
| 130 | +
|
| 131 | +### 2. Run CI Benchmarks |
| 132 | +
|
| 133 | +Use the CI runner to generate results in PyTorch OSS benchmark database format: |
| 134 | +
|
| 135 | +```bash |
| 136 | +python benchmarks/dashboard/ci_microbenchmark_runner.py \ |
| 137 | + --config benchmarks/dashboard/microbenchmark_quantization_config.yml \ |
| 138 | + --output benchmark_results.json |
| 139 | +``` |
| 140 | + |
| 141 | +### 3. CI Output Format |
| 142 | + |
| 143 | +The CI runner outputs results in a specific JSON format required by the PyTorch OSS benchmark database: |
| 144 | + |
| 145 | +```json |
| 146 | +[ |
| 147 | + { |
| 148 | + "benchmark": { |
| 149 | + "name": "micro-benchmark api", |
| 150 | + "mode": "inference", |
| 151 | + "dtype": "int8wo", |
| 152 | + "extra_info": { |
| 153 | + "device": "cuda", |
| 154 | + "arch": "NVIDIA A100-SXM4-80GB" |
| 155 | + } |
| 156 | + }, |
| 157 | + "model": { |
| 158 | + "name": "1024-1024-1024", |
| 159 | + "type": "micro-benchmark custom layer", |
| 160 | + "origins": ["torchao"] |
| 161 | + }, |
| 162 | + "metric": { |
| 163 | + "name": "speedup(wrt bf16)", |
| 164 | + "benchmark_values": [1.25], |
| 165 | + "target_value": 0.0 |
| 166 | + }, |
| 167 | + "runners": [], |
| 168 | + "dependencies": {} |
| 169 | + } |
| 170 | +] |
| 171 | +``` |
| 172 | + |
| 173 | +### 4. Integration with CI Pipeline |
| 174 | + |
| 175 | +To integrate with your CI pipeline, add the benchmark step to your workflow: |
| 176 | + |
| 177 | +```yaml |
| 178 | +# Example GitHub Actions step |
| 179 | +- name: Run Microbenchmarks |
| 180 | + run: | |
| 181 | + python benchmarks/dashboard/ci_microbenchmark_runner.py \ |
| 182 | + --config benchmarks/dashboard/microbenchmark_quantization_config.yml \ |
| 183 | + --output benchmark_results.json |
| 184 | +
|
| 185 | +- name: Upload Results |
| 186 | + # Upload benchmark_results.json to your dashboard system |
| 187 | +``` |
| 188 | + |
| 189 | +## Troubleshooting |
| 190 | + |
| 191 | +### Running Tests |
| 192 | + |
| 193 | +To verify your setup and run the test suite: |
| 194 | + |
| 195 | +```bash |
| 196 | +python -m unittest discover benchmarks/microbenchmarks/test |
| 197 | +``` |
| 198 | + |
| 199 | +### Common Issues |
| 200 | + |
| 201 | +1. **CUDA Out of Memory**: Reduce batch size or matrix dimensions |
| 202 | +2. **Compilation Errors**: Set `use_torch_compile: false` for debugging |
| 203 | +3. **Missing Quantization Methods**: Ensure TorchAO is properly installed |
| 204 | +4. **Device Not Available**: Check device availability and drivers |
| 205 | + |
| 206 | +### Best Practices |
| 207 | + |
| 208 | +1. Use `small_sweep` for basic testing, `custom shapes` for comprehensive or model specific analysis |
| 209 | +2. Enable profiling only when needed (adds overhead) |
| 210 | +3. Test on multiple devices when possible |
| 211 | +4. Use consistent naming conventions for reproducibility |
| 212 | + |
| 213 | +For information on different use-cases for benchmarking, refer to [Benchmarking User Guide](benchmarking_user_guide.md) |
| 214 | + |
| 215 | +For more detailed information about the framework components, see the README files in the `benchmarks/microbenchmarks/` directory. |
0 commit comments