Skip to content

Commit 90844af

Browse files
committed
Address comments
1 parent 5a9aaa0 commit 90844af

File tree

3 files changed

+59
-220
lines changed

3 files changed

+59
-220
lines changed
Lines changed: 9 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import os
3-
from pathlib import Path
43

54

65
def path_expand(s):
@@ -28,7 +27,7 @@ def is_valid_file(arg):
2827
)
2928

3029
##############################################################################
31-
# SDXL Huggingface Options
30+
# Huggingface Options
3231
##############################################################################
3332

3433
p.add_argument(
@@ -41,13 +40,7 @@ def is_valid_file(arg):
4140
"--hf_model_name",
4241
type=str,
4342
help="HF model name",
44-
default="Trelis/Llama-2-7b-chat-hf-function-calling-v2",
45-
)
46-
p.add_argument(
47-
"--scheduler_id",
48-
type=str,
49-
help="Scheduler ID",
50-
default="Euler",
43+
default="meta-llama/Llama-2-7b-chat-hf",
5144
)
5245

5346
##############################################################################
@@ -56,39 +49,14 @@ def is_valid_file(arg):
5649
##############################################################################
5750

5851
p.add_argument(
59-
"--prompt",
60-
type=str,
61-
default=" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
62-
help="Prompt input to stable diffusion.",
52+
"--seed", type=float, default=0, help="Seed for random number/latents generation."
6353
)
6454

6555
p.add_argument(
66-
"--negative_prompt",
56+
"--pipeline_dir",
6757
type=str,
68-
default="Watermark, blurry, oversaturated, low resolution, pollution",
69-
help="Negative prompt input to stable diffusion.",
70-
)
71-
72-
p.add_argument(
73-
"--num_inference_steps", type=int, default=30, help="Number of UNet inference steps"
74-
)
75-
76-
p.add_argument(
77-
"--batch_count",
78-
type=int,
79-
default=1,
80-
help="Number of batches to run for a single prompt",
81-
)
82-
83-
p.add_argument(
84-
"--guidance_scale",
85-
type=float,
86-
default=7.5,
87-
help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.",
88-
)
89-
90-
p.add_argument(
91-
"--seed", type=float, default=0, help="Seed for random number/latents generation."
58+
default="",
59+
help="Path to location of vmfb files.",
9260
)
9361

9462
p.add_argument(
@@ -109,73 +77,30 @@ def is_valid_file(arg):
10977
"--vmfb_path", type=str, default="", help="path to vmfb containing compiled module"
11078
)
11179

112-
p.add_argument(
113-
"--pipeline_vmfb_path",
114-
type=str,
115-
default="",
116-
help="path to vmfb containing compiled meta-module",
117-
)
118-
11980
p.add_argument(
12081
"--external_weight_file",
12182
type=str,
12283
default=None,
12384
help="Path to external weights, used in benchmark scripts.",
12485
)
12586

126-
p.add_argument(
127-
"--pipeline_dir",
128-
type=str,
129-
default=None,
130-
help="Directory to save pipeline artifacts",
131-
)
132-
133-
p.add_argument(
134-
"--compiled_pipeline",
135-
default=False,
136-
action="store_true",
137-
help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.",
138-
)
13987

14088
##############################################################################
141-
# SDXL Modelling Options
142-
# These options are used to control model defining parameters for SDXL.
89+
# Modelling Options
90+
# These options are used to control model defining parameters.
14391
# These are MLIR - changing variables! If you change them, you will need
14492
# to import/download and recompile the model.
14593
##############################################################################
14694

147-
p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
148-
p.add_argument(
149-
"--height", type=int, default=1024, help="Height of Stable Diffusion output image."
150-
)
151-
p.add_argument(
152-
"--width", type=int, default=1024, help="Width of Stable Diffusion output image"
153-
)
15495
p.add_argument(
15596
"--precision",
15697
type=str,
15798
default="fp16",
15899
help="Precision of Stable Diffusion weights and graph.",
159100
)
160-
p.add_argument(
161-
"--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion"
162-
)
163-
p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode")
164-
p.add_argument(
165-
"--return_index",
166-
action="store_true",
167-
help="Make scheduled unet compiled module return the step index.",
168-
)
169-
170-
p.add_argument(
171-
"--vae_decomp_attn",
172-
type=bool,
173-
default=False,
174-
help="Decompose attention for VAE decode only at fx graph level",
175-
)
176101

177102
##############################################################################
178-
# SDXL script general options.
103+
# Script general options.
179104
##############################################################################
180105

