Open
Description
Compiling ao from source using pip install git+https://github.com/pytorch/ao.git
results in a very fun throw
NotImplementedError: Could not run 'torchao::quant_llm_linear' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'torchao::quant_llm_linear' is only available for these backends: [Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].
when running FPX weights using the script below
import torch
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from torchao.quantization import fpx_weight_only, quantize_
@torch.no_grad()
def main():
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
quantize_(pipe.unet, fpx_weight_only(3, 2))
pipe(
prompt="high resolution dslr photograph of a kitten in a field of flowers",
negative_prompt="blurry, noisy, cropped",
num_inference_steps=20,
guidance_scale=5,
seed=0,
).images[0].save("fp6.png")
if __name__ == "__main__":
main()
Setup is 1x 7900XTX on torch 2.5+rocm62. All other quantizations work just fine, with the exception of float8_dynamic_activation_float8_weight
because gfx11 currently does not implement torch's _scaled_mm()
function
Using bfloat16
as the base dtype instead actually does run but it's wicked slow from conversions. The floatx readme states to use float16
so I assume that's the correct way.
Python traceback
traceback.txt