4
4
from fastapi_events .registry .payload_schema import registry as payload_schema
5
5
from pydantic import BaseModel , ConfigDict , Field
6
6
7
+ from invokeai .app .services .model_install .model_install_common import ModelInstallJob , ModelSource
7
8
from invokeai .app .services .session_processor .session_processor_common import ProgressImage
8
9
from invokeai .app .services .session_queue .session_queue_common import (
9
10
QUEUE_ITEM_STATUS ,
18
19
19
20
if TYPE_CHECKING :
20
21
from invokeai .app .services .download .download_base import DownloadJob
21
- from invokeai .app .services .model_install .model_install_common import ModelInstallJob
22
+ from invokeai .app .services .model_install .model_install_common import ModelInstallJob , ModelSource
22
23
23
24
24
25
class EventBase (BaseModel ):
@@ -422,7 +423,7 @@ class ModelInstallDownloadStartedEvent(ModelEventBase):
422
423
__event_name__ = "model_install_download_started"
423
424
424
425
id : int = Field (description = "The ID of the install job" )
425
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
426
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
426
427
local_path : str = Field (description = "Where model is downloading to" )
427
428
bytes : int = Field (description = "Number of bytes downloaded so far" )
428
429
total_bytes : int = Field (description = "Total size of download, including all files" )
@@ -443,7 +444,7 @@ def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
443
444
]
444
445
return cls (
445
446
id = job .id ,
446
- source = str ( job .source ) ,
447
+ source = job .source ,
447
448
local_path = job .local_path .as_posix (),
448
449
parts = parts ,
449
450
bytes = job .bytes ,
@@ -458,7 +459,7 @@ class ModelInstallDownloadProgressEvent(ModelEventBase):
458
459
__event_name__ = "model_install_download_progress"
459
460
460
461
id : int = Field (description = "The ID of the install job" )
461
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
462
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
462
463
local_path : str = Field (description = "Where model is downloading to" )
463
464
bytes : int = Field (description = "Number of bytes downloaded so far" )
464
465
total_bytes : int = Field (description = "Total size of download, including all files" )
@@ -479,7 +480,7 @@ def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
479
480
]
480
481
return cls (
481
482
id = job .id ,
482
- source = str ( job .source ) ,
483
+ source = job .source ,
483
484
local_path = job .local_path .as_posix (),
484
485
parts = parts ,
485
486
bytes = job .bytes ,
@@ -494,11 +495,11 @@ class ModelInstallDownloadsCompleteEvent(ModelEventBase):
494
495
__event_name__ = "model_install_downloads_complete"
495
496
496
497
id : int = Field (description = "The ID of the install job" )
497
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
498
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
498
499
499
500
@classmethod
500
501
def build (cls , job : "ModelInstallJob" ) -> "ModelInstallDownloadsCompleteEvent" :
501
- return cls (id = job .id , source = str ( job .source ) )
502
+ return cls (id = job .id , source = job .source )
502
503
503
504
504
505
@payload_schema .register
@@ -508,11 +509,11 @@ class ModelInstallStartedEvent(ModelEventBase):
508
509
__event_name__ = "model_install_started"
509
510
510
511
id : int = Field (description = "The ID of the install job" )
511
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
512
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
512
513
513
514
@classmethod
514
515
def build (cls , job : "ModelInstallJob" ) -> "ModelInstallStartedEvent" :
515
- return cls (id = job .id , source = str ( job .source ) )
516
+ return cls (id = job .id , source = job .source )
516
517
517
518
518
519
@payload_schema .register
@@ -522,14 +523,14 @@ class ModelInstallCompleteEvent(ModelEventBase):
522
523
__event_name__ = "model_install_complete"
523
524
524
525
id : int = Field (description = "The ID of the install job" )
525
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
526
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
526
527
key : str = Field (description = "Model config record key" )
527
528
total_bytes : Optional [int ] = Field (description = "Size of the model (may be None for installation of a local path)" )
528
529
529
530
@classmethod
530
531
def build (cls , job : "ModelInstallJob" ) -> "ModelInstallCompleteEvent" :
531
532
assert job .config_out is not None
532
- return cls (id = job .id , source = str ( job .source ) , key = (job .config_out .key ), total_bytes = job .total_bytes )
533
+ return cls (id = job .id , source = job .source , key = (job .config_out .key ), total_bytes = job .total_bytes )
533
534
534
535
535
536
@payload_schema .register
@@ -539,11 +540,11 @@ class ModelInstallCancelledEvent(ModelEventBase):
539
540
__event_name__ = "model_install_cancelled"
540
541
541
542
id : int = Field (description = "The ID of the install job" )
542
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
543
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
543
544
544
545
@classmethod
545
546
def build (cls , job : "ModelInstallJob" ) -> "ModelInstallCancelledEvent" :
546
- return cls (id = job .id , source = str ( job .source ) )
547
+ return cls (id = job .id , source = job .source )
547
548
548
549
549
550
@payload_schema .register
@@ -553,15 +554,15 @@ class ModelInstallErrorEvent(ModelEventBase):
553
554
__event_name__ = "model_install_error"
554
555
555
556
id : int = Field (description = "The ID of the install job" )
556
- source : str = Field (description = "Source of the model; local path, repo_id or url" )
557
+ source : ModelSource = Field (description = "Source of the model; local path, repo_id or url" )
557
558
error_type : str = Field (description = "The name of the exception" )
558
559
error : str = Field (description = "A text description of the exception" )
559
560
560
561
@classmethod
561
562
def build (cls , job : "ModelInstallJob" ) -> "ModelInstallErrorEvent" :
562
563
assert job .error_type is not None
563
564
assert job .error is not None
564
- return cls (id = job .id , source = str ( job .source ) , error_type = job .error_type , error = job .error )
565
+ return cls (id = job .id , source = job .source , error_type = job .error_type , error = job .error )
565
566
566
567
567
568
class BulkDownloadEventBase (EventBase ):
0 commit comments