|
| 1 | +#!/bin/env python |
| 2 | + |
| 3 | +from argparse import ArgumentParser, Namespace |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +from invokeai.app.services.config import InvokeAIAppConfig, get_config |
| 7 | +from invokeai.app.services.download import DownloadQueueService |
| 8 | +from invokeai.app.services.model_install import ModelInstallService |
| 9 | +from invokeai.app.services.model_records import ModelRecordServiceSQL |
| 10 | +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase |
| 11 | +from invokeai.backend.util.logging import InvokeAILogger |
| 12 | + |
| 13 | + |
| 14 | +def get_args() -> Namespace: |
| 15 | + parser = ArgumentParser(description="Update models database from yaml file") |
| 16 | + parser.add_argument("--root", type=Path, required=False, default=None) |
| 17 | + parser.add_argument("--yaml_file", type=Path, required=False, default=None) |
| 18 | + return parser.parse_args() |
| 19 | + |
| 20 | + |
| 21 | +def populate_config() -> InvokeAIAppConfig: |
| 22 | + args = get_args() |
| 23 | + config = get_config() |
| 24 | + if args.root: |
| 25 | + config._root = args.root |
| 26 | + if args.yaml_file: |
| 27 | + config.legacy_models_yaml_path = args.yaml_file |
| 28 | + else: |
| 29 | + config.legacy_models_yaml_path = config.root_path / "configs/models.yaml" |
| 30 | + return config |
| 31 | + |
| 32 | + |
| 33 | +def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService: |
| 34 | + logger = InvokeAILogger.get_logger(config=config) |
| 35 | + db = SqliteDatabase(config.db_path, logger) |
| 36 | + record_store = ModelRecordServiceSQL(db) |
| 37 | + queue = DownloadQueueService() |
| 38 | + queue.start() |
| 39 | + installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue) |
| 40 | + return installer |
| 41 | + |
| 42 | + |
| 43 | +def main() -> None: |
| 44 | + config = populate_config() |
| 45 | + installer = initialize_installer(config) |
| 46 | + installer._migrate_yaml(rename_yaml=False, overwrite_db=True) |
| 47 | + print("\n<INSTALLED MODELS>") |
| 48 | + print("\t".join(["key", "name", "type", "path"])) |
| 49 | + for model in installer.record_store.all_models(): |
| 50 | + print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()])) |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == "__main__": |
| 54 | + main() |
0 commit comments