Skip to content

Commit 2981591

Browse files
Kevin Turnerpsychedelicious
authored andcommitted
test: add some aitoolkit lora tests
1 parent b08f90c commit 2981591

File tree

5 files changed

+548
-18
lines changed

5 files changed

+548
-18
lines changed

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
SubModelType,
2222
)
2323
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
24-
is_state_dict_likely_in_aitoolkit_format,
24+
is_state_dict_likely_in_flux_aitoolkit_format,
2525
lora_model_from_flux_aitoolkit_state_dict,
2626
)
2727
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
@@ -96,7 +96,7 @@ def _load_model(
9696
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
9797
elif is_state_dict_likely_flux_control(state_dict=state_dict):
9898
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
99-
elif is_state_dict_likely_in_aitoolkit_format(state_dict=state_dict):
99+
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
100100
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
101101
else:
102102
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from collections import defaultdict
3+
from dataclasses import dataclass, field
34
from typing import Any
45

56
import torch
@@ -8,33 +9,44 @@
89
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
910
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
1011
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
12+
from invokeai.backend.util import InvokeAILogger
1113

1214

13-
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
15+
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
1416
if metadata:
15-
software = json.loads(metadata.get("software", "{}"))
17+
try:
18+
software = json.loads(metadata.get("software", "{}"))
19+
except json.JSONDecodeError:
20+
return False
1621
return software.get("name") == "ai-toolkit"
1722
# metadata got lost somewhere
1823
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
1924

2025

21-
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
22-
# Group keys by layer.
23-
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
26+
@dataclass
27+
class GroupedStateDict:
28+
transformer: dict = field(default_factory=dict)
29+
# might also grow CLIP and T5 submodels
30+
31+
32+
def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStateDict:
33+
logger = InvokeAILogger.get_logger()
34+
grouped = GroupedStateDict()
2435
for key, value in state_dict.items():
25-
layer_name, param_name = key.split(".", 1)
26-
grouped_state_dict[layer_name][param_name] = value
36+
submodel_name, param_name = key.split(".", 1)
37+
match submodel_name:
38+
case "diffusion_model":
39+
grouped.transformer[param_name] = value
40+
case _:
41+
logger.warning(f"Unexpected submodel name: {submodel_name}")
42+
return grouped
2743

28-
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
2944

30-
for layer_name, layer_state_dict in grouped_state_dict.items():
31-
if layer_name.startswith("diffusion_model"):
32-
transformer_grouped_sd[layer_name] = layer_state_dict
33-
else:
34-
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
45+
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
46+
grouped = _group_state_by_submodel(state_dict)
3547

3648
layers: dict[str, BaseLayerPatch] = {}
37-
for layer_key, layer_state_dict in transformer_grouped_sd.items():
49+
for layer_key, layer_state_dict in grouped.transformer.items():
3850
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
3951

4052
return ModelPatchRaw(layers=layers)

invokeai/backend/patches/lora_conversions/formats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
22
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
3-
is_state_dict_likely_in_aitoolkit_format,
3+
is_state_dict_likely_in_flux_aitoolkit_format,
44
)
55
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
66
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
@@ -15,7 +15,7 @@
1515

1616

1717
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
18-
if is_state_dict_likely_in_aitoolkit_format(state_dict, metadata):
18+
if is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
1919
return FluxLoRAFormat.AIToolkit
2020
if is_state_dict_likely_in_flux_kohya_format(state_dict):
2121
return FluxLoRAFormat.Kohya

0 commit comments

Comments
 (0)