Skip to content

Commit 0291d43

Browse files
committed
Small fixes to SDXL inference pipeline/exports/compile
1 parent 05fa32d commit 0291d43

File tree

5 files changed

+6
-7
lines changed

5 files changed

+6
-7
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ def __init__(
368368
target, dict
369369
), "Device and target triple must be both dicts or both strings."
370370
for submodel in self.map.keys():
371+
if self.map[submodel].get("load") == False:
372+
continue
371373
assert submodel in device.keys(), f"Device for {submodel} not found."
372374
assert (
373375
submodel in target.keys()

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@
120120
"decomp_attn": None,
121121
},
122122
},
123+
}
124+
sdxl_compiled_pipeline_map = {
123125
"unetloop": {
124126
"module_name": "sdxl_compiled_pipeline",
125127
"load": False,
@@ -434,7 +436,7 @@ def load_scheduler(
434436
if self.is_sd3:
435437
export_fn = sd3_schedulers.export_scheduler_model
436438
else:
437-
export_fn = scheduler.export_scheduler_model
439+
export_fn = schedulers.export_scheduler_model
438440
self.map["scheduler"] = {
439441
"module_name": "compiled_scheduler",
440442
"export_fn": export_fn,

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def get_mfma_spec_path(target_chip, save_dir, masked_attention=False, use_punet=
476476
url = "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/specs/attention_and_matmul_spec.mlir"
477477
elif not masked_attention:
478478
suffix = ""
479-
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/no_pad/attention_and_matmul_spec_mfma.mlir"
479+
url = "https://raw.githubusercontent.com/iree-org/iree/refs/heads/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
480480
else:
481481
suffix = "_pad"
482482
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"

models/turbine_models/custom_models/sd_inference/vae.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def export_vae_model(
171171
vae_model,
172172
external_weights,
173173
external_weight_path,
174-
vae_harness=vae_harness,
175174
)
176175
if weights_only:
177176
return external_weight_path

models/turbine_models/custom_models/sdxl_inference/unet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ def export_unet_model(
205205
if not attn_spec:
206206
if (not decomp_attn) and use_punet:
207207
attn_spec = "punet"
208-
elif (not decomp_attn) and "gfx9" in target:
209-
attn_spec = "mfma"
210-
elif (not decomp_attn) and "gfx11" in target:
211-
attn_spec = "wmma"
212208
safe_name = utils.create_safe_name(
213209
hf_model_name,
214210
f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_{submodel_name}",

0 commit comments

Comments
 (0)