Skip to content

Commit d7c709e

Browse files
committed
SD3 small tweaks for numerics
1 parent 6c9d96d commit d7c709e

File tree

3 files changed

+53
-12
lines changed

3 files changed

+53
-12
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -445,16 +445,18 @@ def generate_images(
445445
numpy_images = []
446446

447447
for i in range(batch_count):
448-
generator = torch.random.manual_seed(seed + i)
448+
generator = torch.Generator().manual_seed(int(seed))
449+
shape = (
450+
self.batch_size,
451+
16,
452+
self.height // 8,
453+
self.width // 8,
454+
)
449455
rand_sample = torch.randn(
450-
(
451-
self.batch_size,
452-
16,
453-
self.height // 8,
454-
self.width // 8,
455-
),
456+
shape,
456457
generator=generator,
457-
dtype=torch_dtype,
458+
dtype=torch.float32,
459+
layout=torch.strided,
458460
)
459461
samples.append(
460462
ireert.asdevicearray(
@@ -499,7 +501,6 @@ def generate_images(
499501
prompt_embeds, pooled_prompt_embeds = self.runners[
500502
"text_encoders"
501503
].ctx.modules.compiled_text_encoder["encode_tokens"](*text_encoders_inputs)
502-
503504
encode_prompts_end = time.time()
504505

505506
for i in range(batch_count):
@@ -617,11 +618,51 @@ def generate_images(
617618
image.save(img_path)
618619
print(img_path, "saved")
619620
return
621+
622+
def run_diffusers_cpu(
623+
hf_model_name,
624+
prompt,
625+
negative_prompt,
626+
guidance_scale,
627+
seed,
628+
height,
629+
width,
630+
num_inference_steps,
631+
):
632+
from diffusers import StableDiffusion3Pipeline
633+
634+
pipe = StableDiffusion3Pipeline.from_pretrained(hf_model_name, torch_dtype=torch.float32)
635+
pipe = pipe.to("cpu")
636+
generator = torch.Generator().manual_seed(int(seed))
637+
638+
image = pipe(
639+
prompt=prompt,
640+
negative_prompt=negative_prompt,
641+
num_inference_steps=num_inference_steps,
642+
guidance_scale=guidance_scale,
643+
height=height,
644+
width=width,
645+
generator=generator,
646+
).images[0]
647+
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
648+
image.save(f"diffusers_reference_output_{timestamp}.png")
620649

621650

622651
if __name__ == "__main__":
623652
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
624653

654+
if args.compare_vs_torch:
655+
run_diffusers_cpu(
656+
args.hf_model_name,
657+
args.prompt,
658+
args.negative_prompt,
659+
args.guidance_scale,
660+
args.seed,
661+
args.height,
662+
args.width,
663+
args.num_inference_steps,
664+
)
665+
exit()
625666
map = empty_pipe_dict
626667
mlirs = copy.deepcopy(map)
627668
vmfbs = copy.deepcopy(map)

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def __init__(
6666
def initialize(self, sample):
6767
step_count = torch.tensor(len(self.timesteps))
6868
timesteps = self.model.timesteps
69-
# ops.trace_tensor("timesteps", self.timesteps)
69+
ops.trace_tensor("sample", sample[:,:,0,0])
7070
return (
71-
sample.type(self.dtype),
71+
sample,
7272
step_count,
7373
timesteps.type(torch.float32),
7474
)

models/turbine_models/custom_models/sd3_inference/sd3_text_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def forward(self, tokens_g, tokens_l, tokens_t5xxl, neg_g, neg_l, neg_t5):
105105
neg_cond, neg_cond_pool = self.get_cond(neg_l, neg_g, neg_t5)
106106

107107
prompt_embeds = torch.cat([neg_cond, conditioning], dim=0)
108-
pooled_prompt_embeds = torch.cat([cond_pool, neg_cond_pool], dim=0)
108+
pooled_prompt_embeds = torch.cat([neg_cond_pool, cond_pool], dim=0)
109109

110110
return prompt_embeds, pooled_prompt_embeds
111111

0 commit comments

Comments
 (0)