Skip to content

Commit 52edf6b

Browse files
committed
fix(build-download): support regular HF download not just cloud cache
1 parent 8248ba0 commit 52edf6b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

api/app.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939

4040
torch.set_grad_enabled(False)
41+
always_normalize_model_id = None
4142

4243

4344
class DummySafetyChecker:
@@ -71,7 +72,6 @@ def init():
7172
last_model_id = None
7273

7374
if not RUNTIME_DOWNLOADS:
74-
# Uh doesn't this break non-cached images? TODO... IMAGE_CACHE
7575
normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)
7676
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
7777
if os.path.isdir(model_dir):
@@ -80,7 +80,7 @@ def init():
8080
normalized_model_id = MODEL_ID
8181

8282
model = loadModel(
83-
model_id = model_dir,
83+
model_id=always_normalize_model_id or MODEL_ID,
8484
load=True,
8585
precision=MODEL_PRECISION,
8686
revision=MODEL_REVISION,
@@ -187,7 +187,9 @@ def inference(all_inputs: dict) -> dict:
187187
clearPipelines()
188188
if model:
189189
model.to("cpu") # Necessary to avoid a memory leak
190-
model = loadModel(model_id=normalized_model_id, load=True, precision=model_precision)
190+
model = loadModel(
191+
model_id=normalized_model_id, load=True, precision=model_precision
192+
)
191193
last_model_id = normalized_model_id
192194
else:
193195
if always_normalize_model_id:

0 commit comments

Comments
 (0)