|
1 | 1 | import json
|
2 | 2 | from collections import defaultdict
|
| 3 | +from dataclasses import dataclass, field |
3 | 4 | from typing import Any
|
4 | 5 |
|
5 | 6 | import torch
|
|
8 | 9 | from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
9 | 10 | from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
10 | 11 | from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
| 12 | +from invokeai.backend.util import InvokeAILogger |
11 | 13 |
|
12 | 14 |
|
13 |
| -def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: |
| 15 | +def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: |
14 | 16 | if metadata:
|
15 |
| - software = json.loads(metadata.get("software", "{}")) |
| 17 | + try: |
| 18 | + software = json.loads(metadata.get("software", "{}")) |
| 19 | + except json.JSONDecodeError: |
| 20 | + return False |
16 | 21 | return software.get("name") == "ai-toolkit"
|
17 | 22 | # metadata got lost somewhere
|
18 | 23 | return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
|
19 | 24 |
|
20 | 25 |
|
21 |
| -def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw: |
22 |
| - # Group keys by layer. |
23 |
| - grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict) |
| 26 | +@dataclass |
| 27 | +class GroupedStateDict: |
| 28 | + transformer: dict = field(default_factory=dict) |
| 29 | + # might also grow CLIP and T5 submodels |
| 30 | + |
| 31 | + |
| 32 | +def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStateDict: |
| 33 | + logger = InvokeAILogger.get_logger() |
| 34 | + grouped = GroupedStateDict() |
24 | 35 | for key, value in state_dict.items():
|
25 |
| - layer_name, param_name = key.split(".", 1) |
26 |
| - grouped_state_dict[layer_name][param_name] = value |
| 36 | + submodel_name, param_name = key.split(".", 1) |
| 37 | + match submodel_name: |
| 38 | + case "diffusion_model": |
| 39 | + grouped.transformer[param_name] = value |
| 40 | + case _: |
| 41 | + logger.warning(f"Unexpected submodel name: {submodel_name}") |
| 42 | + return grouped |
27 | 43 |
|
28 |
| - transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} |
29 | 44 |
|
30 |
| - for layer_name, layer_state_dict in grouped_state_dict.items(): |
31 |
| - if layer_name.startswith("diffusion_model"): |
32 |
| - transformer_grouped_sd[layer_name] = layer_state_dict |
33 |
| - else: |
34 |
| - raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.") |
| 45 | +def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw: |
| 46 | + grouped = _group_state_by_submodel(state_dict) |
35 | 47 |
|
36 | 48 | layers: dict[str, BaseLayerPatch] = {}
|
37 |
| - for layer_key, layer_state_dict in transformer_grouped_sd.items(): |
| 49 | + for layer_key, layer_state_dict in grouped.transformer.items(): |
38 | 50 | layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
39 | 51 |
|
40 | 52 | return ModelPatchRaw(layers=layers)
|
0 commit comments