Skip to content

Commit 7f13c3f

Browse files
committed
Add Tutorial on E2E integration into VLLM and minimal Subclass
stack-info: PR: #2346, branch: drisspg/stack/75
1 parent 83663b8 commit 7f13c3f

File tree

1 file changed

+364
-0
lines changed

1 file changed

+364
-0
lines changed

tutorials/torchao_vllm_integration.md

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# TorchAO Integration with VLLM: Architecture and Usage Guide
2+
3+
This tutorial provides a comprehensive overview of how TorchAO integrates with VLLM, and what needs to be implemented to have a new technique work E2E.
4+
5+
## Table of Contents
6+
* [Configuration System](#configuration-system)
7+
* [Usage Examples](#usage-examples)
8+
* [Adding New Quantization Methods to VLLM](#adding-new-quantization-methods-to-vllm)
9+
* [Step-by-Step Guide to Add a New Quantization Method](#step-by-step-guide-to-add-a-new-quantization-method)
10+
* [Serialization and Model Sharing](#serialization-and-model-sharing)
11+
* [Integration Architecture Diagrams](#integration-architecture-diagrams)
12+
13+
14+
## Configuration System
15+
16+
### 1. HuggingFace Model Configuration
17+
18+
TorchAO quantization is configured through the model's `config.json` file:
19+
20+
```json
21+
{
22+
"model_type": "llama",
23+
"quant_type": {
24+
"default": {
25+
"_type": "Int4WeightOnlyConfig",
26+
"_data": {
27+
"group_size": 128,
28+
"use_hqq": true
29+
}
30+
}
31+
}
32+
}
33+
```
34+
35+
### 2. TorchAO Configuration Classes
36+
37+
All quantization methods inherit from `AOBaseConfig`:
38+
39+
```python
40+
from torchao.core.config import AOBaseConfig
41+
from torchao.quantization import Int4WeightOnlyConfig
42+
43+
# Example configuration
44+
config = Int4WeightOnlyConfig(
45+
group_size=128,
46+
use_hqq=True,
47+
)
48+
assert isinstance(config, AOBaseConfig)
49+
```
50+
51+
### 3. Module-Level Configuration
52+
53+
For granular control, use `ModuleFqnToConfig`:
54+
55+
```python
56+
from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig
57+
58+
config = ModuleFqnToConfig({
59+
"model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64),
60+
"model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64),
61+
"model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(),
62+
"_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules
63+
})
64+
```
65+
66+
## Usage Examples
67+
68+
### 1. Quantizing Models with HuggingFace Integration
69+
70+
```python
71+
from transformers import TorchAoConfig, AutoModelForCausalLM
72+
from torchao.quantization import Int4WeightOnlyConfig
73+
74+
# Create quantization configuration
75+
quantization_config = TorchAoConfig(
76+
quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True)
77+
)
78+
79+
# Load and automatically quantize the model
80+
model = AutoModelForCausalLM.from_pretrained(
81+
"meta-llama/Llama-3.2-1B",
82+
torch_dtype="auto",
83+
device_map="auto",
84+
quantization_config=quantization_config
85+
)
86+
87+
# Save quantized model (see Serialization section below for safe_serialization details)
88+
model.push_to_hub("your-username/Llama-3.2-1B-int4", safe_serialization=False)
89+
```
90+
91+
### 2. Serving with VLLM
92+
93+
```bash
94+
# Start VLLM server with TorchAO quantized model
95+
vllm serve your-username/Llama-3.2-1B-int4 \
96+
--quantization torchao \
97+
--dtype bfloat16 \
98+
```
99+
100+
101+
## Adding New Quantization Methods to VLLM
102+
103+
### Minimal Requirements for VLLM Compatibility
104+
105+
To make a new TorchAO quantization method work with VLLM, you need to implement minimal tensor subclass operations that support **tensor parallelism**. VLLM uses `narrow()` and copy_ to move data from host cpu loaded in a state dict to the device, these require these specific aten operations:
106+
107+
### Why these ?
108+
109+
VLLM's tensor parallelism works by:
110+
1. **`narrow()`** - Slicing weight tensors across different dimensions
111+
2. **Sharding** - Distributing tensor chunks across multiple GPUs
112+
3. **`copy_()`** - Moving tensor data between devices
113+
4. **`detach()`**
114+
115+
116+
A helpful pattern for doing this is `_apply_fn_to_data`, a method that applies a given function to all the attributes on your class w/ Tensor types. Below is a generic implementation that should work for most subclasses. We make heavy use of this, you can see an examples in [mx_ops.py](../torchao/prototype/mx_formats/mx_ops.py):
117+
118+
```python
119+
def _apply_fn_to_data(self, fn: Callable):
120+
"""Applies a fn to all tensor components stored on this class"""
121+
tensor_names, ctx = self.__tensor_flatten__()
122+
123+
# Apply the function to each tensor component
124+
new_tensors = {}
125+
for name in tensor_names:
126+
new_tensors[name] = fn(getattr(self, name))
127+
128+
return self.__class__.__tensor_unflatten__(
129+
new_tensors,
130+
ctx,
131+
None, # outer_size parameter
132+
None, # outer_stride parameter
133+
)
134+
```
135+
136+
## Step-by-Step Guide to Add a New Quantization Method
137+
138+
#### 1. Create Your Tensor Subclass
139+
140+
```python
141+
from torchao.core.config import AOBaseConfig
142+
from torchao.utils import TorchAOBaseTensor
143+
144+
@dataclass
145+
class MyNewQuantConfig(AOBaseConfig):
146+
"""Configuration for your new quantization method"""
147+
bits: int = 8
148+
VERSION: ClassVar[int] = 1
149+
150+
class MyQuantizedTensor(TorchAOBaseTensor):
151+
"""Example based on FbgemmFp8Tensor - stores quantized data + scale"""
152+
153+
tensor_data_attrs = ["quantized_data", "scale"]
154+
tensor_attributes = ["dtype"]
155+
156+
def __new__(cls, quantized_data, scale, dtype):
157+
shape = quantized_data.shape
158+
return torch.Tensor._make_wrapper_subclass(
159+
cls, shape, device=quantized_data.device, dtype=dtype, requires_grad=False
160+
)
161+
162+
def __init__(self, quantized_data, scale, dtype):
163+
self.quantized_data = quantized_data
164+
self.scale = scale
165+
166+
def __tensor_flatten__(self) -> Tuple[List[str], List]:
167+
"""Serialize tensor subclass into plain tensors and metadata"""
168+
return self.tensor_data_attrs, [getattr(self, attr) for attr in self.tensor_attributes]
169+
170+
@classmethod
171+
def __tensor_unflatten__(cls, tensor_data_dict: Dict[str, torch.Tensor], tensor_attributes: List,
172+
outer_size: Optional[torch.Size], outer_stride: Optional[Tuple]) -> "MyQuantizedTensor":
173+
"""Reconstruct tensor subclass from serialized data"""
174+
return cls(
175+
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
176+
*tensor_attributes,
177+
)
178+
```
179+
180+
#### 2. Implement Required VLLM Operations
181+
182+
```python
183+
from torch.utils._python_dispatch import return_and_correct_aliasing
184+
185+
@MyQuantizedTensor.implements([aten.detach.default, aten.alias.default])
186+
def _(func, types, args, kwargs):
187+
return return_and_correct_aliasing(
188+
func, args, kwargs, args[0]._apply_fn_to_data(func)
189+
)
190+
191+
@MyQuantizedTensor.implements([aten._to_copy.default])
192+
def _(func, types, args, kwargs):
193+
return return_and_correct_aliasing(
194+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
195+
)
196+
197+
@MyQuantizedTensor.implements([aten.slice.Tensor])
198+
def _(func, types, args, kwargs):
199+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
200+
if dim == 0 or dim == 1:
201+
# NOTE the slicing here will likely be different for different quant techniques
202+
return return_and_correct_aliasing(
203+
func, args, kwargs,
204+
args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
205+
)
206+
else:
207+
raise NotImplementedError(f"Slicing along dim={dim} not supported")
208+
```
209+
210+
#### 3. Register with TorchAO's Quantization System
211+
212+
```python
213+
from torchao.quantization.transform_module import register_quantize_module_handler
214+
215+
@register_quantize_module_handler(MyNewQuantConfig)
216+
def _my_quant_transform(module: torch.nn.Module, config: MyNewQuantConfig):
217+
"""Transform function that applies your quantization to a module"""
218+
weight = module.weight
219+
220+
# Your quantization logic here
221+
quantized_weight = my_quantization_function(weight, config)
222+
223+
# Replace the weight with your quantized tensor
224+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
225+
return module
226+
```
227+
228+
### Key Implementation Details
229+
230+
#### Hardware-Specific Linear Operations
231+
Your quantized tensor's forward pass determines hardware support and what actually gets called when F.Linear is called.
232+
233+
```python
234+
@MyQuantizedTensor.implements(torch.nn.functional.linear)
235+
def _(func, types, args, kwargs):
236+
input_tensor, weight_tensor, bias = args[0], args[1], args[2] if len(args) > 2 else None
237+
238+
# This is where you define what hardware your method supports
239+
if hasattr(weight_tensor, 'use_cutlass_kernel'):
240+
return my_cutlass_linear(input_tensor, weight_tensor, bias)
241+
elif hasattr(weight_tensor, 'use_triton_kernel'):
242+
return my_triton_linear(input_tensor, weight_tensor, bias)
243+
else:
244+
# Fallback - dequantize and use standard linear
245+
return torch.nn.functional.linear(
246+
input_tensor, weight_tensor.dequantize(), bias
247+
)
248+
```
249+
250+
#### Compilation Benefits
251+
The overhead of tensor subclasses disappears with `torch.compile()`, this is on by default in VLLM.
252+
253+
### Trade Off of Tensor Subclasses
254+
1. **Compilation**: is essential for removing subclass overhead. Without it unless your model is extremely gpu bound the overhead of dispatch on the CPU can severely impact performance.
255+
2. The checkpoint defines the behavior of the model. You might be saying "don't all checkpoints do this". This is true, however people typically solely think of a torch.Tensor as its data. When in actuality its a true class where it brings the Dispatcher and all the kernels ATen has registered to it. When you define your tensor subclass, you are building a separate little world. One w/ a different representation of data, but also one where you need to explicitly define what ops you support and have implementations for all the hardware you want to support. This can feel a little like spooky action at a distance at first. But it can be very powerful. Case in point is being able to support TP with only 3 definitions.
256+
257+
## Serialization and Model Sharing
258+
259+
### SafeTensors Support
260+
261+
**Current Status**: TorchAO quantized models cannot yet be serialized with safetensors due to tensor subclass limitations. When saving quantized models, you must use `safe_serialization=False`.
262+
263+
**Workaround**: For production use, save models with `safe_serialization=False` when pushing to HuggingFace Hub.
264+
265+
**Future Work**: The TorchAO team is actively working on safetensors support for tensor subclasses. Track progress at: https://github.com/pytorch/ao/issues/2338
266+
267+
## Integration Architecture Diagrams
268+
269+
### 1. High-Level Model Flow: Transformers → VLLM + TorchAO
270+
271+
This diagram shows the end-to-end flow from model creation to serving:
272+
273+
```mermaid
274+
graph LR
275+
A[HuggingFace Model] --> B[Transformers AutoModel]
276+
B --> C{Quantization Config?}
277+
C -->|TorchAO Config| D[Apply TorchAO Quantization]
278+
C -->|No Config| E[Standard Model]
279+
280+
D --> F[Quantized Model w/ Tensor Subclasses]
281+
E --> G[Standard PyTorch Model]
282+
283+
F --> H[VLLM Model Loading]
284+
G --> H
285+
286+
H --> I[VLLM Distributed Engine]
287+
I --> J[Tensor Parallel Sharding]
288+
J --> K[Optimized Inference]
289+
290+
style D fill:#e1f5fe
291+
style F fill:#f3e5f5
292+
style J fill:#e8f5e8
293+
```
294+
295+
### 2. TorchAO Integration Points in VLLM
296+
297+
This shows how VLLM detects and applies TorchAO quantization:
298+
299+
```mermaid
300+
graph LR
301+
A[Model Config Detection] --> B{quantization=torchao?}
302+
B -->|Yes| C[TorchAOConfig.from_config]
303+
B -->|No| D[Other Quantization Methods]
304+
305+
C --> E[Parse HF quant_type]
306+
E --> F[config_from_dict]
307+
F --> G[AOBaseConfig Instance]
308+
309+
G --> H[get_quant_method per layer]
310+
H --> I{Layer Type?}
311+
I -->|LinearBase| J[TorchAOLinearMethod]
312+
I -->|Other| K[UnquantizedLinearMethod]
313+
314+
J --> L[create_weights]
315+
L --> M[torchao_quantize_param_data]
316+
M --> N[Quantized Tensor Subclass]
317+
318+
style C fill:#e1f5fe
319+
style G fill:#f3e5f5
320+
style N fill:#e8f5e8
321+
```
322+
323+
### 3. Kernel Dispatch: Bringing External Kernels to VLLM
324+
325+
This illustrates how tensor subclasses enable custom kernel dispatch within VLLM:
326+
327+
```mermaid
328+
graph LR
329+
A[F.linear Call in VLLM] --> B[MyQuantTensor torch_function]
330+
B --> C[Custom implements Handler]
331+
C --> D{Hardware Check}
332+
333+
D --> E[Dispatch to External Kernel]
334+
E --> F[Execute Optimized Kernel]
335+
F --> G[Return Result to VLLM]
336+
337+
subgraph "External Libraries"
338+
H[TorchAO CUTLASS]
339+
I[TorchAO Triton]
340+
J[FBGEMM-GPU]
341+
K[Custom Libraries]
342+
end
343+
344+
subgraph "Tensor Subclass Code"
345+
L[implements F.linear]
346+
M[custom_linear_impl]
347+
N[call external kernel]
348+
end
349+
350+
E --> H
351+
E --> I
352+
E --> J
353+
E --> K
354+
355+
C --> L
356+
L --> M
357+
M --> N
358+
N --> E
359+
360+
style B fill:#e8f6ff,color:#000
361+
style C fill:#fff3e0,color:#000
362+
style E fill:#e8f5e8,color:#000
363+
style L fill:#f3e5f5,color:#000
364+
```

0 commit comments

Comments
 (0)