Skip to content

Commit b793686

Browse files
committed
Attn debugging, piping for multi-device in sd3
1 parent 81ee093 commit b793686

File tree

4 files changed

+52
-19
lines changed

4 files changed

+52
-19
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,17 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
110110

111111
if args.precision == "fp16":
112112
dtype = torch.float16
113+
np_dtype = np.float16
113114
else:
114115
dtype = torch.float32
116+
np_dtype = np.float32
115117

116118
if args.attn_repro:
117119
qkv_shape = (2, 24, 4250, 64)
118120
example_qkv = [
119-
np.load("q.npy").astype(np.float16),
120-
np.load("k.npy").astype(np.float16),
121-
np.load("v.npy").astype(np.float16),
121+
np.load("q.npy").astype(np_dtype),
122+
np.load("k.npy").astype(np_dtype),
123+
np.load("v.npy").astype(np_dtype),
122124
]
123125
turbine_output = run_attn_turbine(
124126
*example_qkv,

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def __init__(
6868
max_length: int,
6969
batch_size: int,
7070
num_inference_steps: int,
71-
device: str,
72-
iree_target_triple: str,
71+
device: str | dict[str],
72+
iree_target_triple: str | dict[str],
7373
ireec_flags: dict = EMPTY_FLAGS,
7474
attn_spec: str = None,
7575
decomp_attn: bool = False,
@@ -89,7 +89,25 @@ def __init__(
8989
self.max_length = max_length
9090
self.batch_size = batch_size
9191
self.num_inference_steps = num_inference_steps
92-
self.device = device
92+
self.devices = {}
93+
if isinstance(self.device, dict):
94+
assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings."
95+
self.devices["clip"] = {
96+
"device": device["clip"],
97+
"target": iree_target_triple["clip"]
98+
}
99+
self.devices["mmdit"] = {
100+
"device": device["mmdit"],
101+
"target": iree_target_triple["mmdit"]
102+
}
103+
self.devices["vae"] = {
104+
"device": device["vae"],
105+
"target": iree_target_triple["vae"]
106+
}
107+
else:
108+
self.devices["clip"] = device
109+
self.devices["mmdit"] = device
110+
self.devices["vae"] = device
93111
self.iree_target_triple = iree_target_triple
94112
self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
95113
self.attn_spec = attn_spec
@@ -291,8 +309,8 @@ def export_submodel(
291309
"vmfb",
292310
self.external_weights,
293311
mmdit_external_weight_path,
294-
self.device,
295-
self.iree_target_triple,
312+
self.devices["mmdit"]["device"],
313+
self.devices["mmdit"]["target"],
296314
self.ireec_flags["mmdit"],
297315
self.decomp_attn,
298316
exit_on_vmfb=False,
@@ -313,8 +331,8 @@ def export_submodel(
313331
self.num_inference_steps,
314332
self.precision,
315333
"vmfb",
316-
self.device,
317-
self.iree_target_triple,
334+
self.devices["mmdit"]["device"],
335+
self.devices["mmdit"]["target"],
318336
self.ireec_flags["scheduler"],
319337
exit_on_vmfb=False,
320338
pipeline_dir=self.pipeline_dir,
@@ -336,8 +354,8 @@ def export_submodel(
336354
"vmfb",
337355
self.external_weights,
338356
vae_external_weight_path,
339-
self.device,
340-
self.iree_target_triple,
357+
self.devices["vae"]["device"],
358+
self.devices["vae"]["target"],
341359
self.ireec_flags["vae"],
342360
self.vae_decomp_attn,
343361
exit_on_vmfb=False,
@@ -357,8 +375,8 @@ def export_submodel(
357375
"vmfb",
358376
self.external_weights,
359377
text_encoders_external_weight_path,
360-
self.device,
361-
self.iree_target_triple,
378+
self.devices["clip"]["device"],
379+
self.devices["clip"]["target"],
362380
self.ireec_flags["clip"],
363381
exit_on_vmfb=False,
364382
pipeline_dir=self.pipeline_dir,
@@ -374,10 +392,15 @@ def load_pipeline(
374392
self,
375393
vmfbs: dict,
376394
weights: dict,
377-
rt_device: str = "local-task",
395+
rt_device: str | dict[str],
378396
compiled_pipeline: bool = False,
379397
split_scheduler: bool = True,
398+
extra_device_args: dict = {},
380399
):
400+
if "npu_delegate_path" in extra_device_args.keys():
401+
delegate = extra_device_args["npu_delegate_path"]
402+
else:
403+
delegate = None
381404
self.runners = {}
382405
runners = {}
383406
load_start = time.time()
@@ -399,7 +422,7 @@ def load_pipeline(
399422
runners["vae"] = vmfbRunner(
400423
rt_device,
401424
vmfbs["vae"],
402-
weights["vae"],
425+
weights["vae"],
403426
)
404427
vae_loaded = time.time()
405428
print("\n[LOG] VAE Decode loaded in ", vae_loaded - sched_loaded, "sec")

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,19 @@
6666
}
6767
znver4_flags = {
6868
"all": [
69-
# "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))",
7069
"--iree-llvmcpu-target-cpu=znver4",
7170
"--iree-opt-const-eval=false",
7271
"--iree-llvmcpu-enable-ukernels=mmt4d,pack,unpack",
7372
"--iree-flow-collapse-reduction-dims",
7473
"--iree-opt-const-expr-max-size-increase-threshold=1000000000000000",
7574
"--iree-flow-enable-fuse-padding-into-linalg-consumer-ops",
7675
],
76+
"bf16": [
77+
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-demote-contraction-inputs-to-bf16))",
78+
],
79+
"winograd": [
80+
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-linalg-ext-convert-conv2d-to-winograd{replace-all-convs=true},iree-global-opt-demote-contraction-inputs-to-bf16))"
81+
],
7782
}
7883

7984

@@ -182,10 +187,12 @@ def compile_to_vmfb(
182187
if attn_spec in ["default", "mfma"]:
183188
attn_spec = get_mfma_spec_path(target_triple, os.path.dirname(safe_name))
184189
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
185-
elif attn_spec in ["wmma"] or "gfx11" in target_triple:
190+
elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec):
186191
attn_spec = get_wmma_spec_path(target_triple, os.path.dirname(safe_name))
187192
if attn_spec:
188193
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
194+
elif attn_spec and attn_spec != "None":
195+
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
189196

190197
for i, flag in enumerate(ireec_flags):
191198
k = flag.strip().split("=")[0]

models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,14 +773,15 @@ def generate_images(
773773
](samples[i], prompt_embeds, add_text_embeds, guidance_scale)
774774

775775
vae_start = time.time()
776+
np.save("latents_winter_cat.npy", latents.to_host().astype(np.float32))
776777
vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"](
777778
latents
778779
)
779780

780781
pipe_end = time.time()
781782

782783
image = vae_out.to_host()
783-
784+
np.save("image_winter_cat.npy", image.astype(np.float32))
784785
numpy_images.append(image)
785786
print("Batch #", i + 1, "\n")
786787
print(

0 commit comments

Comments
 (0)