22
22
23
23
BATCH_SIZE = 1
24
24
25
- import argparse
26
-
27
- parser = argparse .ArgumentParser ()
28
- parser .add_argument (
29
- "--hf_auth_token" , type = str , help = "The Hugging Face auth token, required"
30
- )
31
- parser .add_argument ("--compile_to" , type = str , help = "torch, linalg, vmfb" )
32
- parser .add_argument (
33
- "--hf_model_name" ,
34
- type = str ,
35
- help = "HF model name" ,
36
- default = "Trelis/Llama-2-7b-chat-hf-function-calling-v2" ,
37
- )
38
- parser .add_argument ("--quantization" , type = str , default = "unquantized" )
39
- parser .add_argument ("--external_weight_file" , type = str , default = "" )
40
- parser .add_argument (
41
- "--vmfb_path" , type = str , default = None , help = "Path/name to store compiled vmfb."
42
- )
43
- parser .add_argument (
44
- "--external_weights" ,
45
- type = str ,
46
- default = None ,
47
- help = "saves ir/vmfb without global weights for size and readability, options [gguf, safetensors]" ,
48
- )
49
- parser .add_argument (
50
- "--precision" , type = str , default = "fp16" , help = "dtype of model [f16, f32]"
51
- )
52
- parser .add_argument (
53
- "--device" , type = str , default = "llvm-cpu" , help = "llvm-cpu, cuda, vulkan, rocm"
54
- )
55
- # TODO: Bring in detection for target triple
56
- parser .add_argument (
57
- "--iree_target_triple" ,
58
- type = str ,
59
- default = "host" ,
60
- help = "Specify vulkan target triple or rocm/cuda target device." ,
61
- )
62
- parser .add_argument ("--vulkan_max_allocation" , type = str , default = "4294967296" )
63
- parser .add_argument (
64
- "--streaming_llm" ,
65
- action = "store_true" ,
66
- help = "Compile LLM with StreamingLLM optimizations" ,
67
- )
68
- parser .add_argument (
69
- "--decomp_attn" ,
70
- action = "store_true" ,
71
- help = "Decompose attention ops at fx graph level." ,
72
- )
73
-
74
25
75
26
def generate_schema (num_layers ):
76
27
null = None
@@ -519,51 +470,31 @@ def evict_kvcache_space(self):
519
470
}
520
471
521
472
522
- class StatelessLlamaPipeline :
473
+ class StatelessLlama :
523
474
def __init__ (
524
475
self ,
525
476
hf_model_name : str ,
526
- scheduler_id : str ,
527
- height : int ,
528
- width : int ,
529
477
precision : str ,
530
- max_length : int ,
531
- batch_size : int ,
532
- num_inference_steps : int ,
533
478
device : str ,
534
479
iree_target_triple : str ,
535
480
ireec_flags : list = [],
536
- attn_spec : str = None ,
537
- decomp_attn : bool = False ,
538
481
pipeline_dir : str | Path = "./shark_vmfbs" ,
539
482
external_weights_dir : str | Path = "./shark_weights" ,
540
483
external_weights : str = "safetensors" ,
541
- custom_vae : str = None ,
542
- vae_decomp_attn : bool = True ,
543
484
hf_auth_token : str = None ,
485
+ streaming_llm : bool = False ,
544
486
):
545
487
self .hf_model_name = hf_model_name
546
488
self .iree_dtype = "float32" if precision == "fp32" else "float16"
547
489
self .torch_dtype = torch .float32 if precision == "fp32" else torch .float16
548
490
self .cpu_scheduling = True
549
- self .scheduler_id = scheduler_id
550
- self .height = height
551
- self .width = width
552
491
self .precision = precision
553
- self .max_length = max_length
554
- self .model_max_length = max_length
555
- self .batch_size = batch_size
556
- self .num_inference_steps = num_inference_steps
557
492
self .device = device
558
493
self .iree_target_triple = iree_target_triple
559
494
self .ireec_flags = ireec_flags
560
- self .attn_spec = attn_spec
561
- self .decomp_attn = decomp_attn
562
495
self .pipeline_dir = pipeline_dir
563
496
self .external_weights_dir = external_weights_dir
564
497
self .external_weights = external_weights
565
- self .custom_vae = custom_vae
566
- self .vae_decomp_attn = vae_decomp_attn
567
498
568
499
self .first_input = True
569
500
self .max_tokens = llm_model_map [self .hf_model_name ]["max_tokens" ]
@@ -582,10 +513,11 @@ def __init__(
582
513
)
583
514
self .model = None
584
515
self .hf_auth_token = hf_auth_token
516
+ self .streaming_llm = streaming_llm
585
517
586
518
# FILE MANAGEMENT AND PIPELINE SETUP
587
519
588
- def check_prepared (
520
+ def prepare_pipeline (
589
521
self ,
590
522
mlir : str ,
591
523
vmfb : str ,
@@ -660,8 +592,8 @@ def export(
660
592
weights_only : bool = False ,
661
593
):
662
594
safe_name = self .hf_model_name .replace ("-" , "_" ).replace ("/" , "_" )
663
- # if self.streaming_llm:
664
- safe_name += "_streaming"
595
+ if self .streaming_llm :
596
+ safe_name += "_streaming"
665
597
666
598
if not os .path .exists (self .pipeline_dir ):
667
599
os .makedirs (self .pipeline_dir )
@@ -698,7 +630,7 @@ def export(
698
630
device = self .device ,
699
631
target_triple = self .iree_target_triple ,
700
632
vulkan_max_allocation = None ,
701
- streaming_llm = True ,
633
+ streaming_llm = self . streaming_llm ,
702
634
vmfb_path = os .path .join (self .pipeline_dir , safe_name + ".vmfb" ),
703
635
upload_ir = False ,
704
636
mod = None ,
@@ -732,9 +664,12 @@ def format_out(results):
732
664
733
665
history = []
734
666
for iter in range (self .max_tokens ):
735
- # if self.streaming_llm:
736
- token_slice = max (self .prev_token_len - 1 , 0 )
737
- input_tensor = input_tensor [:, token_slice :]
667
+ if self .streaming_llm :
668
+ token_slice = max (self .prev_token_len - 1 , 0 )
669
+ input_tensor = input_tensor [:, token_slice :]
670
+ else :
671
+ # TODO
672
+ pass
738
673
# if self.streaming_llm and self.model["get_seq_step"]() > 600:
739
674
if self .model ["get_seq_step" ]() > 600 :
740
675
print ("Evicting cache space!" )
@@ -743,7 +678,7 @@ def format_out(results):
743
678
device_inputs = [
744
679
ireert .asdevicearray (self .device , input_tensor )
745
680
]
746
- if self .first_input : # or not self.streaming_llm:
681
+ if self .first_input or not self .streaming_llm :
747
682
st_time = time .time ()
748
683
token = self .model ["run_initialize" ](* device_inputs )
749
684
total_time = time .time () - st_time
@@ -820,33 +755,17 @@ def format_out(results):
820
755
if not args .external_weights_dir and args .external_weights :
821
756
args .external_weights_dir = args .pipeline_dir
822
757
823
- sd_pipe = StatelessLlamaPipeline (
758
+ llama = StatelessLlama (
824
759
args .hf_model_name ,
825
- args .scheduler_id ,
826
- args .height ,
827
- args .width ,
828
760
args .precision ,
829
- args .max_length ,
830
- args .batch_size ,
831
- args .num_inference_steps ,
832
761
args .device ,
833
762
args .iree_target_triple ,
834
763
flags ,
835
- args .attn_spec ,
836
- args .decomp_attn ,
837
764
args .pipeline_dir ,
838
765
args .external_weights_dir ,
839
766
args .external_weights ,
840
- args .vae_decomp_attn ,
841
767
args .hf_auth_token ,
768
+ True ,
842
769
)
843
- vmfb , weight = sd_pipe .check_prepared (mlir , vmfb , weight , interactive = False , quantization = "int4" )
844
- sd_pipe .load_pipeline (vmfb , weight , args .rt_device , args .compiled_pipeline )
845
- sd_pipe .generate_images (
846
- args .prompt ,
847
- args .negative_prompt ,
848
- args .batch_count ,
849
- args .guidance_scale ,
850
- args .seed ,
851
- False ,
852
- )
770
+ vmfb , weight = llama .prepare_pipeline (mlir , vmfb , weight , interactive = False , quantization = "int4" )
771
+ llama .load_pipeline (vmfb , weight , args .rt_device , args .compiled_pipeline )
0 commit comments