Skip to content

Commit b6d7a44

Browse files
psychedelicioushipsterusername
authored andcommitted
refactor(events): include full model source in model install events
This is required to fix an issue with the MM UI's error handling. Previously, we only included the model source as a string. That could be an arbitrary URL, file path or HF repo id, but the frontend has no parsing logic to differentiate between these different model sources. Without access to the type of model source, it is difficult to determine how the user should proceed. For example, if it's HF URL with an HTTP unauthorized error, we should direct the user to log in to HF. But if it's a civitai URL with the same error, we should not direct the user to HF. There are a variety of related edge cases. With this change, the full `ModelSource` object is included in each model install event, including error events. I had to fix some circular import issues, hence the import changes to files other than `events_common.py`.
1 parent e18100a commit b6d7a44

File tree

5 files changed

+29
-22
lines changed

5 files changed

+29
-22
lines changed

invokeai/app/services/download/download_default.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import traceback
99
from pathlib import Path
1010
from queue import Empty, PriorityQueue
11-
from typing import Any, Dict, List, Literal, Optional, Set
11+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
1212

1313
import requests
1414
from pydantic.networks import AnyHttpUrl
@@ -28,11 +28,13 @@
2828
ServiceInactiveException,
2929
UnknownJobIDException,
3030
)
31-
from invokeai.app.services.events.events_base import EventServiceBase
3231
from invokeai.app.util.misc import get_iso_timestamp
3332
from invokeai.backend.model_manager.metadata import RemoteModelFile
3433
from invokeai.backend.util.logging import InvokeAILogger
3534

35+
if TYPE_CHECKING:
36+
from invokeai.app.services.events.events_base import EventServiceBase
37+
3638
# Maximum number of bytes to download during each call to requests.iter_content()
3739
DOWNLOAD_CHUNK_SIZE = 100000
3840

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from .events_base import EventServiceBase # noqa F401

invokeai/app/services/events/events_common.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi_events.registry.payload_schema import registry as payload_schema
55
from pydantic import BaseModel, ConfigDict, Field
66

