Skip to content

Commit 5c5108c

Browse files
Kevin Turnerpsychedelicious
authored andcommitted
feat(LoRA): support AI Toolkit LoRA for FLUX [WIP]
1 parent 3df7cfd commit 5c5108c

File tree

4 files changed

+49
-2
lines changed

4 files changed

+49
-2
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def flux_lora_format(cls, mod: ModelOnDisk):
296296
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
297297

298298
sd = mod.load_state_dict(mod.path)
299-
value = flux_format_from_state_dict(sd)
299+
value = flux_format_from_state_dict(sd, mod.metadata())
300300
mod.cache[key] = value
301301
return value
302302

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class FluxLoRAFormat(str, Enum):
137137
Kohya = "flux.kohya"
138138
OneTrainer = "flux.onetrainer"
139139
Control = "flux.control"
140+
AIToolkit = "flux.aitoolkit"
140141

141142

142143
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import json
2+
from collections import defaultdict
3+
from typing import Any
4+
5+
import torch
6+
7+
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
8+
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
9+
lora_layers_from_flux_diffusers_grouped_state_dict,
10+
)
11+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
12+
13+
14+
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any]) -> bool:
15+
if metadata:
16+
software = json.loads(metadata.get("software", "{}"))
17+
return software.get("name") == "ai-toolkit"
18+
# metadata got lost somewhere
19+
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
20+
21+
22+
def lora_model_from_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
23+
# Group keys by layer.
24+
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
25+
for key, value in state_dict.items():
26+
layer_name, param_name = key.split(".", 1)
27+
grouped_state_dict[layer_name][param_name] = value
28+
29+
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
30+
31+
for layer_name, layer_state_dict in grouped_state_dict.items():
32+
if layer_name.startswith("diffusion_model"):
33+
transformer_grouped_sd[layer_name] = layer_state_dict
34+
else:
35+
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
36+
37+
layers: dict[str, BaseLayerPatch] = lora_layers_from_flux_diffusers_grouped_state_dict(
38+
transformer_grouped_sd, alpha=None
39+
)
40+
41+
return ModelPatchRaw(layers=layers)

invokeai/backend/patches/lora_conversions/formats.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
2+
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
3+
is_state_dict_likely_in_aitoolkit_format,
4+
)
25
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
36
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
47
is_state_dict_likely_in_flux_diffusers_format,
@@ -11,7 +14,9 @@
1114
)
1215

1316

14-
def flux_format_from_state_dict(state_dict):
17+
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
18+
if is_state_dict_likely_in_aitoolkit_format(state_dict, metadata):
19+
return FluxLoRAFormat.AIToolkit
1520
if is_state_dict_likely_in_flux_kohya_format(state_dict):
1621
return FluxLoRAFormat.Kohya
1722
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):

0 commit comments

Comments
 (0)