Skip to content

Commit 6f6585b

Browse files
committed
Fix an issue with hardcoded iterations
1 parent d23b528 commit 6f6585b

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def prepare_all(
437437
vmfbs: dict = {},
438438
weights: dict = {},
439439
interactive: bool = False,
440+
num_steps: int = 20,
440441
):
441442
ready = self.is_prepared(vmfbs, weights)
442443
match ready:
@@ -463,7 +464,7 @@ def prepare_all(
463464
if not self.map[submodel].get("weights") and self.map[submodel][
464465
"export_args"
465466
].get("external_weights"):
466-
self.export_submodel(submodel, weights_only=True)
467+
self.export_submodel(submodel, weights_only=True, num_steps=num_steps)
467468
return self.prepare_all(mlirs, vmfbs, weights, interactive)
468469

469470
def is_prepared(self, vmfbs, weights):
@@ -581,6 +582,7 @@ def export_submodel(
581582
submodel: str,
582583
input_mlir: str = None,
583584
weights_only: bool = False,
585+
num_steps: int = 20,
584586
):
585587
if not os.path.exists(self.pipeline_dir):
586588
os.makedirs(self.pipeline_dir)
@@ -672,6 +674,7 @@ def export_submodel(
672674
self.map[submodel]["export_args"]["max_length"],
673675
"produce_img_split",
674676
unet_module_name=self.map["unet"]["module_name"],
677+
num_steps=num_steps,
675678
)
676679
dims = [
677680
self.map[submodel]["export_args"]["width"],

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def numpy_to_pil_image(images):
831831
False,
832832
args.compiled_pipeline,
833833
)
834-
sd_pipe.prepare_all()
834+
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
835835
sd_pipe.load_map()
836836
sd_pipe.generate_images(
837837
args.prompt,

0 commit comments

Comments
 (0)