77from enum import Enum
88from collections import deque
99from itertools import chain , product
10+ from time import time
1011from typing import Any , Optional , Sequence
1112
1213from .api import WorkflowInput
@@ -164,31 +165,49 @@ async def connect(url=default_url, access_token=""):
164165 loras = nodes .options ("LoraLoader" , "lora_name" )
165166 available_resources .update (_find_loras (loras ))
166167
167- # Retrieve list of checkpoints
168- checkpoints = await client .try_inspect ("checkpoints" )
169- diffusion_models = await client .try_inspect ("diffusion_models" )
170- diffusion_models .update (await client .try_inspect ("unet_gguf" ))
171- client ._refresh_models (nodes , checkpoints , diffusion_models )
172-
173- # Check supported base models and make sure there is at least one
174- client ._supported_archs = {ver : client ._check_workload (ver ) for ver in Arch .list ()}
175- supported_workloads = [
176- arch for arch , miss in client ._supported_archs .items () if len (miss ) == 0
177- ]
178- log .info ("Supported workloads: " + ", " .join (arch .value for arch in supported_workloads ))
179- if len (supported_workloads ) == 0 and settings .check_server_resources :
180- raise MissingResources (client ._supported_archs )
181-
182168 # Workarounds for DirectML
183169 if client .device_info .type == "privateuseone" :
184170 # OmniSR causes a crash
185171 for n in [2 , 3 , 4 ]:
186172 id = resource_id (ResourceKind .upscaler , Arch .all , UpscalerName .fast_x (n ))
187173 available_resources [id ] = models .default_upscaler
188174
189- _ensure_supported_style (client )
190175 return client
191176
177+ async def discover_models (self , refresh : bool ):
178+ if refresh :
179+ nodes = ComfyObjectInfo (await self ._get ("object_info" ))
180+ else :
181+ nodes = self .models .node_inputs
182+
183+ checkpoints : dict [str , dict ] = {}
184+ diffusion_models : dict [str , dict ] = {}
185+ async for status , result in self .try_inspect ("checkpoints" ):
186+ yield status
187+ checkpoints .update (result )
188+ async for status , result in self .try_inspect ("diffusion_models" ):
189+ yield status
190+ diffusion_models .update (result )
191+ async for status , result in self .try_inspect ("unet_gguf" ):
192+ yield status
193+ diffusion_models .update (result )
194+ self ._refresh_models (nodes , checkpoints , diffusion_models )
195+
196+ # Check supported base models and make sure there is at least one
197+ self ._supported_archs = {ver : self ._check_workload (ver ) for ver in Arch .list ()}
198+ supported_workloads = [
199+ arch for arch , miss in self ._supported_archs .items () if len (miss ) == 0
200+ ]
201+ log .info ("Supported workloads: " + ", " .join (arch .value for arch in supported_workloads ))
202+ if not refresh and len (supported_workloads ) == 0 and settings .check_server_resources :
203+ raise MissingResources (self ._supported_archs )
204+
205+ _ensure_supported_style (self )
206+
207+ async def refresh (self ):
208+ async for __ in self .discover_models (refresh = True ):
209+ pass
210+
192211 async def _get (self , op : str , timeout : float | None = 60 ):
193212 return await self ._requests .get (f"{ self .url } /{ op } " , timeout = timeout )
194213
@@ -386,14 +405,26 @@ async def disconnect(self):
386405 self ._unsubscribe_workflows (),
387406 )
388407
389- async def try_inspect (self , folder_name : str ) -> dict [ str , Any ] :
408+ async def try_inspect (self , folder_name : str ):
390409 if "gguf" in folder_name and not self .features .gguf :
391- return {}
410+ return
392411 try :
393- return await self ._get (f"api/etn/model_info/{ folder_name } " , timeout = 120 )
412+ log .info (f"Inspecting models at { self .url } /api/etn/model_info/{ folder_name } " )
413+ start , timeout = time (), 180
414+ offset , total = 0 , 100
415+ while offset < total and (time () - start ) < timeout :
416+ r = await self ._get (f"api/etn/model_info/{ folder_name } ?offset={ offset } &limit=8" )
417+ if "_meta" not in r : # server doesn't support pagination
418+ yield (Client .DiscoverStatus (folder_name , len (r ), len (r )), r )
419+ return
420+ total = r ["_meta" ]["total" ]
421+ del r ["_meta" ]
422+ yield (Client .DiscoverStatus (folder_name , offset + len (r ), total ), r )
423+ offset += 8
424+ if offset < total :
425+ log .warning (f"Timeout while inspecting models, received { offset } /{ total } entries" )
394426 except NetworkError as e :
395427 log .error (f"Error while inspecting models in { folder_name } : { str (e )} " )
396- return {}
397428
398429 @property
399430 def queued_count (self ):
@@ -403,16 +434,6 @@ def queued_count(self):
403434 def is_executing (self ):
404435 return self ._active is not None
405436
406- async def refresh (self ):
407- nodes , checkpoints , diffusion_models , diffusion_gguf = await asyncio .gather (
408- self ._get ("object_info" ),
409- self .try_inspect ("checkpoints" ),
410- self .try_inspect ("diffusion_models" ),
411- self .try_inspect ("unet_gguf" ),
412- )
413- diffusion_models .update (diffusion_gguf )
414- self ._refresh_models (ComfyObjectInfo (nodes ), checkpoints , diffusion_models )
415-
416437 def _refresh_models (
417438 self , nodes : ComfyObjectInfo , checkpoints : dict | None , diffusion_models : dict | None
418439 ):
0 commit comments