1
1
import json
2
- from collections import defaultdict
3
2
from dataclasses import dataclass , field
4
3
from typing import Any
5
4
6
5
import torch
7
6
8
7
from invokeai .backend .patches .layers .base_layer_patch import BaseLayerPatch
9
8
from invokeai .backend .patches .layers .utils import any_lora_layer_from_state_dict
9
+ from invokeai .backend .patches .lora_conversions .flux_diffusers_lora_conversion_utils import _group_by_layer
10
10
from invokeai .backend .patches .lora_conversions .flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
11
11
from invokeai .backend .patches .model_patch_raw import ModelPatchRaw
12
12
from invokeai .backend .util import InvokeAILogger
@@ -25,11 +25,11 @@ def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], me
25
25
26
26
@dataclass
27
27
class GroupedStateDict :
28
- transformer : dict = field (default_factory = dict )
28
+ transformer : dict [ str , Any ] = field (default_factory = dict )
29
29
# might also grow CLIP and T5 submodels
30
30
31
31
32
- def _group_state_by_submodel (state_dict : dict [str , torch . Tensor ]) -> GroupedStateDict :
32
+ def _group_state_by_submodel (state_dict : dict [str , Any ]) -> GroupedStateDict :
33
33
logger = InvokeAILogger .get_logger ()
34
34
grouped = GroupedStateDict ()
35
35
for key , value in state_dict .items ():
@@ -42,11 +42,22 @@ def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStat
42
42
return grouped
43
43
44
44
45
+ def _rename_peft_lora_keys (state_dict : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
46
+ """Renames keys from the PEFT LoRA format to the InvokeAI format."""
47
+ renamed_state_dict = {}
48
+ for key , value in state_dict .items ():
49
+ renamed_key = key .replace (".lora_A." , ".lora_down." ).replace (".lora_B." , ".lora_up." )
50
+ renamed_state_dict [renamed_key ] = value
51
+ return renamed_state_dict
52
+
53
+
45
54
def lora_model_from_flux_aitoolkit_state_dict (state_dict : dict [str , torch .Tensor ]) -> ModelPatchRaw :
46
- grouped = _group_state_by_submodel (state_dict )
55
+ state_dict = _rename_peft_lora_keys (state_dict )
56
+ by_layer = _group_by_layer (state_dict )
57
+ by_model = _group_state_by_submodel (by_layer )
47
58
48
59
layers : dict [str , BaseLayerPatch ] = {}
49
- for layer_key , layer_state_dict in grouped .transformer .items ():
60
+ for layer_key , layer_state_dict in by_model .transformer .items ():
50
61
layers [FLUX_LORA_TRANSFORMER_PREFIX + layer_key ] = any_lora_layer_from_state_dict (layer_state_dict )
51
62
52
63
return ModelPatchRaw (layers = layers )
0 commit comments