Skip to content

Commit 2325984

Browse files
authored
Stateless llama testing + SDXL CI Job (#513)
This commit cleans up stateless llama testing and helps with memory efficiency by creating the model on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). The sd test for `vae_encode` throws an error. I took a look at the mlir generated and the `encode_inp func` is there and inputs/return looks valid. Looks like it has to do with iree bump (`FuncConversion` pass). It also adds a CI job to run Jinchen's sdxl script nightly (one failure right now, hopefully fixed soon).
1 parent 1cccf4d commit 2325984

File tree

6 files changed

+143
-47
lines changed

6 files changed

+143
-47
lines changed

.github/workflows/test_models.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
matrix:
2222
version: [3.11]
23-
os: [nodai-ubuntu-builder-large]
23+
os: [nodai-amdgpu-w7900-x86-64]
2424

2525
runs-on: ${{matrix.os}}
2626
steps:

.github/workflows/test_sdxl.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: SDXL Models Nightly
2+
3+
on:
4+
schedule:
5+
- cron: '30 6 * * *'
6+
7+
jobs:
8+
test-sdxl-models:
9+
strategy:
10+
matrix:
11+
version: [3.11]
12+
os: [nodai-amdgpu-w7900-x86-64]
13+
14+
runs-on: ${{matrix.os}}
15+
steps:
16+
- name: "Setting up Python"
17+
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3
18+
with:
19+
python-version: ${{matrix.version}}
20+
21+
- name: "Checkout Code"
22+
uses: actions/checkout@v2
23+
with:
24+
ref: ean-sd-fp16
25+
26+
- name: Sync source deps
27+
# build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile
28+
run: |
29+
python -m pip install --upgrade pip
30+
# Note: We install in three steps in order to satisfy requirements
31+
# from non default locations first. Installing the PyTorch CPU
32+
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
33+
pip install --index-url https://download.pytorch.org/whl/cpu \
34+
-r core/pytorch-cpu-requirements.txt \
35+
-r core/torchvision-requirements.txt
36+
pip install --upgrade -r core/requirements.txt
37+
pip install -e core[testing,torch-cpu-nightly]
38+
pip install --upgrade -r models/requirements.txt
39+
pip install -e models
40+
41+
- name: Show current free memory
42+
run: |
43+
free -mh
44+
45+
- name: Run sdxl tests
46+
run: |
47+
pip install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
48+
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
49+
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
50+
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device rocm --iree_target_triple gfx90a

