Skip to content

Commit b1f20f1

Browse files
committed
Fix numerics, add some features to VAE runner, add cpu scheduling options
1 parent fd2a2ba commit b1f20f1

File tree

7 files changed

+141
-37
lines changed

7 files changed

+141
-37
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,12 @@ def is_valid_file(arg):
346346
action="store_true",
347347
help="Just compile attention reproducer for mmdit.",
348348
)
349+
p.add_argument(
350+
"--vae_input_path",
351+
type=str,
352+
default=None,
353+
help="Path to input latents for VAE inference numerics validation.",
354+
)
349355

350356

351357
##############################################################################

models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def export_mmdit_model(
207207
torch.empty(hidden_states_shape, dtype=dtype),
208208
torch.empty(encoder_hidden_states_shape, dtype=dtype),
209209
torch.empty(pooled_projections_shape, dtype=dtype),
210-
torch.empty(1, dtype=dtype),
210+
torch.empty(init_batch_dim, dtype=dtype),
211211
]
212212

213213
decomp_list = []

models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
154154
(batch_size, args.max_length * 2, 4096), dtype=dtype
155155
)
156156
pooled_projections = torch.randn((batch_size, 2048), dtype=dtype)
157-
timestep = torch.tensor([0], dtype=dtype)
157+
timestep = torch.tensor([0, 0], dtype=dtype)
158158

159159
turbine_output = run_mmdit_turbine(
160160
hidden_states,
@@ -180,6 +180,7 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
180180
timestep,
181181
args,
182182
)
183+
np.save("torch_mmdit_output.npy", torch_output.astype(np.float16))
183184
print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
184185

185186
print("\n(torch (comfy) image latents to iree image latents): ")

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from turbine_models.custom_models.sd_inference import utils
1818
from turbine_models.model_runner import vmfbRunner
1919
from transformers import CLIPTokenizer
20+
from diffusers import FlowMatchEulerDiscreteScheduler
2021

2122
from PIL import Image
2223
import os
@@ -426,10 +427,16 @@ def load_pipeline(
426427
unet_loaded = time.time()
427428
print("\n[LOG] MMDiT loaded in ", unet_loaded - load_start, "sec")
428429

429-
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
430-
self.devices["mmdit"]["driver"],
431-
vmfbs["scheduler"],
432-
)
430+
if not self.cpu_scheduling:
431+
runners["scheduler"] = sd3_schedulers.SharkSchedulerWrapper(
432+
self.devices["mmdit"]["driver"],
433+
vmfbs["scheduler"],
434+
)
435+
else:
436+
print("Using torch CPU scheduler.")
437+
runners["scheduler"] = FlowMatchEulerDiscreteScheduler.from_pretrained(
438+
self.hf_model_name, subfolder="scheduler"
439+
)
433440

434441
sched_loaded = time.time()
435442
print("\n[LOG] Scheduler loaded in ", sched_loaded - unet_loaded, "sec")
@@ -502,11 +509,12 @@ def generate_images(
502509
)
503510
)
504511

505-
guidance_scale = ireert.asdevicearray(
506-
self.runners["pipe"].config.device,
507-
np.asarray([guidance_scale]),
508-
dtype=iree_dtype,
509-
)
512+
if not self.cpu_scheduling:
513+
guidance_scale = ireert.asdevicearray(
514+
self.runners["pipe"].config.device,
515+
np.asarray([guidance_scale]),
516+
dtype=iree_dtype,
517+
)
510518

