Skip to content

[ROCm] Unable to Run FPX Weights #967

Open
@Beinsezii

Description

@Beinsezii

Compiling ao from source using pip install git+https://github.com/pytorch/ao.git results in a very fun throw

NotImplementedError: Could not run 'torchao::quant_llm_linear' with arguments from the 'CUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'torchao::quant_llm_linear' is only available for these backends: [Meta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

when running FPX weights using the script below

import torch
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from torchao.quantization import fpx_weight_only, quantize_

@torch.no_grad()
def main():
    pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
    quantize_(pipe.unet, fpx_weight_only(3, 2))
    pipe(
        prompt="high resolution dslr photograph of a kitten in a field of flowers",
        negative_prompt="blurry, noisy, cropped",
        num_inference_steps=20,
        guidance_scale=5,
        seed=0,
    ).images[0].save("fp6.png")

if __name__ == "__main__":
    main()

Setup is 1x 7900XTX on torch 2.5+rocm62. All other quantizations work just fine, with the exception of float8_dynamic_activation_float8_weight because gfx11 currently does not implement torch's _scaled_mm() function

Using bfloat16 as the base dtype instead actually does run but it's wicked slow from conversions. The floatx readme states to use float16 so I assume that's the correct way.

Python traceback
traceback.txt

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions