Skip to content

Commit b785714

Browse files
authored
update tank and add validation distinction for uploaded names (#548)
This commit makes minor changes to the model testing and base models, so that there is a difference between validated and failing mlir names in the turbinetank.
1 parent 36a7091 commit b785714

File tree

9 files changed

+172
-121
lines changed

9 files changed

+172
-121
lines changed

models/turbine_models/custom_models/sd_inference/clip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,16 @@ def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
138138
f.write(module_str)
139139
model_name_upload = hf_model_name.replace("/", "_")
140140
model_name_upload += "-clip"
141-
turbine_tank.uploadToBlobStorage(
141+
blob_name = turbine_tank.uploadToBlobStorage(
142142
str(os.path.abspath(f"{safe_name}.mlir")),
143143
f"{model_name_upload}/{model_name_upload}.mlir",
144144
)
145145
if compile_to != "vmfb":
146146
return module_str, tokenizer
147147
else:
148148
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
149+
if upload_ir:
150+
return blob_name
149151

150152

151153
if __name__ == "__main__":

models/turbine_models/custom_models/sd_inference/schedulers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,16 @@ def main(
153153
f.write(module_str)
154154
model_name_upload = hf_model_name.replace("/", "-")
155155
model_name_upload = model_name_upload + "_scheduler"
156-
turbine_tank.uploadToBlobStorage(
156+
blob_name = turbine_tank.uploadToBlobStorage(
157157
str(os.path.abspath(f"{safe_name}.mlir")),
158158
f"{model_name_upload}/{model_name_upload}.mlir",
159159
)
160160
if compile_to != "vmfb":
161161
return module_str
162162
else:
163163
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
164+
if upload_ir:
165+
return blob_name
164166

165167

166168
if __name__ == "__main__":

models/turbine_models/custom_models/sd_inference/unet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,16 @@ def main(
132132
f.write(module_str)
133133
model_name_upload = hf_model_name.replace("/", "-")
134134
model_name_upload += "_unet"
135-
turbine_tank.uploadToBlobStorage(
135+
blob_name = turbine_tank.uploadToBlobStorage(
136136
str(os.path.abspath(f"{safe_name}.mlir")),
137137
f"{model_name_upload}/{model_name_upload}.mlir",
138138
)
139139
if compile_to != "vmfb":
140140
return module_str
141141
else:
142142
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
143+
if upload_ir:
144+
return blob_name
143145

144146

145147
if __name__ == "__main__":

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
8383
with open(f"{safe_name}.vmfb", "wb+") as f:
8484
f.write(flatbuffer_blob)
8585
print("Saved to", safe_name + ".vmfb")
86-
exit()
8786

8887

8988
def create_safe_name(hf_model_name, model_name_str):

models/turbine_models/custom_models/sd_inference/vae.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,16 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)):
120120
f.write(module_str)
121121
model_name_upload = hf_model_name.replace("/", "_")
122122
model_name_upload = model_name_upload + "-vae-" + variant
123-
turbine_tank.uploadToBlobStorage(
123+
blob_name = turbine_tank.uploadToBlobStorage(
124124
str(os.path.abspath(f"{safe_name}.mlir")),
125125
f"{model_name_upload}/{model_name_upload}.mlir",
126126
)
127127
if compile_to != "vmfb":
128128
return module_str
129129
else:
130130
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
131+
if upload_ir:
132+
return blob_name
131133

132134

133135
if __name__ == "__main__":

models/turbine_models/custom_models/stateless_llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def evict_kvcache_space(self):
416416
with open(f"{safe_name}.mlir", "w+") as f:
417417
f.write(module_str)
418418
model_name_upload = hf_model_name.replace("/", "_")
419-
turbine_tank.uploadToBlobStorage(
419+
blob_name = turbine_tank.uploadToBlobStorage(
420420
str(os.path.abspath(f"{safe_name}.mlir")),
421421
f"{model_name_upload}/{model_name_upload}.mlir",
422422
)
@@ -478,6 +478,8 @@ def evict_kvcache_space(self):
478478
with open(vmfb_path, "wb+") as f:
479479
f.write(flatbuffer_blob)
480480
print("saved to ", safe_name + ".vmfb")
481+
if upload_ir:
482+
return blob_name
481483
return module_str, tokenizer
482484

483485

0 commit comments

Comments
 (0)