181106
p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb")
@@ -195,12 +120,6 @@ def is_valid_file(arg):
195120
action="store_true",
196121
help="Runs both turbine vmfb and a torch model to compare results",
197122
)
198-
p.add_argument(
199-
"--decomp_attn",
200-
default=False,
201-
action="store_true",
202-
help="Decompose attention at fx graph level",
203-
)
204123
p.add_argument(
205124
"--exit_on_vmfb",
206125
default=True,
@@ -264,26 +183,5 @@ def is_valid_file(arg):
264183
help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.",
265184
)
266185

267-
p.add_argument(
268-
"--clip_flags",
269-
type=str,
270-
default="",
271-
help="extra iree-compile options to send for compiling CLIP/prompt_encoder. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py",
272-
)
273-
274-
p.add_argument(
275-
"--vae_flags",
276-
type=str,
277-
default="",
278-
help="extra iree-compile options to send for compiling VAE. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py",
279-
)
280-
281-
p.add_argument(
282-
"--unet_flags",
283-
type=str,
284-
default="",
285-
help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py",
286-
)
287-
288186

289187
args, unknown = p.parse_known_args()

models/turbine_models/custom_models/stateless_llama.py

Lines changed: 18 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -22,55 +22,6 @@
2222

2323
BATCH_SIZE = 1
2424

25-
import argparse
26-
27-
parser = argparse.ArgumentParser()
28-
parser.add_argument(
29-
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
30-
)
31-
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
32-
parser.add_argument(
33-
"--hf_model_name",
34-
type=str,
35-
help="HF model name",
36-
default="Trelis/Llama-2-7b-chat-hf-function-calling-v2",
37-
)
38-
parser.add_argument("--quantization", type=str, default="unquantized")
39-
parser.add_argument("--external_weight_file", type=str, default="")
40-
parser.add_argument(
41-
"--vmfb_path", type=str, default=None, help="Path/name to store compiled vmfb."
42-
)
43-
parser.add_argument(
44-
"--external_weights",
45-
type=str,
46-
default=None,
47-
help="saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]",
48-
)
49-
parser.add_argument(
50-
"--precision", type=str, default="fp16", help="dtype of model [f16, f32]"
51-
)
52-
parser.add_argument(
53-
"--device", type=str, default="llvm-cpu", help="llvm-cpu, cuda, vulkan, rocm"
54-
)
55-
# TODO: Bring in detection for target triple
56-
parser.add_argument(
57-
"--iree_target_triple",
58-
type=str,
59-
default="host",
60-
help="Specify vulkan target triple or rocm/cuda target device.",
61-
)
62-
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
63-
parser.add_argument(
64-
"--streaming_llm",
65-
action="store_true",
66-
help="Compile LLM with StreamingLLM optimizations",
67-
)
68-
parser.add_argument(
69-
"--decomp_attn",
70-
action="store_true",
71-
help="Decompose attention ops at fx graph level.",
72-
)
73-
7425

7526
def generate_schema(num_layers):
7627
null = None
@@ -519,51 +470,31 @@ def evict_kvcache_space(self):
519470
}
520471

521472

522-
class StatelessLlamaPipeline:
473+
class StatelessLlama:
523474
def __init__(
524475
self,
525476
hf_model_name: str,
526-
scheduler_id: str,
527-
height: int,
528-
width: int,
529477
precision: str,
530-
max_length: int,
531-
batch_size: int,
532-
num_inference_steps: int,
533478
device: str,
534479
iree_target_triple: str,
535480
ireec_flags: list = [],
536-
attn_spec: str = None,
537-
decomp_attn: bool = False,
538481
pipeline_dir: str | Path = "./shark_vmfbs",
539482
external_weights_dir: str | Path = "./shark_weights",
540483
external_weights: str = "safetensors",
541-
custom_vae: str = None,
542-
vae_decomp_attn: bool = True,
543484
hf_auth_token: str = None,
485+
streaming_llm: bool = False,
544486
):
545487
self.hf_model_name = hf_model_name
546488
self.iree_dtype = "float32" if precision == "fp32" else "float16"
547489
self.torch_dtype = torch.float32 if precision == "fp32" else torch.float16
548490
self.cpu_scheduling = True
549-
self.scheduler_id = scheduler_id
550-
self.height = height
551-
self.width = width
552491
self.precision = precision
553-
self.max_length = max_length
554-
self.model_max_length = max_length
555-
self.batch_size = batch_size
556-
self.num_inference_steps = num_inference_steps
557492
self.device = device
558493
self.iree_target_triple = iree_target_triple
559494
self.ireec_flags = ireec_flags
560-
self.attn_spec = attn_spec
561-
self.decomp_attn = decomp_attn
562495
self.pipeline_dir = pipeline_dir
563496
self.external_weights_dir = external_weights_dir
564497
self.external_weights = external_weights
565-
self.custom_vae = custom_vae
566-
self.vae_decomp_attn = vae_decomp_attn
567498

