Skip to content

Commit 9338648

Browse files
committed
fix(misc): fix failing tests, pipeline init in rare circumstances
1 parent 3f1f980 commit 9338648

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

api/app.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import skimage
1717
import skimage.measure
1818
from getScheduler import getScheduler, SCHEDULERS
19-
from getPipeline import getPipelineForModel, listAvailablePipelines, clearPipelines
19+
from getPipeline import (
20+
getPipelineClass,
21+
getPipelineForModel,
22+
listAvailablePipelines,
23+
clearPipelines,
24+
)
2025
import re
2126
import requests
2227
from download import download_model, normalize_model_id
@@ -228,7 +233,7 @@ def sendStatus():
228233
model_dir = os.path.join(MODELS_DIR, normalized_model_id)
229234
pipeline_name = call_inputs.get("PIPELINE", None)
230235
if pipeline_name:
231-
pipeline_class = getattr(diffusers_pipelines, pipeline_name)
236+
pipeline_class = getPipelineClass(pipeline_name)
232237
if last_model_id != normalized_model_id:
233238
# if not downloaded_models.get(normalized_model_id, None):
234239
if not os.path.isdir(model_dir):
@@ -250,7 +255,7 @@ def sendStatus():
250255
hf_model_id=hf_model_id,
251256
model_precision=model_precision,
252257
send_opts=send_opts,
253-
pipeline_class=pipeline_class,
258+
pipeline_class=pipeline_class if pipeline_name else None,
254259
)
255260
# downloaded_models.update({normalized_model_id: True})
256261
clearPipelines()
@@ -267,7 +272,7 @@ def sendStatus():
267272
precision=model_precision,
268273
revision=model_revision,
269274
send_opts=send_opts,
270-
pipeline_class=pipeline_class,
275+
pipeline_class=pipeline_class if pipeline_name else None,
271276
)
272277
await send(
273278
"loadModel", "done", {"startRequestId": startRequestId}, send_opts

api/getPipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def clearPipelines():
4848
_pipelines = {}
4949

5050

51+
def getPipelineClass(pipeline_name: str):
52+
if hasattr(diffusers_pipelines, pipeline_name):
53+
return getattr(diffusers_pipelines, pipeline_name)
54+
elif pipeline_name in availableCommunityPipelines():
55+
return DiffusionPipeline
56+
57+
5158
def getPipelineForModel(
5259
pipeline_name: str, model, model_id, model_revision, model_precision
5360
):

api/loadModel.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def loadModel(
3131
precision=None,
3232
revision=None,
3333
send_opts={},
34-
pipeline_class=AutoPipelineForText2Image,
34+
pipeline_class=None,
3535
):
3636
torch_dtype = torch_dtype_from_precision(precision)
3737
if revision == "":
@@ -44,18 +44,23 @@ def loadModel(
4444
"load": load,
4545
"precision": precision,
4646
"revision": revision,
47+
"pipeline_class": pipeline_class,
4748
},
4849
)
50+
51+
if not pipeline_class:
52+
pipeline_class = AutoPipelineForText2Image
53+
54+
pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
55+
print("pipeline", pipeline_class)
56+
4957
print(
5058
("Loading" if load else "Downloading")
5159
+ " model: "
5260
+ model_id
5361
+ (f" ({revision})" if revision else "")
5462
)
5563

56-
pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
57-
print("pipeline", pipeline_class)
58-
5964
scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)
6065

6166
model_dir = os.path.join(MODELS_DIR, model_id)

0 commit comments

Comments
 (0)