|
| 1 | +# TorchAO Integration with VLLM: Architecture and Usage Guide |
| 2 | + |
| 3 | +This tutorial provides a comprehensive overview of how TorchAO integrates with VLLM, and what needs to be implemented to have a new technique work E2E. |
| 4 | + |
| 5 | +## Table of Contents |
| 6 | +* [Configuration System](#configuration-system) |
| 7 | +* [Usage Examples](#usage-examples) |
| 8 | +* [Adding New Quantization Methods to VLLM](#adding-new-quantization-methods-to-vllm) |
| 9 | +* [Step-by-Step Guide to Add a New Quantization Method](#step-by-step-guide-to-add-a-new-quantization-method) |
| 10 | +* [Serialization and Model Sharing](#serialization-and-model-sharing) |
| 11 | +* [Integration Architecture Diagrams](#integration-architecture-diagrams) |
| 12 | + |
| 13 | + |
| 14 | +## Configuration System |
| 15 | + |
| 16 | +### 1. HuggingFace Model Configuration |
| 17 | + |
| 18 | +TorchAO quantization is configured through the model's `config.json` file: |
| 19 | + |
| 20 | +```json |
| 21 | +{ |
| 22 | + "model_type": "llama", |
| 23 | + "quant_type": { |
| 24 | + "default": { |
| 25 | + "_type": "Int4WeightOnlyConfig", |
| 26 | + "_data": { |
| 27 | + "group_size": 128, |
| 28 | + "use_hqq": true |
| 29 | + } |
| 30 | + } |
| 31 | + } |
| 32 | +} |
| 33 | +``` |
| 34 | + |
| 35 | +### 2. TorchAO Configuration Classes |
| 36 | + |
| 37 | +All quantization methods inherit from `AOBaseConfig`: |
| 38 | + |
| 39 | +```python |
| 40 | +from torchao.core.config import AOBaseConfig |
| 41 | +from torchao.quantization import Int4WeightOnlyConfig |
| 42 | + |
| 43 | +# Example configuration |
| 44 | +config = Int4WeightOnlyConfig( |
| 45 | + group_size=128, |
| 46 | + use_hqq=True, |
| 47 | +) |
| 48 | +assert isinstance(config, AOBaseConfig) |
| 49 | +``` |
| 50 | + |
| 51 | +### 3. Module-Level Configuration |
| 52 | + |
| 53 | +For granular control, use `ModuleFqnToConfig`: |
| 54 | + |
| 55 | +```python |
| 56 | +from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig |
| 57 | + |
| 58 | +config = ModuleFqnToConfig({ |
| 59 | + "model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64), |
| 60 | + "model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64), |
| 61 | + "model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(), |
| 62 | + "_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules |
| 63 | +}) |
| 64 | +``` |
| 65 | + |
| 66 | +## Usage Examples |
| 67 | + |
| 68 | +### 1. Quantizing Models with HuggingFace Integration |
| 69 | + |
| 70 | +```python |
| 71 | +from transformers import TorchAoConfig, AutoModelForCausalLM |
| 72 | +from torchao.quantization import Int4WeightOnlyConfig |
| 73 | + |
| 74 | +# Create quantization configuration |
| 75 | +quantization_config = TorchAoConfig( |
| 76 | + quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True) |
| 77 | +) |
| 78 | + |
| 79 | +# Load and automatically quantize the model |
| 80 | +model = AutoModelForCausalLM.from_pretrained( |
| 81 | + "meta-llama/Llama-3.2-1B", |
| 82 | + torch_dtype="auto", |
| 83 | + device_map="auto", |
| 84 | + quantization_config=quantization_config |
| 85 | +) |
| 86 | + |
| 87 | +# Save quantized model (see Serialization section below for safe_serialization details) |
| 88 | +model.push_to_hub("your-username/Llama-3.2-1B-int4", safe_serialization=False) |
| 89 | +``` |
| 90 | + |
| 91 | +### 2. Serving with VLLM |
| 92 | + |
| 93 | +```bash |
| 94 | +# Start VLLM server with TorchAO quantized model |
| 95 | +vllm serve your-username/Llama-3.2-1B-int4 \ |
| 96 | + --quantization torchao \ |
| 97 | + --dtype bfloat16 \ |
| 98 | +``` |
| 99 | + |
| 100 | + |
| 101 | +## Adding New Quantization Methods to VLLM |
| 102 | + |
| 103 | +### Minimal Requirements for VLLM Compatibility |
| 104 | + |
| 105 | +To make a new TorchAO quantization method work with VLLM, you need to implement minimal tensor subclass operations that support **tensor parallelism**. VLLM uses `narrow()` and copy_ to move data from host cpu loaded in a state dict to the device, these require these specific aten operations: |
| 106 | + |
| 107 | +### Why these ? |
| 108 | + |
| 109 | +VLLM's tensor parallelism works by: |
| 110 | +1. **`narrow()`** - Slicing weight tensors across different dimensions |
| 111 | +2. **Sharding** - Distributing tensor chunks across multiple GPUs |
| 112 | +3. **`copy_()`** - Moving tensor data between devices |
| 113 | +4. **`detach()`** |
| 114 | + |
| 115 | + |
| 116 | +A helpful pattern for doing this is `_apply_fn_to_data`, a method that applies a given function to all the attributes on your class w/ Tensor types. Below is a generic implementation that should work for most subclasses. We make heavy use of this, you can see an examples in [mx_ops.py](../torchao/prototype/mx_formats/mx_ops.py): |
| 117 | + |
| 118 | +```python |
| 119 | +def _apply_fn_to_data(self, fn: Callable): |
| 120 | + """Applies a fn to all tensor components stored on this class""" |
| 121 | + tensor_names, ctx = self.__tensor_flatten__() |
| 122 | + |
| 123 | + # Apply the function to each tensor component |
| 124 | + new_tensors = {} |
| 125 | + for name in tensor_names: |
| 126 | + new_tensors[name] = fn(getattr(self, name)) |
| 127 | + |
| 128 | + return self.__class__.__tensor_unflatten__( |
| 129 | + new_tensors, |
| 130 | + ctx, |
| 131 | + None, # outer_size parameter |
| 132 | + None, # outer_stride parameter |
| 133 | + ) |
| 134 | +``` |
| 135 | + |
| 136 | +## Step-by-Step Guide to Add a New Quantization Method |
| 137 | + |
| 138 | +#### 1. Create Your Tensor Subclass |
| 139 | + |
| 140 | +```python |
| 141 | +from torchao.core.config import AOBaseConfig |
| 142 | +from torchao.utils import TorchAOBaseTensor |
| 143 | + |
| 144 | +@dataclass |
| 145 | +class MyNewQuantConfig(AOBaseConfig): |
| 146 | + """Configuration for your new quantization method""" |
| 147 | + bits: int = 8 |
| 148 | + VERSION: ClassVar[int] = 1 |
| 149 | + |
| 150 | +class MyQuantizedTensor(TorchAOBaseTensor): |
| 151 | + """Example based on FbgemmFp8Tensor - stores quantized data + scale""" |
| 152 | + |
| 153 | + tensor_data_attrs = ["quantized_data", "scale"] |
| 154 | + tensor_attributes = ["dtype"] |
| 155 | + |
| 156 | + def __new__(cls, quantized_data, scale, dtype): |
| 157 | + shape = quantized_data.shape |
| 158 | + return torch.Tensor._make_wrapper_subclass( |
| 159 | + cls, shape, device=quantized_data.device, dtype=dtype, requires_grad=False |
| 160 | + ) |
| 161 | + |
| 162 | + def __init__(self, quantized_data, scale, dtype): |
| 163 | + self.quantized_data = quantized_data |
| 164 | + self.scale = scale |
| 165 | + |
| 166 | + def __tensor_flatten__(self) -> Tuple[List[str], List]: |
| 167 | + """Serialize tensor subclass into plain tensors and metadata""" |
| 168 | + return self.tensor_data_attrs, [getattr(self, attr) for attr in self.tensor_attributes] |
| 169 | + |
| 170 | + @classmethod |
| 171 | + def __tensor_unflatten__(cls, tensor_data_dict: Dict[str, torch.Tensor], tensor_attributes: List, |
| 172 | + outer_size: Optional[torch.Size], outer_stride: Optional[Tuple]) -> "MyQuantizedTensor": |
| 173 | + """Reconstruct tensor subclass from serialized data""" |
| 174 | + return cls( |
| 175 | + *[tensor_data_dict[name] for name in cls.tensor_data_attrs], |
| 176 | + *tensor_attributes, |
| 177 | + ) |
| 178 | +``` |
| 179 | + |
| 180 | +#### 2. Implement Required VLLM Operations |
| 181 | + |
| 182 | +```python |
| 183 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 184 | + |
| 185 | +@MyQuantizedTensor.implements([aten.detach.default, aten.alias.default]) |
| 186 | +def _(func, types, args, kwargs): |
| 187 | + return return_and_correct_aliasing( |
| 188 | + func, args, kwargs, args[0]._apply_fn_to_data(func) |
| 189 | + ) |
| 190 | + |
| 191 | +@MyQuantizedTensor.implements([aten._to_copy.default]) |
| 192 | +def _(func, types, args, kwargs): |
| 193 | + return return_and_correct_aliasing( |
| 194 | + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) |
| 195 | + ) |
| 196 | + |
| 197 | +@MyQuantizedTensor.implements([aten.slice.Tensor]) |
| 198 | +def _(func, types, args, kwargs): |
| 199 | + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) |
| 200 | + if dim == 0 or dim == 1: |
| 201 | + # NOTE the slicing here will likely be different for different quant techniques |
| 202 | + return return_and_correct_aliasing( |
| 203 | + func, args, kwargs, |
| 204 | + args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) |
| 205 | + ) |
| 206 | + else: |
| 207 | + raise NotImplementedError(f"Slicing along dim={dim} not supported") |
| 208 | +``` |
| 209 | + |
| 210 | +#### 3. Register with TorchAO's Quantization System |
| 211 | + |
| 212 | +```python |
| 213 | +from torchao.quantization.transform_module import register_quantize_module_handler |
| 214 | + |
| 215 | +@register_quantize_module_handler(MyNewQuantConfig) |
| 216 | +def _my_quant_transform(module: torch.nn.Module, config: MyNewQuantConfig): |
| 217 | + """Transform function that applies your quantization to a module""" |
| 218 | + weight = module.weight |
| 219 | + |
| 220 | + # Your quantization logic here |
| 221 | + quantized_weight = my_quantization_function(weight, config) |
| 222 | + |
| 223 | + # Replace the weight with your quantized tensor |
| 224 | + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) |
| 225 | + return module |
| 226 | +``` |
| 227 | + |
| 228 | +### Key Implementation Details |
| 229 | + |
| 230 | +#### Hardware-Specific Linear Operations |
| 231 | +Your quantized tensor's forward pass determines hardware support and what actually gets called when F.Linear is called. |
| 232 | + |
| 233 | +```python |
| 234 | +@MyQuantizedTensor.implements(torch.nn.functional.linear) |
| 235 | +def _(func, types, args, kwargs): |
| 236 | + input_tensor, weight_tensor, bias = args[0], args[1], args[2] if len(args) > 2 else None |
| 237 | + |
| 238 | + # This is where you define what hardware your method supports |
| 239 | + if hasattr(weight_tensor, 'use_cutlass_kernel'): |
| 240 | + return my_cutlass_linear(input_tensor, weight_tensor, bias) |
| 241 | + elif hasattr(weight_tensor, 'use_triton_kernel'): |
| 242 | + return my_triton_linear(input_tensor, weight_tensor, bias) |
| 243 | + else: |
| 244 | + # Fallback - dequantize and use standard linear |
| 245 | + return torch.nn.functional.linear( |
| 246 | + input_tensor, weight_tensor.dequantize(), bias |
| 247 | + ) |
| 248 | +``` |
| 249 | + |
| 250 | +#### Compilation Benefits |
| 251 | +The overhead of tensor subclasses disappears with `torch.compile()`, this is on by default in VLLM. |
| 252 | + |
| 253 | +### Trade Off of Tensor Subclasses |
| 254 | +1. **Compilation**: is essential for removing subclass overhead. Without it unless your model is extremely gpu bound the overhead of dispatch on the CPU can severely impact performance. |
| 255 | +2. The checkpoint defines the behavior of the model. You might be saying "don't all checkpoints do this". This is true, however people typically solely think of a torch.Tensor as its data. When in actuality its a true class where it brings the Dispatcher and all the kernels ATen has registered to it. When you define your tensor subclass, you are building a separate little world. One w/ a different representation of data, but also one where you need to explicitly define what ops you support and have implementations for all the hardware you want to support. This can feel a little like spooky action at a distance at first. But it can be very powerful. Case in point is being able to support TP with only 3 definitions. |
| 256 | + |
| 257 | +## Serialization and Model Sharing |
| 258 | + |
| 259 | +### SafeTensors Support |
| 260 | + |
| 261 | +**Current Status**: TorchAO quantized models cannot yet be serialized with safetensors due to tensor subclass limitations. When saving quantized models, you must use `safe_serialization=False`. |
| 262 | + |
| 263 | +**Workaround**: For production use, save models with `safe_serialization=False` when pushing to HuggingFace Hub. |
| 264 | + |
| 265 | +**Future Work**: The TorchAO team is actively working on safetensors support for tensor subclasses. Track progress at: https://github.com/pytorch/ao/issues/2338 |
| 266 | + |
| 267 | +## Integration Architecture Diagrams |
| 268 | + |
| 269 | +### 1. High-Level Model Flow: Transformers → VLLM + TorchAO |
| 270 | + |
| 271 | +This diagram shows the end-to-end flow from model creation to serving: |
| 272 | + |
| 273 | +```mermaid |
| 274 | +graph LR |
| 275 | + A[HuggingFace Model] --> B[Transformers AutoModel] |
| 276 | + B --> C{Quantization Config?} |
| 277 | + C -->|TorchAO Config| D[Apply TorchAO Quantization] |
| 278 | + C -->|No Config| E[Standard Model] |
| 279 | +
|
| 280 | + D --> F[Quantized Model w/ Tensor Subclasses] |
| 281 | + E --> G[Standard PyTorch Model] |
| 282 | +
|
| 283 | + F --> H[VLLM Model Loading] |
| 284 | + G --> H |
| 285 | +
|
| 286 | + H --> I[VLLM Distributed Engine] |
| 287 | + I --> J[Tensor Parallel Sharding] |
| 288 | + J --> K[Optimized Inference] |
| 289 | +
|
| 290 | + style D fill:#e1f5fe |
| 291 | + style F fill:#f3e5f5 |
| 292 | + style J fill:#e8f5e8 |
| 293 | +``` |
| 294 | + |
| 295 | +### 2. TorchAO Integration Points in VLLM |
| 296 | + |
| 297 | +This shows how VLLM detects and applies TorchAO quantization: |
| 298 | + |
| 299 | +```mermaid |
| 300 | +graph LR |
| 301 | + A[Model Config Detection] --> B{quantization=torchao?} |
| 302 | + B -->|Yes| C[TorchAOConfig.from_config] |
| 303 | + B -->|No| D[Other Quantization Methods] |
| 304 | +
|
| 305 | + C --> E[Parse HF quant_type] |
| 306 | + E --> F[config_from_dict] |
| 307 | + F --> G[AOBaseConfig Instance] |
| 308 | +
|
| 309 | + G --> H[get_quant_method per layer] |
| 310 | + H --> I{Layer Type?} |
| 311 | + I -->|LinearBase| J[TorchAOLinearMethod] |
| 312 | + I -->|Other| K[UnquantizedLinearMethod] |
| 313 | +
|
| 314 | + J --> L[create_weights] |
| 315 | + L --> M[torchao_quantize_param_data] |
| 316 | + M --> N[Quantized Tensor Subclass] |
| 317 | +
|
| 318 | + style C fill:#e1f5fe |
| 319 | + style G fill:#f3e5f5 |
| 320 | + style N fill:#e8f5e8 |
| 321 | +``` |
| 322 | + |
| 323 | +### 3. Kernel Dispatch: Bringing External Kernels to VLLM |
| 324 | + |
| 325 | +This illustrates how tensor subclasses enable custom kernel dispatch within VLLM: |
| 326 | + |
| 327 | +```mermaid |
| 328 | +graph LR |
| 329 | + A[F.linear Call in VLLM] --> B[MyQuantTensor torch_function] |
| 330 | + B --> C[Custom implements Handler] |
| 331 | + C --> D{Hardware Check} |
| 332 | +
|
| 333 | + D --> E[Dispatch to External Kernel] |
| 334 | + E --> F[Execute Optimized Kernel] |
| 335 | + F --> G[Return Result to VLLM] |
| 336 | +
|
| 337 | + subgraph "External Libraries" |
| 338 | + H[TorchAO CUTLASS] |
| 339 | + I[TorchAO Triton] |
| 340 | + J[FBGEMM-GPU] |
| 341 | + K[Custom Libraries] |
| 342 | + end |
| 343 | +
|
| 344 | + subgraph "Tensor Subclass Code" |
| 345 | + L[implements F.linear] |
| 346 | + M[custom_linear_impl] |
| 347 | + N[call external kernel] |
| 348 | + end |
| 349 | +
|
| 350 | + E --> H |
| 351 | + E --> I |
| 352 | + E --> J |
| 353 | + E --> K |
| 354 | +
|
| 355 | + C --> L |
| 356 | + L --> M |
| 357 | + M --> N |
| 358 | + N --> E |
| 359 | +
|
| 360 | + style B fill:#e8f6ff,color:#000 |
| 361 | + style C fill:#fff3e0,color:#000 |
| 362 | + style E fill:#e8f5e8,color:#000 |
| 363 | + style L fill:#f3e5f5,color:#000 |
| 364 | +``` |
0 commit comments