7+
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
78
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
89
from invokeai.app.services.session_queue.session_queue_common import (
910
QUEUE_ITEM_STATUS,
@@ -18,7 +19,7 @@
1819

1920
if TYPE_CHECKING:
2021
from invokeai.app.services.download.download_base import DownloadJob
21-
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
22+
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
2223

2324

2425
class EventBase(BaseModel):
@@ -422,7 +423,7 @@ class ModelInstallDownloadStartedEvent(ModelEventBase):
422423
__event_name__ = "model_install_download_started"
423424

424425
id: int = Field(description="The ID of the install job")
425-
source: str = Field(description="Source of the model; local path, repo_id or url")
426+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
426427
local_path: str = Field(description="Where model is downloading to")
427428
bytes: int = Field(description="Number of bytes downloaded so far")
428429
total_bytes: int = Field(description="Total size of download, including all files")
@@ -443,7 +444,7 @@ def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
443444
]
444445
return cls(
445446
id=job.id,
446-
source=str(job.source),
447+
source=job.source,
447448
local_path=job.local_path.as_posix(),
448449
parts=parts,
449450
bytes=job.bytes,
@@ -458,7 +459,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
458459
__event_name__ = "model_install_download_progress"
459460

460461
id: int = Field(description="The ID of the install job")
461-
source: str = Field(description="Source of the model; local path, repo_id or url")
462+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
462463
local_path: str = Field(description="Where model is downloading to")
463464
bytes: int = Field(description="Number of bytes downloaded so far")
464465
total_bytes: int = Field(description="Total size of download, including all files")
@@ -479,7 +480,7 @@ def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
479480
]
480481
return cls(
481482
id=job.id,
482-
source=str(job.source),
483+
source=job.source,
483484
local_path=job.local_path.as_posix(),
484485
parts=parts,
485486
bytes=job.bytes,
@@ -494,11 +495,11 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
494495
__event_name__ = "model_install_downloads_complete"
495496

496497
id: int = Field(description="The ID of the install job")
497-
source: str = Field(description="Source of the model; local path, repo_id or url")
498+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
498499

499500
@classmethod
500501
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
501-
return cls(id=job.id, source=str(job.source))
502+
return cls(id=job.id, source=job.source)
502503

503504

504505
@payload_schema.register
@@ -508,11 +509,11 @@ class ModelInstallStartedEvent(ModelEventBase):
508509
__event_name__ = "model_install_started"
509510

510511
id: int = Field(description="The ID of the install job")
511-
source: str = Field(description="Source of the model; local path, repo_id or url")
512+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
512513

513514
@classmethod
514515
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
515-
return cls(id=job.id, source=str(job.source))
516+
return cls(id=job.id, source=job.source)
516517

517518

518519
@payload_schema.register
@@ -522,14 +523,14 @@ class ModelInstallCompleteEvent(ModelEventBase):
522523
__event_name__ = "model_install_complete"
523524

524525
id: int = Field(description="The ID of the install job")
525-
source: str = Field(description="Source of the model; local path, repo_id or url")
526+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
526527
key: str = Field(description="Model config record key")
527528
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
528529

529530
@classmethod
530531
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
531532
assert job.config_out is not None
532-
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
533+
return cls(id=job.id, source=job.source, key=(job.config_out.key), total_bytes=job.total_bytes)
533534

534535

535536
@payload_schema.register
@@ -539,11 +540,11 @@ class ModelInstallCancelledEvent(ModelEventBase):
539540
__event_name__ = "model_install_cancelled"
540541

541542
id: int = Field(description="The ID of the install job")
542-
source: str = Field(description="Source of the model; local path, repo_id or url")
543+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
543544

544545
@classmethod
545546
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
546-
return cls(id=job.id, source=str(job.source))
547+
return cls(id=job.id, source=job.source)
547548

548549

549550
@payload_schema.register
@@ -553,15 +554,15 @@ class ModelInstallErrorEvent(ModelEventBase):
553554
__event_name__ = "model_install_error"
554555

555556
id: int = Field(description="The ID of the install job")
556-
source: str = Field(description="Source of the model; local path, repo_id or url")
557+
source: ModelSource = Field(description="Source of the model; local path, repo_id or url")
557558
error_type: str = Field(description="The name of the exception")
558559
error: str = Field(description="A text description of the exception")
559560

560561
@classmethod
561562
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
562563
assert job.error_type is not None
563564
assert job.error is not None
564-
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
565+
return cls(id=job.id, source=job.source, error_type=job.error_type, error=job.error)
565566

566567

567568
class BulkDownloadEventBase(EventBase):

invokeai/app/services/model_install/model_install_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33

44
from abc import ABC, abstractmethod
55
from pathlib import Path
6-
from typing import List, Optional, Union
6+
from typing import TYPE_CHECKING, List, Optional, Union
77

88
from pydantic.networks import AnyHttpUrl
99

1010
from invokeai.app.services.config import InvokeAIAppConfig
1111
from invokeai.app.services.download import DownloadQueueServiceBase
12-
from invokeai.app.services.events.events_base import EventServiceBase
1312
from invokeai.app.services.invoker import Invoker
1413
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
1514
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
1615
from invokeai.backend.model_manager import AnyModelConfig
1716

17+
if TYPE_CHECKING:
18+
from invokeai.app.services.events.events_base import EventServiceBase
19+
1820

1921
class ModelInstallServiceBase(ABC):
2022
"""Abstract base class for InvokeAI model installation."""

invokeai/app/services/model_install/model_install_default.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from queue import Empty, Queue
1010
from shutil import copyfile, copytree, move, rmtree
1111
from tempfile import mkdtemp
12-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
12+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
1313

1414
import torch
1515
import yaml
@@ -20,7 +20,6 @@
2020

2121
from invokeai.app.services.config import InvokeAIAppConfig
2222
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
23-
from invokeai.app.services.events.events_base import EventServiceBase
2423
from invokeai.app.services.invoker import Invoker
2524
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
2625
from invokeai.app.services.model_install.model_install_common import (
@@ -57,6 +56,10 @@
5756
from invokeai.backend.util.devices import TorchDevice
5857
from invokeai.backend.util.util import slugify
5958

59+
if TYPE_CHECKING:
60+
from invokeai.app.services.events.events_base import EventServiceBase
61+
62+
6063
TMPDIR_PREFIX = "tmpinstall_"
6164

6265

0 commit comments

Comments
 (0)