Skip to content

Commit 618d01f

Browse files
committed
Point to azure links for specs and fix timesteps dim in gpu scheduler.
1 parent b1f20f1 commit 618d01f

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def prepare_model_input(self, sample, t, timesteps):
8282
latent_model_input = torch.cat([sample] * 2)
8383
else:
8484
latent_model_input = sample
85-
t = t.expand(sample.shape[0])
85+
t = t.expand(latent_model_input.shape[0])
8686
return latent_model_input.type(self.dtype), t.type(self.dtype)
8787

8888
def step(self, noise_pred, t, sample, guidance_scale, i):

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"--iree-codegen-gpu-native-math-precision=true",
3636
"--iree-rocm-waves-per-eu=2",
3737
"--iree-flow-inline-constants-max-byte-length=1",
38-
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
38+
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,128,0,32,0}))",
3939
],
4040
"unet": [
4141
"--iree-flow-enable-aggressive-fusion",
@@ -275,7 +275,7 @@ def create_safe_name(hf_model_name, model_name_str):
275275

276276

277277
def get_mfma_spec_path(target_chip, save_dir):
278-
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
278+
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx942.mlir"
279279
attn_spec = urlopen(url).read().decode("utf-8")
280280
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
281281
if os.path.exists(spec_path):
@@ -287,9 +287,9 @@ def get_mfma_spec_path(target_chip, save_dir):
287287

288288
def get_wmma_spec_path(target_chip, save_dir):
289289
if target_chip == "gfx1100":
290-
url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1100.spec.mlir"
290+
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1100.mlir"
291291
elif target_chip in ["gfx1103", "gfx1150"]:
292-
url = "https://github.com/iree-org/iree/raw/shared/tresleches-united/scripts/attention_gfx1103.spec.mlir"
292+
url = "https://sharkpublic.blob.core.windows.net/sharkpublic/specs/latest/attention_and_matmul_spec_gfx1150.mlir"
293293
else:
294294
return None
295295
attn_spec = urlopen(url).read().decode("utf-8")

0 commit comments

Comments
 (0)