diff --git a/requirements.txt b/requirements.txt index d6d51859..d904ae43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,3 +32,5 @@ invisible-watermark torchmetrics<1.0.0 kornia open-clip-torch<2.26.1 +diffusers +accelerate diff --git a/text_to_image/stable_diffusion/run_hf.py b/text_to_image/stable_diffusion/run_hf.py index be51f051..91a62091 100644 --- a/text_to_image/stable_diffusion/run_hf.py +++ b/text_to_image/stable_diffusion/run_hf.py @@ -40,17 +40,52 @@ def single_pass_pytorch(_runner, _stablediffusion): _stablediffusion.submit_count(batch_size, x_samples) runner = PyTorchRunnerV2(model) - stablediffusion = StableDiffusion() - return run_model(single_pass_pytorch, runner, stablediffusion, batch_size, num_runs, timeout) + stable_diffusion_dataset = StableDiffusion() + return run_model(single_pass_pytorch, runner, stable_diffusion_dataset, batch_size, num_runs, timeout) + + +def run_pytorch_fp32(model_name, steps, batch_size, num_runs, timeout, **kwargs): + import torch._dynamo + from diffusers import DiffusionPipeline + torch._dynamo.config.suppress_errors = True + + from utils.benchmark import run_model + from utils.pytorch import apply_compile + from utils.pytorch import PyTorchRunnerV2 + from utils.text_to_image.stable_diffusion import StableDiffusion + + model = DiffusionPipeline.from_pretrained(model_name, + use_safetensors=True).to("cpu") + + model.unet = apply_compile(model.unet) + + def single_pass_pytorch(_runner, _stablediffusion): + prompts = [_stablediffusion.get_input() for _ in range(batch_size)] + x_samples = _runner.run(batch_size * steps, prompt=prompts, num_inference_steps=steps) + _stablediffusion.submit_count(batch_size, x_samples) + + runner = PyTorchRunnerV2(model) + stable_diffusion_dataset = StableDiffusion() + return run_model(single_pass_pytorch, runner, stable_diffusion_dataset, batch_size, num_runs, timeout) if __name__ == "__main__": from utils.helpers import DefaultArgParser + from utils.misc import print_goodbye_message_and_die stablediffusion_variants = ["stabilityai/stable-diffusion-xl-base-1.0"] parser = DefaultArgParser(["pytorch"]) parser.require_model_name(stablediffusion_variants) parser.ask_for_batch_size() parser.add_argument("--steps", type=int, default=25, help="steps through which the model processes the input") + parser.add_argument("-p", "--precision", type=str, choices=["fp32", "bf16"], required=True, + help="precision in which to run the model") - run_pytorch_bf16(**vars(parser.parse())) + args = parser.parse() + if args.precision == "fp32": + run_pytorch_fp32(**vars(parser.parse())) + elif args.precision == "bf16": + run_pytorch_bf16(**vars(parser.parse())) + else: + print_goodbye_message_and_die( + "this model seems to be unsupported in a specified precision: " + args.precision)