Skip to content

Commit cd84e73

Browse files
committed
Use pagination for model inspection #2079 #1796
* inspect models in chunks of 8 * show some progress in the UI rather than just "Connecting..." * timeout for one page is 1min * total timeout is 3min * -> issues are reported faster, and large model collections have more time * (untested: should also support partial results)
1 parent 40e1849 commit cd84e73

File tree

5 files changed

+84
-44
lines changed

5 files changed

+84
-44
lines changed

ai_diffusion/client.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,13 @@ def deduce_from_filename(filename: str):
130130
class ClientModels:
131131
"""Collects names of AI models the client has access to."""
132132

133-
checkpoints: dict[str, CheckpointInfo]
134-
vae: list[str]
135-
loras: list[str]
136-
upscalers: list[str]
137-
resources: dict[str, str | None]
138-
node_inputs: ComfyObjectInfo
139-
140133
def __init__(self) -> None:
134+
self.checkpoints: dict[str, CheckpointInfo] = {}
135+
self.vae: list[str] = []
136+
self.loras: list[str] = []
137+
self.upscalers: list[str] = []
141138
self.node_inputs = ComfyObjectInfo({})
142-
self.resources = {}
139+
self.resources: dict[str, str | None] = {}
143140

144141
def resource(
145142
self, kind: ResourceKind, identifier: ControlMode | UpscalerName | str, arch: Arch
@@ -297,6 +294,14 @@ class Client(ABC):
297294
@abstractmethod
298295
async def connect(url: str, access_token: str = "") -> Client: ...
299296

297+
class DiscoverStatus(NamedTuple):
298+
folder: str
299+
current: int
300+
total: int
301+
302+
def discover_models(self, refresh: bool) -> AsyncGenerator[DiscoverStatus, Any]:
303+
raise NotImplementedError()
304+
300305
@abstractmethod
301306
async def enqueue(self, work: WorkflowInput, front: bool = False) -> str: ...
302307

ai_diffusion/comfy_client.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from enum import Enum
88
from collections import deque
99
from itertools import chain, product
10+
from time import time
1011
from typing import Any, Optional, Sequence
1112

1213
from .api import WorkflowInput
@@ -164,31 +165,49 @@ async def connect(url=default_url, access_token=""):
164165
loras = nodes.options("LoraLoader", "lora_name")
165166
available_resources.update(_find_loras(loras))
166167

167-
# Retrieve list of checkpoints
168-
checkpoints = await client.try_inspect("checkpoints")
169-
diffusion_models = await client.try_inspect("diffusion_models")
170-
diffusion_models.update(await client.try_inspect("unet_gguf"))
171-
client._refresh_models(nodes, checkpoints, diffusion_models)
172-
173-
# Check supported base models and make sure there is at least one
174-
client._supported_archs = {ver: client._check_workload(ver) for ver in Arch.list()}
175-
supported_workloads = [
176-
arch for arch, miss in client._supported_archs.items() if len(miss) == 0
177-
]
178-
log.info("Supported workloads: " + ", ".join(arch.value for arch in supported_workloads))
179-
if len(supported_workloads) == 0 and settings.check_server_resources:
180-
raise MissingResources(client._supported_archs)
181-
182168
# Workarounds for DirectML
183169
if client.device_info.type == "privateuseone":
184170
# OmniSR causes a crash
185171
for n in [2, 3, 4]:
186172
id = resource_id(ResourceKind.upscaler, Arch.all, UpscalerName.fast_x(n))
187173
available_resources[id] = models.default_upscaler
188174

189-
_ensure_supported_style(client)
190175
return client
191176

177+
async def discover_models(self, refresh: bool):
178+
if refresh:
179+
nodes = ComfyObjectInfo(await self._get("object_info"))
180+
else:
181+
nodes = self.models.node_inputs
182+
183+
checkpoints: dict[str, dict] = {}
184+
diffusion_models: dict[str, dict] = {}
185+
async for status, result in self.try_inspect("checkpoints"):
186+
yield status
187+
checkpoints.update(result)
188+
async for status, result in self.try_inspect("diffusion_models"):
189+
yield status
190+
diffusion_models.update(result)
191+
async for status, result in self.try_inspect("unet_gguf"):
192+
yield status
193+
diffusion_models.update(result)
194+
self._refresh_models(nodes, checkpoints, diffusion_models)
195+
196+
# Check supported base models and make sure there is at least one
197+
self._supported_archs = {ver: self._check_workload(ver) for ver in Arch.list()}
198+
supported_workloads = [
199+
arch for arch, miss in self._supported_archs.items() if len(miss) == 0
200+
]
201+
log.info("Supported workloads: " + ", ".join(arch.value for arch in supported_workloads))
202+
if not refresh and len(supported_workloads) == 0 and settings.check_server_resources:
203+
raise MissingResources(self._supported_archs)
204+
205+
_ensure_supported_style(self)
206+
207+
async def refresh(self):
208+
async for __ in self.discover_models(refresh=True):
209+
pass
210+
192211
async def _get(self, op: str, timeout: float | None = 60):
193212
return await self._requests.get(f"{self.url}/{op}", timeout=timeout)
194213

@@ -386,14 +405,26 @@ async def disconnect(self):
386405
self._unsubscribe_workflows(),
387406
)
388407