models/turbine_models/custom_models/llm_runner.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,14 @@ def run_llm(
168168
streaming_llm=False,
169169
chat_mode=False,
170170
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
171+
tokenizer=None,
171172
):
172-
tokenizer = AutoTokenizer.from_pretrained(
173-
hf_model_name,
174-
use_fast=False,
175-
token=hf_auth_token,
176-
)
173+
if tokenizer == None:
174+
tokenizer = AutoTokenizer.from_pretrained(
175+
hf_model_name,
176+
use_fast=False,
177+
token=hf_auth_token,
178+
)
177179
llm = SharkLLM(
178180
device=device,
179181
vmfb_path=vmfb_path,
@@ -204,43 +206,35 @@ def run_torch_llm(
204206
prompt,
205207
streaming_llm=False,
206208
chat_sys_prompt=DEFAULT_CHAT_SYS_PROMPT,
209+
model=None,
210+
tokenizer=None,
207211
):
208-
from turbine_models.model_builder import HFTransformerBuilder
209-
from transformers import AutoModelForCausalLM
210-
211-
model_builder = HFTransformerBuilder(
212-
example_input=None,
213-
hf_id=hf_model_name,
214-
auto_model=AutoModelForCausalLM,
215-
hf_auth_token=hf_auth_token,
216-
auto_tokenizer=AutoTokenizer,
217-
)
218212
if streaming_llm is True:
219-
enable_llama_pos_shift_attention(model_builder.model)
213+
enable_llama_pos_shift_attention(model)
220214

221215
def get_token_from_logits(logits):
222216
return torch.argmax(logits[:, -1, :], dim=1)
223217

224218
prompt = append_user_prompt(chat_sys_prompt, prompt)
225-
initial_input = model_builder.tokenizer(prompt, return_tensors="pt")
219+
initial_input = tokenizer(prompt, return_tensors="pt")
226220
example_input_id = initial_input.input_ids
227221

228-
model_results = model_builder.model.forward(example_input_id)
222+
model_results = model.forward(example_input_id)
229223
model_token = get_token_from_logits(model_results.logits)
230224

231225
pkv = model_results.past_key_values
232226

233227
torch_results = []
234228
torch_results.append(int(model_token))
235229
while model_token != 2:
236-
model_results = model_builder.model.forward(
230+
model_results = model.forward(
237231
torch.unsqueeze(model_token, 0), past_key_values=pkv
238232
)
239233
model_token = get_token_from_logits(model_results.logits)
240234
pkv = model_results.past_key_values
241235
torch_results.append(int(model_token[0]))
242236

243-
return model_builder.tokenizer.decode(torch_results)
237+
return tokenizer.decode(torch_results)
244238

245239

246240
if __name__ == "__main__":

models/turbine_models/custom_models/stateless_llama.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,21 @@ def export_transformer_model(
121121
streaming_llm=False,
122122
vmfb_path=None,
123123
upload_ir=False,
124+
mod=None,
125+
tokenizer=None,
124126
):
125-
tokenizer = AutoTokenizer.from_pretrained(
126-
hf_model_name,
127-
use_fast=False,
128-
token=hf_auth_token,
129-
)
130-
131-
mod = AutoModelForCausalLM.from_pretrained(
132-
hf_model_name,
133-
torch_dtype=torch.float,
134-
token=hf_auth_token,
135-
)
127+
if tokenizer == None:
128+
tokenizer = AutoTokenizer.from_pretrained(
129+
hf_model_name,
130+
use_fast=False,
131+
token=hf_auth_token,
132+
)
133+
if mod == None:
134+
mod = AutoModelForCausalLM.from_pretrained(
135+
hf_model_name,
136+
torch_dtype=torch.float,
137+
token=hf_auth_token,
138+
)
136139
schema_json = generate_schema(mod.config.num_hidden_layers)
137140
state_schema = pytree.treespec_loads(schema_json)
138141
if streaming_llm:
@@ -165,7 +168,8 @@ def export_transformer_model(
165168
for name in mod_params:
166169
mapper["params." + name] = name
167170
if external_weight_file:
168-
safetensors.torch.save_file(mod_params, external_weight_file)
171+
if os.path.exists(external_weight_file) == False:
172+
safetensors.torch.save_file(mod_params, external_weight_file)
169173

170174
elif external_weights == "gguf":
171175
tensor_mapper = remap_gguf.TensorNameMap(remap_gguf.MODEL_ARCH.LLAMA, HEADS)

models/turbine_models/model_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
model=None,
3131
model_type: str = None,
3232
compile_to_vmfb: bool = None,
33+
tokenizer=None,
3334
) -> None:
3435
self.example_input = example_input
3536
self.hf_id = hf_id
@@ -38,7 +39,7 @@ def __init__(
3839
self.auto_config = auto_config
3940
self.hf_auth_token = hf_auth_token
4041
self.model = model
41-
self.tokenizer = None
42+
self.tokenizer = tokenizer
4243
self.upload_ir = upload_ir
4344
self.model_type = model_type
4445
self.compile_to_vmfb = compile_to_vmfb

models/turbine_models/tests/stateless_llama_test.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
import os
1010
import unittest
1111
import difflib
12+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
13+
import torch
14+
from accelerate import init_empty_weights
15+
from transformers.modeling_utils import load_sharded_checkpoint
16+
import tempfile
1217

1318
os.environ["TORCH_LOGS"] = "dynamic"
1419
from shark_turbine.aot import *
@@ -18,18 +23,6 @@
1823
gen_external_params,
1924
)
2025

21-
quantization = "unquantized"
22-
precision = "f32"
23-
gen_external_params(
24-
hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2",
25-
quantization=quantization,
26-
hf_auth_token=None,
27-
precision=precision,
28-
)
29-
DEFAULT_PROMPT = """<s>[INST] <<SYS>>
30-
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
31-
"""
32-
3326

3427
def check_output_string(reference, output):
3528
# Calculate and print diff
@@ -43,7 +36,45 @@ def check_output_string(reference, output):
4336
assert reference == output, "".join(diff)
4437

4538

39+
quantization = "unquantized"
40+
precision = "f32"
41+
42+
DEFAULT_PROMPT = """<s>[INST] <<SYS>>
43+
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
44+
"""
45+
46+
4647
class StatelessLlamaChecks(unittest.TestCase):
48+
@classmethod
49+
def setUpClass(cls):
50+
gen_external_params(
51+
hf_model_name="Trelis/Llama-2-7b-chat-hf-function-calling-v2",
52+
quantization=quantization,
53+
hf_auth_token=None,
54+
precision=precision,
55+
)
56+
57+
cls.tokenizer = AutoTokenizer.from_pretrained(
58+
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
59+
use_fast=False,
60+
)
61+
62+
# The model is first created on the Meta device (with empty weights) and the state dict
63+
# is then loaded inside it (shard by shard in the case of a sharded checkpoint).
64+
# This avoids using twice the size of model with creating whole model with random weights,
65+
# then loading pretrained weights.
66+
cls.mod = AutoModelForCausalLM.from_pretrained(
67+
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
68+
torch_dtype=torch.float,
69+
low_cpu_mem_usage=True,
70+
device_map="auto",
71+
)
72+
73+
@classmethod
74+
def tearDownClass(cls):
75+
cls.tokenizer = None
76+
cls.mod = None
77+
4778
def test_vmfb_comparison(self):
4879
"""
4980
Test that the vmfb model produces the same output as the torch model
@@ -66,6 +97,8 @@ def test_vmfb_comparison(self):
6697
device="llvm-cpu",
6798
target_triple="host",
6899
upload_ir=upload_ir_var == "upload",
100+
mod=self.mod,
101+
tokenizer=self.tokenizer,
69102
)
70103

71104
torch_str_cache_path = (
@@ -77,7 +110,11 @@ def test_vmfb_comparison(self):
77110
torch_str = f.read()
78111
else:
79112
torch_str = llm_runner.run_torch_llm(
80-
"Trelis/Llama-2-7b-chat-hf-function-calling-v2", None, DEFAULT_PROMPT
113+
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
114+
None,
115+
self.DEFAULT_PROMPT,
116+
model=self.mod,
117+
tokenizer=self.tokenizer,
81118
)
82119

83120
with open(torch_str_cache_path, "w") as f:
@@ -90,6 +127,7 @@ def test_vmfb_comparison(self):
90127
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
91128
None,
92129
f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors",
130+
tokenizer=self.tokenizer,
93131
)
94132
check_output_string(torch_str, turbine_str)
95133

@@ -109,6 +147,8 @@ def test_streaming_vmfb_comparison(self):
109147
target_triple="host",
110148
streaming_llm=True,
111149
vmfb_path="streaming_llama.vmfb",
150+
mod=self.mod,
151+
tokenizer=self.tokenizer,
112152
)
113153

114154
torch_str_cache_path = (
@@ -124,6 +164,8 @@ def test_streaming_vmfb_comparison(self):
124164
None,
125165
DEFAULT_PROMPT,
126166
streaming_llm=True,
167+
model=self.mod,
168+
tokenizer=self.tokenizer,
127169
)
128170

129171
with open(torch_str_cache_path, "w") as f:
@@ -137,6 +179,7 @@ def test_streaming_vmfb_comparison(self):
137179
None,
138180
f"Llama_2_7b_chat_hf_function_calling_v2_{precision}_{quantization}.safetensors",
139181
streaming_llm=True,
182+
tokenizer=self.tokenizer,
140183
)
141184
check_output_string(torch_str, turbine_str)
142185

@@ -145,12 +188,16 @@ def test_rerotated_torch_comparison(self):
145188
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
146189
None,
147190
DEFAULT_PROMPT,
191+
model=self.mod,
192+
tokenizer=self.tokenizer,
148193
)
149194
rotated_torch_str = llm_runner.run_torch_llm(
150195
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
151196
None,
152197
DEFAULT_PROMPT,
153198
streaming_llm=True,
199+
model=self.mod,
200+
tokenizer=self.tokenizer,
154201
)
155202
check_output_string(torch_str, rotated_torch_str)
156203

0 commit comments

Comments
 (0)