Open
Description
Thanks for the great work!
I tried to enable AutoQuant on top of the latest gpt-fast repository since gpt-fast version that ao repo is providing as an example is outdated.
Here is the diff of enabling AutoQuant on top of the latest gpt-fast codebase.
But I'm getting the error mentioning "CUDA generator expects graph capture to be underway, but the current stream is not capturing".
(/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin) yejinlee@a100-st-p4de24xlarge-47:/fsx-checkpoints/yejinlee/gpt-fast$ python generate.py --compile --checkpoint_path checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --prompt "Hello, my name is" --quantization autoquant
uintx feature need torch 2.3+, please upgrade pytorch
Using device=cuda
Loading model ...
Time to load model: 39.35 seconds
activation_shapes: torch.Size([6, 4096]), times_seen: 1
activation_shapes: torch.Size([1, 4096]), times_seen: 199
weight_shape: torch.Size([12288, 4096]), dtype: torch.bfloat16, bias_shape: None
AUTOTUNE mm(6x4096, 4096x12288)
mm 0.0832 ms 100.0%
triton_mm_8 0.0840 ms 99.0%
triton_mm_6 0.0842 ms 98.8%
triton_mm_4 0.0857 ms 97.1%
triton_mm_3 0.0861 ms 96.6%
triton_mm_9 0.0879 ms 94.7%
triton_mm_5 0.0887 ms 93.8%
triton_mm_2 0.0944 ms 88.1%
triton_mm_1 0.0962 ms 86.5%
triton_mm_0 0.1044 ms 79.7%
SingleProcess AUTOTUNE takes 2.7937 seconds
warning: failed to autoquant AQFloatLinearWeight for shape: (torch.Size([6, 4096]), torch.Size([12288, 4096]), None, torch.bfloat16) due to CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Traceback (most recent call last):
File "/opt/hpcaas/.mounts/fs-08829104cb559c481/yejinlee/gpt-fast/generate.py", line 480, in <module>
main(
File "/opt/hpcaas/.mounts/fs-08829104cb559c481/yejinlee/gpt-fast/generate.py", line 354, in main
model.finalize_autoquant()
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 620, in finalize_autoquant
_change_autoquantizable_to_quantized(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 494, in _change_autoquantizable_to_quantized
_replace_with_custom_fn_if_matches_filter(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 187, in _replace_with_custom_fn_if_matches_filter
new_child = _replace_with_custom_fn_if_matches_filter(
[Previous line repeated 1 more time]
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 183, in _replace_with_custom_fn_if_matches_filter
model = replacement_fn(model)
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/quant_api.py", line 238, in insert_subclass
getattr(cls, from_float)(lin.weight, **kwargs), requires_grad=False
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 146, in to_quantized
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
File "/fsx-checkpoints/yejinlee/condaenv/gptfast_yejin/lib/python3.9/site-packages/torchao-0.4.0+gitc2f44608-py3.9-linux-x86_64.egg/torchao/quantization/autoquant.py", line 94, in tune_autoquant
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
RuntimeError: CUDA generator expects graph capture to be underway, but the current stream is not capturing.
Attaching the env info here
torch 2.3.1+cu121
torchao 0.4.0+gitc2f44608
torchaudio 2.3.1+cu121
torchvision 0.18.1+cu121
Thanks for the help in advance!