Skip to content

Commit ab8c739

Browse files
Kevin Turnerpsychedelicious
authored andcommitted
fix(LoRA): add ai-toolkit to lora loader
1 parent 5c5108c commit ab8c739

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
ModelType,
2121
SubModelType,
2222
)
23+
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
24+
is_state_dict_likely_in_aitoolkit_format,
25+
lora_model_from_flux_aitoolkit_state_dict,
26+
)
2327
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
2428
is_state_dict_likely_flux_control,
2529
lora_model_from_flux_control_state_dict,
@@ -92,6 +96,8 @@ def _load_model(
9296
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
9397
elif is_state_dict_likely_flux_control(state_dict=state_dict):
9498
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
99+
elif is_state_dict_likely_in_aitoolkit_format(state_dict=state_dict):
100+
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
95101
else:
96102
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
97103
else:

invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
1212

1313

14-
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any]) -> bool:
14+
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
1515
if metadata:
1616
software = json.loads(metadata.get("software", "{}"))
1717
return software.get("name") == "ai-toolkit"
1818
# metadata got lost somewhere
1919
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
2020

2121

22-
def lora_model_from_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
22+
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
2323
# Group keys by layer.
2424
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
2525
for key, value in state_dict.items():

0 commit comments

Comments
 (0)