Skip to content

Commit a214f4f

Browse files
Kevin Turnerpsychedelicious
authored andcommitted
fix: group aitoolkit lora layers
1 parent 2981591 commit a214f4f

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import json
2-
from collections import defaultdict
32
from dataclasses import dataclass, field
43
from typing import Any
54

65
import torch
76

87
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
98
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
9+
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer
1010
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
1111
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1212
from invokeai.backend.util import InvokeAILogger
@@ -25,11 +25,11 @@ def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], me
2525

2626
@dataclass
2727
class GroupedStateDict:
28-
transformer: dict = field(default_factory=dict)
28+
transformer: dict[str, Any] = field(default_factory=dict)
2929
# might also grow CLIP and T5 submodels
3030

3131

32-
def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStateDict:
32+
def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict:
3333
logger = InvokeAILogger.get_logger()
3434
grouped = GroupedStateDict()
3535
for key, value in state_dict.items():
@@ -42,11 +42,22 @@ def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStat
4242
return grouped
4343

4444

45+
def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
46+
"""Renames keys from the PEFT LoRA format to the InvokeAI format."""
47+
renamed_state_dict = {}
48+
for key, value in state_dict.items():
49+
renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.")
50+
renamed_state_dict[renamed_key] = value
51+
return renamed_state_dict
52+
53+
4554
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
46-
grouped = _group_state_by_submodel(state_dict)
55+
state_dict = _rename_peft_lora_keys(state_dict)
56+
by_layer = _group_by_layer(state_dict)
57+
by_model = _group_state_by_submodel(by_layer)
4758

4859
layers: dict[str, BaseLayerPatch] = {}
49-
for layer_key, layer_state_dict in grouped.transformer.items():
60+
for layer_key, layer_state_dict in by_model.transformer.items():
5061
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
5162

5263
return ModelPatchRaw(layers=layers)

tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
1212
state_dict_keys as flux_onetrainer_state_dict_keys,
1313
)
14+
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import (
15+
state_dict_keys as flux_aitoolkit_state_dict_keys,
16+
)
1417
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
1518
state_dict_keys as flux_diffusers_state_dict_keys,
1619
)
17-
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import state_dict_keys as flux_aitoolkit_state_dict_keys
1820
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
1921

2022

@@ -46,15 +48,12 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
4648
model_keys = set(model.state_dict().keys())
4749

4850
for converted_key_prefix in converted_key_prefixes:
49-
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), f"'{converted_key_prefix}' did not match any model keys."
51+
assert any(
52+
model_key.startswith(converted_key_prefix) for model_key in model_keys
53+
), f"'{converted_key_prefix}' did not match any model keys."
5054

5155

5256
def test_lora_model_from_flux_aitoolkit_state_dict():
5357
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
5458

55-
lora_model = lora_model_from_flux_aitoolkit_state_dict(state_dict)
56-
57-
# Assert that the lora_model has the expected layers.
58-
# lora_model_keys = set(lora_model.layers.keys())
59-
# lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
60-
# assert lora_model_keys == expected_layer_keys
59+
assert lora_model_from_flux_aitoolkit_state_dict(state_dict)

0 commit comments

Comments
 (0)