568499
self.first_input = True
569500
self.max_tokens = llm_model_map[self.hf_model_name]["max_tokens"]
@@ -582,10 +513,11 @@ def __init__(
582513
)
583514
self.model = None
584515
self.hf_auth_token=hf_auth_token
516+
self.streaming_llm = streaming_llm
585517

586518
# FILE MANAGEMENT AND PIPELINE SETUP
587519

588-
def check_prepared(
520+
def prepare_pipeline(
589521
self,
590522
mlir: str,
591523
vmfb: str,
@@ -660,8 +592,8 @@ def export(
660592
weights_only: bool = False,
661593
):
662594
safe_name = self.hf_model_name.replace("-", "_").replace("/", "_")
663-
# if self.streaming_llm:
664-
safe_name += "_streaming"
595+
if self.streaming_llm:
596+
safe_name += "_streaming"
665597

666598
if not os.path.exists(self.pipeline_dir):
667599
os.makedirs(self.pipeline_dir)
@@ -698,7 +630,7 @@ def export(
698630
device=self.device,
699631
target_triple=self.iree_target_triple,
700632
vulkan_max_allocation=None,
701-
streaming_llm=True,
633+
streaming_llm=self.streaming_llm,
702634
vmfb_path=os.path.join(self.pipeline_dir, safe_name + ".vmfb"),
703635
upload_ir=False,
704636
mod=None,
@@ -732,9 +664,12 @@ def format_out(results):
732664

733665
history = []
734666
for iter in range(self.max_tokens):
735-
# if self.streaming_llm:
736-
token_slice = max(self.prev_token_len - 1, 0)
737-
input_tensor = input_tensor[:, token_slice:]
667+
if self.streaming_llm:
668+
token_slice = max(self.prev_token_len - 1, 0)
669+
input_tensor = input_tensor[:, token_slice:]
670+
else:
671+
# TODO
672+
pass
738673
# if self.streaming_llm and self.model["get_seq_step"]() > 600:
739674
if self.model["get_seq_step"]() > 600:
740675
print("Evicting cache space!")
@@ -743,7 +678,7 @@ def format_out(results):
743678
device_inputs = [
744679
ireert.asdevicearray(self.device, input_tensor)
745680
]
746-
if self.first_input: # or not self.streaming_llm:
681+
if self.first_input or not self.streaming_llm:
747682
st_time = time.time()
748683
token = self.model["run_initialize"](*device_inputs)
749684
total_time = time.time() - st_time
@@ -820,33 +755,17 @@ def format_out(results):
820755
if not args.external_weights_dir and args.external_weights:
821756
args.external_weights_dir = args.pipeline_dir
822757

823-
sd_pipe = StatelessLlamaPipeline(
758+
llama = StatelessLlama(
824759
args.hf_model_name,
825-
args.scheduler_id,
826-
args.height,
827-
args.width,
828760
args.precision,
829-
args.max_length,
830-
args.batch_size,
831-
args.num_inference_steps,
832761
args.device,
833762
args.iree_target_triple,
834763
flags,
835-
args.attn_spec,
836-
args.decomp_attn,
837764
args.pipeline_dir,
838765
args.external_weights_dir,
839766
args.external_weights,
840-
args.vae_decomp_attn,
841767
args.hf_auth_token,
768+
True,
842769
)
843-
vmfb, weight = sd_pipe.check_prepared(mlir, vmfb, weight, interactive=False, quantization="int4")
844-
sd_pipe.load_pipeline(vmfb, weight, args.rt_device, args.compiled_pipeline)
845-
sd_pipe.generate_images(
846-
args.prompt,
847-
args.negative_prompt,
848-
args.batch_count,
849-
args.guidance_scale,
850-
args.seed,
851-
False,
852-
)
770+
vmfb, weight = llama.prepare_pipeline(mlir, vmfb, weight, interactive=False, quantization="int4")
771+
llama.load_pipeline(vmfb, weight, args.rt_device, args.compiled_pipeline)

0 commit comments

Comments
 (0)