Skip to content

Commit 6172c85

Browse files
feat(api): enrich starer model bundle metadata
1 parent b26fb1f commit 6172c85

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

invokeai/app/api/routers/model_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
STARTER_BUNDLES,
4242
STARTER_MODELS,
4343
StarterModel,
44+
StarterModelBundle,
4445
StarterModelWithoutDependencies,
4546
)
4647

@@ -799,7 +800,7 @@ async def convert_model(
799800

800801
class StarterModelResponse(BaseModel):
801802
starter_models: list[StarterModel]
802-
starter_bundles: dict[str, list[StarterModel]]
803+
starter_bundles: dict[str, StarterModelBundle]
803804

804805

805806
def get_is_installed(
@@ -833,7 +834,7 @@ async def get_starter_models() -> StarterModelResponse:
833834
model.dependencies = missing_deps
834835

835836
for bundle in starter_bundles.values():
836-
for model in bundle:
837+
for model in bundle.models:
837838
model.is_installed = get_is_installed(model, installed_models)
838839
# Remove already-installed dependencies
839840
missing_deps: list[StarterModelWithoutDependencies] = []

invokeai/backend/model_manager/starter_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies):
2323
dependencies: Optional[list[StarterModelWithoutDependencies]] = None
2424

2525

26-
class StarterModelBundles(BaseModel):
26+
class StarterModelBundle(BaseModel):
2727
name: str
2828
models: list[StarterModel]
2929

@@ -778,10 +778,10 @@ class StarterModelBundles(BaseModel):
778778
flux_fill,
779779
]
780780

781-
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
782-
BaseModelType.StableDiffusion1: sd1_bundle,
783-
BaseModelType.StableDiffusionXL: sdxl_bundle,
784-
BaseModelType.Flux: flux_bundle,
781+
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
782+
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
783+
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
784+
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
785785
}
786786

787787
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

0 commit comments

Comments
 (0)