Skip to content

Commit 838b15b

Browse files
committed
Merge branch 'main' into fix_build_win
2 parents a3d79ab + 3577306 commit 838b15b

File tree

36 files changed

+1338
-384
lines changed

36 files changed

+1338
-384
lines changed

.github/workflows/regression_test_aarch64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ jobs:
3737
# Install executorch first because it installs its own version
3838
# of torch and torchao, which we do not want to use
3939
pip install executorch
40-
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall
40+
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall
4141
pip install -r dev-requirements.txt
4242
USE_CPP=1 TORCHAO_BUILD_KLEIDIAI=1 pip install . --no-build-isolation
4343
- name: Install requirements linux
4444
if: runner.os == 'Linux'
4545
run: |
4646
conda activate venv
4747
pip install coremltools
48-
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall
48+
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cpu --force-reinstall
4949
pip install -r dev-requirements.txt
5050
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install . --no-build-isolation
5151
- name: Run python tests

README.md

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
## 📣 Latest News
2626

27-
- [Oct 20] MXFP8 MoE training prototype achieved **~1.45x speedup** for MoE layer in Llama4 Scout, and **~1.25x** speedup for MoE layer in DeepSeekV3 671b - with comparable numerics to bfloat16! Check out the [docs](./torchao/prototype/moe_training/) to try it out.
27+
- [Oct 25] QAT is now integrated into [Unsloth](https://docs.unsloth.ai/new/quantization-aware-training-qat) for both full and LoRA fine-tuning! Try it out using [this notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_%284B%29_Instruct-QAT.ipynb).
28+
- [Oct 25] MXFP8 MoE training prototype achieved **~1.45x speedup** for MoE layer in Llama4 Scout, and **~1.25x** speedup for MoE layer in DeepSeekV3 671b - with comparable numerics to bfloat16! Check out the [docs](./torchao/prototype/moe_training/) to try it out.
2829
- [Sept 25] MXFP8 training achieved [1.28x speedup on Crusoe B200 cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) with virtually identical loss curve to bfloat16!
2930
- [Sept 19] [TorchAO Quantized Model and Quantization Recipes Now Available on Huggingface Hub](https://pytorch.org/blog/torchao-quantized-models-and-quantization-recipes-now-available-on-huggingface-hub/)!
3031
- [Jun 25] Our [TorchAO paper](https://openreview.net/attachment?id=HpqH0JakHf&name=pdf) was accepted to CodeML @ ICML 2025!
@@ -103,22 +104,6 @@ pip install torchao
103104

104105
Please see the [torchao compability table](https://github.com/pytorch/ao/issues/2919) for version requirements for dependencies.
105106

106-
## 🔗 Integrations
107-
108-
TorchAO is integrated into some of the leading open-source libraries including:
109-
110-
* Unsloth for QAT, blog post coming soon!
111-
* HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
112-
* HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
113-
* vLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html)
114-
* Integration with [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai) for SOTA kernels on server GPUs
115-
* Integration with [ExecuTorch](https://github.com/pytorch/executorch/) for edge device deployment
116-
* Axolotl for [QAT](https://docs.axolotl.ai/docs/qat.html) and [PTQ](https://docs.axolotl.ai/docs/quantize.html)
117-
* TorchTitan for [float8 pre-training](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md)
118-
* HuggingFace PEFT for LoRA using TorchAO as their [quantization backend](https://huggingface.co/docs/peft/en/developer_guides/quantization#torchao-pytorch-architecture-optimization)
119-
* TorchTune for our NF4 [QLoRA](https://docs.pytorch.org/torchtune/main/tutorials/qlora_finetune.html), [QAT](https://docs.pytorch.org/torchtune/main/recipes/qat_distributed.html), and [float8 quantized fine-tuning](https://github.com/pytorch/torchtune/pull/2546) recipes
120-
* SGLang for LLM serving: [usage](https://docs.sglang.ai/advanced_features/quantization.html#online-quantization)
121-
122107
## 🔎 Inference
123108

124109
TorchAO delivers substantial performance gains with minimal code changes:
@@ -265,6 +250,21 @@ We've added support for authoring and releasing [custom ops](./torchao/csrc/) th
265250
If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) or feel free to contribute directly to the repo.
266251
-->
267252

253+
## 🔗 Integrations
254+
255+
TorchAO is integrated into some of the leading open-source libraries including:
256+
257+
* Unsloth for QAT, blog post coming soon!
258+
* HuggingFace transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
259+
* HuggingFace diffusers best practices with `torch.compile` and TorchAO in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
260+
* vLLM for LLM serving: [usage](https://docs.vllm.ai/en/latest/features/quantization/torchao.html), [detailed docs](https://docs.pytorch.org/ao/main/torchao_vllm_integration.html)
261+
* Integration with [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai) for SOTA kernels on server GPUs
262+
* Integration with [ExecuTorch](https://github.com/pytorch/executorch/) for edge device deployment
263+
* Axolotl for [QAT](https://docs.axolotl.ai/docs/qat.html) and [PTQ](https://docs.axolotl.ai/docs/quantize.html)
264+
* TorchTitan for [float8 pre-training](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md)
265+
* HuggingFace PEFT for LoRA using TorchAO as their [quantization backend](https://huggingface.co/docs/peft/en/developer_guides/quantization#torchao-pytorch-architecture-optimization)
266+
* TorchTune for our NF4 [QLoRA](https://docs.pytorch.org/torchtune/main/tutorials/qlora_finetune.html), [QAT](https://docs.pytorch.org/torchtune/main/recipes/qat_distributed.html), and [float8 quantized fine-tuning](https://github.com/pytorch/torchtune/pull/2546) recipes
267+
* SGLang for LLM serving: [usage](https://docs.sglang.ai/advanced_features/quantization.html#online-quantization)
268268

269269
## 🎥 Videos
270270
* [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)

benchmarks/float8/bench_matmul.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from torchao.ops import mx_fp4_bf16
1919
from torchao.prototype.mx_formats.mx_tensor import to_mx
20+
from torchao.prototype.mx_formats.utils import to_blocked
2021
from torchao.testing.training.roofline_utils import get_specs
2122
from torchao.utils import is_MI300
2223

@@ -125,10 +126,16 @@ def run(
125126
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
126127
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
127128
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
129+
# pad if needed
130+
scale_a = to_blocked(scale_a)
131+
scale_b = to_blocked(scale_b)
128132
elif recipe == "nvfp4":
129133
# Use the blockwise scales from nvfp4_quantize
130134
scale_a = A_scales.view(torch.float8_e4m3fn)
131135
scale_b = B_scales.view(torch.float8_e4m3fn)
136+
# pad if needed
137+
scale_a = to_blocked(scale_a)
138+
scale_b = to_blocked(scale_b)
132139
else:
133140
assert False, f"unknown recipe {recipe}"
134141

benchmarks/float8/float8_inference_roofline.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
NVFP4InferenceConfig,
4747
NVFP4MMConfig,
4848
)
49+
from torchao.prototype.mx_formats.utils import to_blocked
4950
from torchao.quantization.quant_api import (
5051
Float8DynamicActivationFloat8WeightConfig,
5152
PerRow,
@@ -134,12 +135,18 @@ def get_gemm_times(
134135
elif recipe_name == "mxfp8_cublas":
135136
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
136137
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
138+
scale_a = to_blocked(scale_a)
139+
scale_b = to_blocked(scale_b)
137140
elif recipe_name == "mxfp4_cutlass":
138141
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
139142
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
143+
scale_a = to_blocked(scale_a)
144+
scale_b = to_blocked(scale_b)
140145
elif recipe_name == "nvfp4":
141146
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
142147
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
148+
scale_a = to_blocked(scale_a)
149+
scale_b = to_blocked(scale_b)
143150

144151
else:
145152
assert False, "unsupported"
@@ -166,16 +173,22 @@ def run(
166173
recipe_name: str,
167174
do_benchmarks: bool = True,
168175
shape_gen_name: str = "pow2",
176+
M: Optional[int] = None,
177+
K: Optional[int] = None,
178+
N: Optional[int] = None,
169179
n_limit: Optional[int] = None,
170180
save_profile_traces: bool = False,
181+
enable_fusion_modeling: bool = False,
171182
):
172183
"""
173184
Args:
174185
* `recipe_name`: quantization recipe (tensorwise, rowwise, mxfp8*, mxfp4*, nvfp4*)
175186
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
176-
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
187+
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, `sweep`, or `custom`
188+
* `M|K|N`: if shape_gen_name is `custom`, then these values are used for MKN
177189
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
178190
# `save_profile_traces (optional)`: if True, saves profiling traces
191+
# `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm
179192
"""
180193
config_table = [
181194
["GPU", torch.cuda.get_device_name(0)],
@@ -184,16 +197,22 @@ def run(
184197
["recipe_name", recipe_name],
185198
["do_benchmarks", do_benchmarks],
186199
["shape_gen_name", shape_gen_name],
200+
["enable_fusion_modeling", enable_fusion_modeling],
201+
["MKN", f"{M} {K} {N}"],
187202
]
188203
print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple"))
189204

205+
# reassign user specified MKN, so we can use them for sympy
206+
user_M, user_K, user_N = M, K, N
207+
190208
M, K, N = sympy.symbols("M K N")
191209

192210
fp8_ovhd_time_sympy = get_inference_float8_mem_sympy(
193211
M,
194212
K,
195213
N,
196214
recipe_name,
215+
# TODO(future): also enable fusion modeling here
197216
)
198217
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None)
199218

@@ -241,7 +260,7 @@ def run(
241260
]
242261
results = []
243262

244-
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None)
263+
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, user_M, user_K, user_N)
245264

246265
for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)):
247266
if n_limit is not None and idx >= n_limit:
@@ -287,9 +306,11 @@ def run(
287306
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
288307
if do_benchmarks:
289308
# create the model
290-
m_orig = (
291-
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
292-
)
309+
if not enable_fusion_modeling:
310+
m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False))
311+
else:
312+
m_orig = nn.Sequential(nn.ReLU(), nn.Linear(K_val, N_val, bias=False))
313+
m_orig = m_orig.cuda().bfloat16()
293314
x = torch.randn(
294315
M_val, K_val, dtype=torch.bfloat16, device="cuda"
295316
).requires_grad_()

benchmarks/prototype/moe_training/bench_moe_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def warmup(model, input, labels):
205205
parser.add_argument(
206206
"--local_batch_size",
207207
type=int,
208-
default=8,
208+
default=12,
209209
)
210210
parser.add_argument(
211211
"--hidden_dim",

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
bench_fwd_microseconds,
2020
profile_fwd_bwd,
2121
)
22-
from torchao.prototype.moe_training import _scaled_grouped_mm
22+
from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm
2323
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
2424
from torchao.prototype.moe_training.utils import generate_jagged_offs
2525

@@ -158,7 +158,7 @@ def run_experiment(
158158

159159
# fwd_bwd scaled benchmark + profiling
160160
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
161-
_scaled_grouped_mm,
161+
_quantize_then_scaled_grouped_mm,
162162
A,
163163
B_t,
164164
offs,
@@ -169,7 +169,7 @@ def run_experiment(
169169
)
170170
if args.profile:
171171
profile_fwd_bwd(
172-
_scaled_grouped_mm,
172+
_quantize_then_scaled_grouped_mm,
173173
A,
174174
B_t,
175175
offs,
@@ -190,7 +190,7 @@ def run_experiment(
190190
fullgraph=True,
191191
)
192192
scaled_fwd_us = bench_fwd_microseconds(
193-
_scaled_grouped_mm,
193+
_quantize_then_scaled_grouped_mm,
194194
A,
195195
B_t,
196196
offs,

0 commit comments

Comments
 (0)