Skip to content

Commit aa2fa01

Browse files
committed
Add Claude MD file
stack-info: PR: #2311, branch: drisspg/stack/66
1 parent 61d49d4 commit aa2fa01

File tree

1 file changed

+322
-0
lines changed

1 file changed

+322
-0
lines changed

CLAUDE.md

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
torchao is an Architecture Optimization library that accelerates PyTorch models through quantization and sparsification techniques. It provides optimization for weights, gradients, activations, and more for both inference and training with minimal code changes.
8+
9+
## Prerequisites
10+
11+
### Required Dependencies
12+
```bash
13+
# Install PyTorch (required before torchao installation)
14+
pip install torch torchvision torchaudio
15+
16+
# For development, you may need specific PyTorch versions
17+
# Check requirements.txt or setup.py for version constraints
18+
```
19+
20+
## Development Commands
21+
22+
### Installation & Build
23+
```bash
24+
# Development install (Python-only mode, fastest for development)
25+
USE_CPP=0 python setup.py develop
26+
27+
# Full build with C++/CUDA extensions
28+
python setup.py develop
29+
30+
# Install specific version of ruff for linting
31+
pip install ruff==0.11.6
32+
```
33+
34+
### Testing
35+
```bash
36+
# Run specific test files
37+
pytest test/float8/test_base.py
38+
pytest test/quantization/test_quant_api.py
39+
pytest test/dtypes/test_affine_quantized.py
40+
41+
# Run comprehensive float8 tests
42+
./test/float8/test_everything.sh
43+
44+
# Run all tutorials
45+
./tutorials/run_all.sh
46+
```
47+
48+
### Linting & Formatting
49+
```bash
50+
# Install pre-commit hooks (one-time setup)
51+
pre-commit install
52+
53+
# Run all pre-commit checks
54+
pre-commit run --all-files
55+
56+
# Run pre-commit on staged files only
57+
pre-commit run
58+
```
59+
60+
## Architecture Overview
61+
62+
### Workflow-Based Structure (2025H1 Refresh)
63+
64+
TorchAO is transitioning from an AQT-centered structure to a **workflow-based organization** that embraces vertical workflows for optimal performance and maintainability.
65+
66+
### Current Structure
67+
68+
**torchao/quantization/** - User-facing APIs
69+
- `quantize_()` - Main quantization function with workflow-specific configs
70+
- `autoquant.py` - Automatic quantization selection
71+
- Configuration classes for different workflows
72+
73+
**torchao/sparsity/** - User-facing sparsity APIs
74+
- `sparsify_()` - Main sparsification function
75+
- Sparsity configuration classes
76+
77+
**Vertical Workflows** (in transition):
78+
- **int8_weight_only** - INT8 weight-only quantization workflows
79+
- **float8** - High-performance float8 training (1.5x speedup vs FP16)
80+
- **nf4** - NF4 (4-bit normal float) for QLoRA
81+
- **pt2e** - PT2E graph mode quantization (migrating from PyTorch Core)
82+
- **executorch** - ExecutorTorch workflows (moving from experimental)
83+
84+
**torchao/csrc/** - Custom kernels
85+
- CUTLASS-based implementations for maximum performance
86+
- ROCm support for AMD GPUs
87+
- CPU kernels with AVX512 optimizations
88+
89+
**torchao/experimental/** - Experimental features
90+
- MPS acceleration for Apple Silicon
91+
- Low-bit quantization research (1-7 bit weights)
92+
- Prototype workflows before graduation
93+
94+
### Design Philosophy
95+
96+
**Vertical Workflows Over Horizontal Abstractions:**
97+
- Self-contained workflows that can move fast on SOTA performance/accuracy
98+
- Workflows choose abstractions that fit their needs rather than forced repo-wide patterns
99+
- Well-fitting abstractions > no abstractions > poorly fitting abstractions
100+
- Duplicated easy-to-understand code preferred over highly abstracted hard-to-understand code
101+
102+
103+
## Build Configuration
104+
105+
The build system uses environment variables for configuration:
106+
107+
**Core Controls:**
108+
- `USE_CPP=0|1` - Skip C++/CUDA extensions (default: 1, set to 0 for fastest dev setup)
109+
- `USE_CPU_KERNELS=0|1` - Enable optimized CPU kernels (Linux only, default: 0)
110+
- `DEBUG=0|1` - Debug build mode
111+
112+
**Experimental Features:**
113+
- `BUILD_TORCHAO_EXPERIMENTAL=1` - Enable experimental cmake builds
114+
- `TORCHAO_BUILD_CPU_AARCH64=1` - ARM64 CPU kernels (auto-enabled on Apple Silicon)
115+
- `TORCHAO_BUILD_KLEIDIAI=1` - Kleidi AI library integration (experimental, accuracy issues)
116+
- `TORCHAO_BUILD_EXPERIMENTAL_MPS=1` - MPS acceleration (macOS only, disabled by default)
117+
- `USE_AVX512=1` - Enable AVX512 optimizations for x86 CPUs (default on Linux)
118+
119+
## Experimental Features (Alpha)
120+
121+
### Stability Levels
122+
**Alpha Features**: Early development stage requiring further refinement. These are prototypes due to:
123+
- Ongoing hardware support development
124+
- Non-compelling memory benchmarks
125+
- Need for compiler/kernel investment
126+
127+
### Current Experimental Features
128+
129+
**MX Training and Inference:**
130+
- **Status**: Prototype (hardware support not yet available)
131+
- **Description**: Tensors based on OCP MX specification for group-wise scaled float8/float6/float4/int8
132+
- **Usage**: Group-wise quantization with MX data types
133+
134+
**Int8 Quantized Training:**
135+
- **Status**: Prototype (memory benchmarks not yet compelling)
136+
- **Description**: Full int8 training support
137+
- **Usage**: `quantize_(model, int8_weight_only_quantized_training())`
138+
139+
**IntX (Low-bit integers):**
140+
- **Status**: Prototype (needs compiler/kernel investment)
141+
- **Description**: Various integer types through bitpacking in pure PyTorch
142+
- **Note**: Int4 remains more compelling than smaller data types currently
143+
144+
**Bitnet:**
145+
- **Status**: Experimental (dependent on better hardware/kernel support)
146+
- **Description**: Bitnet quantization technique
147+
- **Limitation**: Usefulness highly dependent on hardware improvements
148+
149+
### Hardware-Specific Experimental Features
150+
151+
**MPS Kernels (Apple Silicon):**
152+
- **Status**: Experimental (disabled by default)
153+
- **Requirements**: macOS with ARM64 architecture and MPS available
154+
- **Build**: `export TORCHAO_BUILD_EXPERIMENTAL_MPS=1`
155+
- **Features**: Metal shaders for int1mm, int2mm, int3mm, int4mm, int5mm, int6mm, int7mm
156+
157+
**ARM64/AArch64 CPU Kernels:**
158+
- **Status**: Auto-enabled on ARM64 Macs, manual enable elsewhere
159+
- **Build**: `export TORCHAO_BUILD_CPU_AARCH64=1`
160+
- **Features**:
161+
- Quantized matrix operations with NEON intrinsics
162+
- Bit-packing operations for low-bit quantization
163+
- Lookup table (LUT) operations for weight compression
164+
- Kleidi AI integration (experimental, accuracy issues in CI)
165+
166+
**Kleidi AI Integration:**
167+
- **Status**: Experimental (disabled by default)
168+
- **Build**: `export TORCHAO_BUILD_KLEIDIAI=1`
169+
- **Requirements**: ARM64 architecture
170+
- **Note**: Increases build time, has shown BF16 accuracy issues in CI tests
171+
172+
173+
## Development Patterns and Workflows
174+
175+
### Common Development Patterns
176+
177+
**One-line optimizations:** Use `quantize_(model, config)` and `sparsify_(model, config)` for quick model optimization
178+
- `quantize_(m, Int4WeightOnlyConfig())` applies 4-bit weight-only quantization
179+
- `sparsify_(model, BlockSparseWeightConfig())` applies block-sparse weight configuration
180+
181+
**Model-specific optimizations:** Specialized patterns for different model types
182+
- **SAM2**: Use "Fast Mode" with `torch.compile` and "Furious Mode" with FP16 precision
183+
- **LLMs**: Common patterns include KV cache quantization and Sparse-Marlin integration
184+
185+
**Composability focus:** Design optimizations to work with `torch.compile()` and FSDP2 without graph breaks
186+
187+
### Workflow Development Approach
188+
189+
**Workflow-First Development:**
190+
- Focus on vertical workflows rather than horizontal tensor subclass abstractions
191+
- Each workflow is self-contained and optimized for its specific use case
192+
- Workflows can choose their own abstractions and implementation patterns
193+
194+
**Workflow Implementation Patterns:**
195+
- **Config-based**: Each workflow provides configuration classes for `quantize_()`
196+
- **Kernel integration**: Workflows integrate with `torchao/csrc/` kernels as needed
197+
- **Composability**: Workflows maintain compatibility with `torch.compile` and FSDP2
198+
- **Independence**: Workflows avoid dependencies on repo-wide abstractions unless beneficial
199+
200+
**Abstraction Selection:**
201+
- Workflows choose abstractions that make their implementation cleaner and more maintainable
202+
- No enforcement of repo-wide abstractions without clear benefits
203+
- Many-to-many mapping between abstractions and workflows is acceptable
204+
205+
### Development Tasks
206+
207+
**Adding a new workflow:**
208+
1. Create a new workflow directory in the appropriate location
209+
2. Implement a configuration class for the workflow
210+
3. Add the config to `torchao/quantization/__init__.py` for `quantize_()` integration
211+
4. Implement the workflow using patterns that fit your use case (tensor subclass, module swap, etc.)
212+
5. Add any required kernels to `torchao/csrc/` or `torchao/kernel/`
213+
6. Choose helpful abstractions from common utilities as needed
214+
215+
**Current workflow examples:**
216+
- `int8_weight_only` - Uses AQT patterns where beneficial
217+
- `float8` - Uses `Float8Tensor` and specialized training patterns
218+
- `nf4` - Uses NF4-specific tensor subclass for QLoRA
219+
- `pt2e` - Uses graph mode quantization patterns
220+
221+
**Performance optimization:**
222+
1. Custom kernels go in `torchao/csrc/` with architecture-specific builds (SM90a, SM100a)
223+
2. Use `opcheck()` in tests to ensure `torch.compile` compatibility
224+
3. Implement fallback paths for unsupported configurations
225+
226+
**Testing:**
227+
1. Follow patterns in `test/` directory, use `pytest` for individual tests
228+
2. Use `TorchAOBasicTestCase` and `TorchAOCompileTestCase` for tensor subclass tests
229+
3. Include SQNR assertions for quantization accuracy verification
230+
231+
**Experimental workflows:**
232+
1. Develop in `torchao/experimental/` or `torchao/prototype/` as appropriate
233+
2. Use `_check_torchao_ops_loaded()` to verify experimental kernels are loaded
234+
3. Follow Alpha feature guidelines for prototype development
235+
4. Graduate to main workflow structure when ready for production use
236+
5. Focus on vertical workflow patterns rather than forcing horizontal abstractions
237+
238+
## Common Issues and Debugging
239+
240+
### Frequent Issues and Solutions
241+
242+
**torch.compile() graph breaks:**
243+
- **Issue**: Custom kernels causing graph breaks when used with `torch.compile()`
244+
- **Debug**: Run with `fullgraph=True` and `TORCH_LOGS="output_code"` to inspect generated code
245+
- **Solution**: Ensure tensor subclasses implement `__tensor_flatten__` and `__tensor_unflatten__`
246+
247+
**Device and data type compatibility:**
248+
- **Issue**: Some experimental features only support specific devices (e.g., CPU-only embedding quantization)
249+
- **Solution**: Check feature documentation for supported devices and data types
250+
- **Example**: MPS quantization requires macOS with ARM64 architecture
251+
252+
**Performance analysis:**
253+
- **Issue**: Need to benchmark optimized models
254+
- **Tools**: Use `print_op_and_shapes.py` to identify relevant shapes for microbenchmarking
255+
- **Profiling**: Add `--profile=profile_path` to benchmark scripts for Chrome traces
256+
257+
**Accuracy degradation:**
258+
- **Issue**: Quantization/sparsity causing accuracy loss
259+
- **Analysis**: Check scale/zero_point for quantization, mask for sparsity
260+
- **Solution**: Consider Quantization-Aware Training (QAT) for accuracy recovery
261+
262+
### Common Debugging Commands
263+
264+
```bash
265+
# Check CUDA availability and version
266+
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}, version: {torch.version.cuda}')"
267+
268+
# Check build configuration
269+
python -c "import torchao; print(torchao.__file__)"
270+
271+
# Debug torch.compile issues
272+
TORCH_LOGS="output_code" python your_script.py
273+
274+
# Run specific test with verbose output
275+
pytest -v -s test/quantization/test_quant_api.py::test_specific_function
276+
277+
# Check for CUDA kernel compilation issues
278+
USE_CPP=1 python setup.py develop --verbose
279+
280+
# Verify experimental kernels are loaded
281+
python -c "from torchao.experimental import _check_torchao_ops_loaded; _check_torchao_ops_loaded()"
282+
283+
# Profile model performance
284+
python benchmark_script.py --profile=profile_output.json
285+
```
286+
287+
## Testing and Benchmarking
288+
289+
### Testing Infrastructure
290+
291+
**Test organization:**
292+
- Unit tests: `test_base.py` for core components like `Float8Tensor`
293+
- Integration tests: `test_integration.py` for AOTI compilation with tensor subclasses
294+
- Numerical accuracy: `test_numerics_integration.py` for Float8 operations
295+
296+
**Test utilities:**
297+
- `TorchAOBasicTestCase`: Basic tensor subclass testing
298+
- `TorchAOCompileTestCase`: `torch.compile` compatibility testing
299+
- SQNR assertions for minimum signal-to-quantization noise ratio
300+
301+
### Performance Benchmarking
302+
303+
**Microbenchmarks:**
304+
- `bench_matmul.py`: Benchmark `torch._scaled_mm` function
305+
- `bench_linear_float8.py`: Benchmark `nn.Linear` vs `Float8Linear`
306+
- `benchmark_aq.py`: Benchmark various quantized tensor subclasses
307+
308+
**Model-level benchmarks:**
309+
- **Llama**: `generate.py` for generation performance, `eval.py` for evaluation
310+
- **SAM**: `eval_combo.py` for SAM model benchmarking
311+
- Enable profiling with `generate_model_profile` for detailed analysis
312+
313+
**Continuous Integration:**
314+
- `dashboard_perf_test.yml`: Nightly A100 benchmarks with dashboard visualization
315+
- `torchao_experimental_test.yml`: Experimental feature validation
316+
317+
## Important Notes
318+
319+
- Always run `pre-commit run --all-files` before committing
320+
- Use `USE_CPP=0` for faster iteration during Python-only development
321+
- CUTLASS kernels have architecture-specific builds (SM90a, SM100a) based on CUDA version
322+
- Git submodules (CUTLASS) are automatically initialized during build

0 commit comments

Comments
 (0)