389-
async def try_inspect(self, folder_name: str) -> dict[str, Any]:
408+
async def try_inspect(self, folder_name: str):
390409
if "gguf" in folder_name and not self.features.gguf:
391-
return {}
410+
return
392411
try:
393-
return await self._get(f"api/etn/model_info/{folder_name}", timeout=120)
412+
log.info(f"Inspecting models at {self.url}/api/etn/model_info/{folder_name}")
413+
start, timeout = time(), 180
414+
offset, total = 0, 100
415+
while offset < total and (time() - start) < timeout:
416+
r = await self._get(f"api/etn/model_info/{folder_name}?offset={offset}&limit=8")
417+
if "_meta" not in r: # server doesn't support pagination
418+
yield (Client.DiscoverStatus(folder_name, len(r), len(r)), r)
419+
return
420+
total = r["_meta"]["total"]
421+
del r["_meta"]
422+
yield (Client.DiscoverStatus(folder_name, offset + len(r), total), r)
423+
offset += 8
424+
if offset < total:
425+
log.warning(f"Timeout while inspecting models, received {offset}/{total} entries")
394426
except NetworkError as e:
395427
log.error(f"Error while inspecting models in {folder_name}: {str(e)}")
396-
return {}
397428

398429
@property
399430
def queued_count(self):
@@ -403,16 +434,6 @@ def queued_count(self):
403434
def is_executing(self):
404435
return self._active is not None
405436

406-
async def refresh(self):
407-
nodes, checkpoints, diffusion_models, diffusion_gguf = await asyncio.gather(
408-
self._get("object_info"),
409-
self.try_inspect("checkpoints"),
410-
self.try_inspect("diffusion_models"),
411-
self.try_inspect("unet_gguf"),
412-
)
413-
diffusion_models.update(diffusion_gguf)
414-
self._refresh_models(ComfyObjectInfo(nodes), checkpoints, diffusion_models)
415-
416437
def _refresh_models(
417438
self, nodes: ComfyObjectInfo, checkpoints: dict | None, diffusion_models: dict | None
418439
):

ai_diffusion/connection.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ class ConnectionState(Enum):
2020
connected = 2
2121
error = 3
2222

23-
auth_missing = 4
24-
auth_requesting = 5
25-
auth_pending = 6
26-
auth_error = 7
23+
discover_models = 10
24+
25+
auth_missing = 20
26+
auth_requesting = 21
27+
auth_pending = 22
28+
auth_error = 23
2729

2830

2931
class Connection(QObject, ObservableProperties):
3032
state = Property(ConnectionState.disconnected)
3133
error = Property("")
34+
progress = Property((1, 1))
3235

3336
state_changed = pyqtSignal(ConnectionState)
3437
error_changed = pyqtSignal(str)
38+
progress_changed = pyqtSignal(tuple)
3539
models_changed = pyqtSignal()
3640
message_received = pyqtSignal(ClientMessage)
3741
workflow_published = pyqtSignal(str)
@@ -90,6 +94,9 @@ async def _connect(self, url: str, mode: ServerMode, access_token=""):
9094
self._client = await CloudClient.connect(CloudClient.default_api_url, access_token)
9195
else:
9296
self._client = await ComfyClient.connect(url)
97+
self.state = ConnectionState.discover_models
98+
async for status in self._client.discover_models(refresh=False):
99+
self.progress = (status.current, status.total)
93100
self.missing_resources = self._client.missing_resources
94101

95102
apply_performance_preset(settings, self._client.device_info)

ai_diffusion/resources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class CustomNode(NamedTuple):
4242
"External Tooling Nodes",
4343
"comfyui-tooling-nodes",
4444
"https://github.com/Acly/comfyui-tooling-nodes",
45-
"fbf99f2a08368211513bcc3a3ab5b8bc62a476b1",
45+
"a088a2dde2ca4a0db9193ac1c01ac77c8b31a70d",
4646
["ETN_LoadImageCache", "ETN_SaveImageCache", "ETN_Translate"],
4747
),
4848
CustomNode(

ai_diffusion/ui/settings.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(self, server: Server):
299299

300300
root.connection.state_changed.connect(self.update_server_status)
301301
root.connection.error_changed.connect(self.update_server_status)
302+
root.connection.progress_changed.connect(self.update_server_status)
302303
self.update_server_status()
303304

304305
@property
@@ -352,7 +353,7 @@ def _connect(self):
352353
def update_server_status(self):
353354
connection = root.connection
354355
self._cloud_widget.update_connection_state(connection.state)
355-
self._connect_button.setEnabled(connection.state != ConnectionState.connecting)
356+
self._connect_button.setEnabled(True)
356357
self._client_id.setVisible(False)
357358
if connection.state == ConnectionState.connected:
358359
self._connection_status.setText(_("Connected"))
@@ -362,6 +363,12 @@ def update_server_status(self):
362363
elif connection.state == ConnectionState.connecting:
363364
self._connection_status.setText(_("Connecting"))
364365
self._connection_status.setStyleSheet(f"color: {yellow}; font-weight:bold")
366+
self._connect_button.setEnabled(False)
367+
elif connection.state == ConnectionState.discover_models:
368+
progress = f" ({connection.progress[0]}/{connection.progress[1]})"
369+
self._connection_status.setText(_("Discovering models") + progress)
370+
self._connection_status.setStyleSheet(f"color: {yellow}; font-weight:bold")
371+
self._connect_button.setEnabled(False)
365372
elif connection.state == ConnectionState.disconnected:
366373
self._connection_status.setText(_("Disconnected"))
367374
self._connection_status.setStyleSheet(f"color: {grey}; font-style:italic")

0 commit comments

Comments
 (0)