511519
tokenize_start = time.time()
512520
text_input_ids_dict = self.tokenizer.tokenize_with_weights(prompt)
@@ -540,12 +548,23 @@ def generate_images(
540548
"clip"
541549
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
542550
encode_prompts_end = time.time()
551+
if self.cpu_scheduling:
552+
timesteps, num_inference_steps = sd3_schedulers.retrieve_timesteps(
553+
self.runners["scheduler"],
554+
num_inference_steps=self.num_inference_steps,
555+
timesteps=None,
556+
)
557+
steps = num_inference_steps
558+
543559

544560
for i in range(batch_count):
545561
unet_start = time.time()
546-
sample, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
562+
if not self.cpu_scheduling:
563+
latents, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
564+
else:
565+
latents = torch.tensor(samples[i].to_host(), dtype=self.torch_dtype)
547566
iree_inputs = [
548-
sample,
567+
latents,
549568
ireert.asdevicearray(
550569
self.runners["pipe"].config.device, prompt_embeds, dtype=iree_dtype
551570
),
@@ -560,41 +579,71 @@ def generate_images(
560579
# print(f"step {s}")
561580
if self.cpu_scheduling:
562581
step_index = s
582+
t = timesteps[s]
583+
if self.do_classifier_free_guidance:
584+
latent_model_input = torch.cat([latents] * 2)
585+
timestep = ireert.asdevicearray(
586+
self.runners["pipe"].config.device,
587+
t.expand(latent_model_input.shape[0]),
588+
dtype=iree_dtype,
589+
)
590+
latent_model_input = ireert.asdevicearray(
591+
self.runners["pipe"].config.device,
592+
latent_model_input,
593+
dtype=iree_dtype,
594+
)
563595
else:
564596
step_index = ireert.asdevicearray(
565597
self.runners["scheduler"].runner.config.device,
566598
torch.tensor([s]),
567599
"int64",
568600
)
569-
latents, t = self.runners["scheduler"].prep(
570-
sample,
571-
step_index,
572-
timesteps,
573-
)
601+
latent_model_input, timestep = self.runners["scheduler"].prep(
602+
latents,
603+
step_index,
604+
timesteps,
605+
)
606+
t = ireert.asdevicearray(
607+
self.runners["scheduler"].runner.config.device,
608+
timestep.to_host()[0]
609+
)
574610
noise_pred = self.runners["pipe"].ctx.modules.compiled_mmdit[
575611
"run_forward"
576612
](
577-
latents,
613+
latent_model_input,
578614
iree_inputs[1],
579615
iree_inputs[2],
580-
t,
581-
)
582-
sample = self.runners["scheduler"].step(
583-
noise_pred,
584-
t,
585-
sample,
586-
guidance_scale,
587-
step_index,
616+
timestep,
588617
)
589-
if isinstance(sample, torch.Tensor):
618+
if not self.cpu_scheduling:
619+
latents = self.runners["scheduler"].step(
620+
noise_pred,
621+
t,
622+
latents,
623+
guidance_scale,
624+
step_index,
625+
)
626+
else:
627+
noise_pred = torch.tensor(noise_pred.to_host(), dtype=self.torch_dtype)
628+
if self.do_classifier_free_guidance:
629+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
630+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
631+
latents = self.runners["scheduler"].step(
632+
noise_pred,
633+
t,
634+
latents,
635+
return_dict=False,
636+
)[0]
637+
638+
if isinstance(latents, torch.Tensor):
639+
latents = latents.type(self.vae_dtype)
590640
latents = ireert.asdevicearray(
591641
self.runners["vae"].config.device,
592-
sample,
593-
dtype=self.vae_dtype,
642+
latents,
594643
)
595644
else:
596645
vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16
597-
latents = sample.astype(vae_numpy_dtype)
646+
latents = latents.astype(vae_numpy_dtype)
598647

599648
vae_start = time.time()
600649
vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents)
@@ -791,10 +840,10 @@ def run_diffusers_cpu(
791840
cpu_scheduling=args.cpu_scheduling,
792841
vae_precision=args.vae_precision,
793842
)
794-
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
795843
if args.cpu_scheduling:
796844
vmfbs.pop("scheduler")
797845
weights.pop("scheduler")
846+
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
798847
if args.npu_delegate_path:
799848
extra_device_args = {"npu_delegate_path": args.npu_delegate_path}
800849
else:

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import os
8+
import inspect
89
from typing import List
910

1011
import torch
12+
from typing import Any, Callable, Dict, List, Optional, Union
1113
from shark_turbine.aot import *
1214
import shark_turbine.ops.iree as ops
1315
from iree.compiler.ir import Context
@@ -75,11 +77,12 @@ def initialize(self, sample):
7577

7678
def prepare_model_input(self, sample, t, timesteps):
7779
t = timesteps[t]
78-
t = t.expand(sample.shape[0])
80+
7981
if self.do_classifier_free_guidance:
8082
latent_model_input = torch.cat([sample] * 2)
8183
else:
8284
latent_model_input = sample
85+
t = t.expand(sample.shape[0])
8386
return latent_model_input.type(self.dtype), t.type(self.dtype)
8487

8588
def step(self, noise_pred, t, sample, guidance_scale, i):
@@ -146,6 +149,42 @@ def step(self, noise_pred, t, latents, guidance_scale, i):
146149
return_dict=False,
147150
)[0]
148151

152+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
153+
# Only used for cpu scheduling.
154+
def retrieve_timesteps(
155+
scheduler,
156+
num_inference_steps: Optional[int] = None,
157+
device: Optional[Union[str, torch.device]] = None,
158+
timesteps: Optional[List[int]] = None,
159+
sigmas: Optional[List[float]] = None,
160+
**kwargs,
161+
):
162+
if timesteps is not None and sigmas is not None:
163+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
164+
if timesteps is not None:
165+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
166+
if not accepts_timesteps:
167+
raise ValueError(
168+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
169+
f" timestep schedules. Please check whether you are using the correct scheduler."
170+
)
171+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
172+
timesteps = scheduler.timesteps
173+
num_inference_steps = len(timesteps)
174+
elif sigmas is not None:
175+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
176+
if not accept_sigmas:
177+
raise ValueError(
178+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
179+
f" sigmas schedules. Please check whether you are using the correct scheduler."
180+
)
181+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
182+
timesteps = scheduler.timesteps
183+
num_inference_steps = len(timesteps)
184+
else:
185+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
186+
timesteps = scheduler.timesteps
187+
return timesteps, num_inference_steps
149188

150189
@torch.no_grad()
151190
def export_scheduler_model(

models/turbine_models/custom_models/sd3_inference/sd3_vae.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
)
3434

3535
def decode(self, inp):
36+
inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor
3637
image = self.vae.decode(inp, return_dict=False)[0]
3738
image = image.float()
3839
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]

models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,17 @@ def imagearray_from_vae_out(image):
4545
if __name__ == "__main__":
4646
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
4747
import numpy as np
48+
from PIL import Image
4849

4950
dtype = torch.float16 if args.precision == "fp16" else torch.float32
5051
if args.vae_variant == "decode":
5152
example_input = torch.rand(
5253
args.batch_size, 16, args.height // 8, args.width // 8, dtype=dtype
5354
)
55+
if args.vae_input_path:
56+
example_input = np.load(args.vae_input_path)
57+
if example_input.shape[0] == 2:
58+
example_input = np.split(example_input, 2)[0]
5459
elif args.vae_variant == "encode":
5560
example_input = torch.rand(
5661
args.batch_size, 3, args.height, args.width, dtype=dtype
@@ -74,13 +79,16 @@ def imagearray_from_vae_out(image):
7479
from turbine_models.custom_models.sd_inference import utils
7580

7681
torch_output = run_torch_vae(
77-
args.hf_model_name, args.vae_variant, example_input.float()
82+
args.hf_model_name, args.vae_variant, torch.tensor(example_input).float()
7883
)
7984
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
85+
if args.vae_input_path:
86+
out_image_torch = Image.fromarray(torch_output)
87+
out_image_torch.save("vae_test_output_torch.png")
88+
out_image_turbine = Image.fromarray(turbine_results)
89+
out_image_turbine.save("vae_test_output_turbine.png")
8090
# Allow a small amount of wiggle room for rounding errors (1)
91+
8192
np.testing.assert_allclose(
8293
turbine_results, torch_output, rtol=1, atol=1
8394
)
84-
85-
# TODO: Figure out why we occasionally segfault without unlinking output variables
86-
turbine_results = None

0 commit comments

Comments
 (0)