Skip to content

Commit 6afbf31

Browse files
jazzhaikupsychedelicious
authored andcommitted
Ruff formatting
1 parent 3cd4306 commit 6afbf31

File tree

9 files changed

+125
-97
lines changed

9 files changed

+125
-97
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from invokeai.backend.model_hash.hash_validator import validate_hash
3838
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
3939
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
40-
from invokeai.backend.model_manager.omi import stable_diffusion_xl_1_lora, flux_dev_1_lora
40+
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
4141
from invokeai.backend.model_manager.taxonomy import (
4242
AnyVariant,
4343
BaseModelType,
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from invokeai.backend.model_manager.omi.omi import convert_from_omi
12
from invokeai.backend.model_manager.omi.vendor.model_spec.architecture import (
3+
flux_dev_1_lora,
24
stable_diffusion_xl_1_lora,
3-
flux_dev_1_lora
45
)
56

6-
from invokeai.backend.model_manager.omi.omi import convert_from_omi
7+
__all__ = ["flux_dev_1_lora", "stable_diffusion_xl_1_lora", "convert_from_omi"]

invokeai/backend/model_manager/omi/omi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
from invokeai.backend.model_manager.model_on_disk import StateDict
12
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
2-
convert_lora_util as lora_util,
33
convert_flux_lora as omi_flux,
4+
)
5+
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
6+
convert_lora_util as lora_util,
7+
)
8+
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
49
convert_sdxl_lora as omi_sdxl,
510
)
6-
7-
from invokeai.backend.model_manager.model_on_disk import StateDict
811
from invokeai.backend.model_manager.taxonomy import BaseModelType
912

1013

invokeai/backend/model_manager/omi/vendor/convert/lora/convert_clip.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
1+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
2+
LoraConversionKeySet,
3+
map_prefix_range,
4+
)
25

36

47
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:

invokeai/backend/model_manager/omi/vendor/convert/lora/convert_flux_lora.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
2-
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
2+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
3+
LoraConversionKeySet,
4+
map_prefix_range,
5+
)
36
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
47

58

@@ -45,10 +48,16 @@ def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKe
4548
keys = []
4649

4750
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
48-
keys += [LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)]
51+
keys += [
52+
LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)
53+
]
4954
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
50-
keys += [LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)]
51-
keys += [LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)]
55+
keys += [
56+
LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)
57+
]
58+
keys += [
59+
LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)
60+
]
5261
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
5362
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
5463
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]
Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
import torch
22
from torch import Tensor
3-
43
from typing_extensions import Self
54

65

76
class LoraConversionKeySet:
87
def __init__(
9-
self,
10-
omi_prefix: str,
11-
diffusers_prefix: str,
12-
legacy_diffusers_prefix: str | None = None,
13-
parent: Self | None = None,
14-
swap_chunks: bool = False,
15-
filter_is_last: bool | None = None,
16-
next_omi_prefix: str | None = None,
17-
next_diffusers_prefix: str | None = None,
8+
self,
9+
omi_prefix: str,
10+
diffusers_prefix: str,
11+
legacy_diffusers_prefix: str | None = None,
12+
parent: Self | None = None,
13+
swap_chunks: bool = False,
14+
filter_is_last: bool | None = None,
15+
next_omi_prefix: str | None = None,
16+
next_diffusers_prefix: str | None = None,
1817
):
1918
if parent is not None:
2019
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
@@ -24,9 +23,11 @@ def __init__(
2423
self.diffusers_prefix = diffusers_prefix
2524

2625
if legacy_diffusers_prefix is None:
27-
self.legacy_diffusers_prefix = self.diffusers_prefix.replace('.', '_')
26+
self.legacy_diffusers_prefix = self.diffusers_prefix.replace(".", "_")
2827
elif parent is not None:
29-
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace('.', '_')
28+
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace(
29+
".", "_"
30+
)
3031
else:
3132
self.legacy_diffusers_prefix = legacy_diffusers_prefix
3233

@@ -42,11 +43,13 @@ def __init__(
4243
elif next_omi_prefix is not None and parent is not None:
4344
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
4445
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
45-
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace('.', '_')
46+
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace(
47+
".", "_"
48+
)
4649
elif next_omi_prefix is not None and parent is None:
4750
self.next_omi_prefix = next_omi_prefix
4851
self.next_diffusers_prefix = next_diffusers_prefix
49-
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace('.', '_')
52+
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace(".", "_")
5053
else:
5154
self.next_omi_prefix = None
5255
self.next_diffusers_prefix = None
@@ -61,19 +64,19 @@ def __get_diffusers(self, in_prefix: str, key: str) -> str:
6164
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
6265
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
6366

64-
suffix = key[key.rfind('.'):]
65-
if suffix not in ['.alpha', '.dora_scale']: # some keys only have a single . in the suffix
66-
suffix = key[key.removesuffix(suffix).rfind('.'):]
67+
suffix = key[key.rfind(".") :]
68+
if suffix not in [".alpha", ".dora_scale"]: # some keys only have a single . in the suffix
69+
suffix = key[key.removesuffix(suffix).rfind(".") :]
6770
key = key.removesuffix(suffix)
6871

69-
return key.replace('.', '_') + suffix
72+
return key.replace(".", "_") + suffix
7073

7174
def get_key(self, in_prefix: str, key: str, target: str) -> str:
72-
if target == 'omi':
75+
if target == "omi":
7376
return self.__get_omi(in_prefix, key)
74-
elif target == 'diffusers':
77+
elif target == "diffusers":
7578
return self.__get_diffusers(in_prefix, key)
76-
elif target == 'legacy_diffusers':
79+
elif target == "legacy_diffusers":
7780
return self.__get_legacy_diffusers(in_prefix, key)
7881
return key
7982

@@ -82,8 +85,8 @@ def __str__(self) -> str:
8285

8386

8487
def combine(left: str, right: str) -> str:
85-
left = left.rstrip('.')
86-
right = right.lstrip('.')
88+
left = left.rstrip(".")
89+
right = right.lstrip(".")
8790
if left == "" or left is None:
8891
return right
8992
elif right == "" or right is None:
@@ -93,25 +96,28 @@ def combine(left: str, right: str) -> str:
9396

9497

9598
def map_prefix_range(
96-
omi_prefix: str,
97-
diffusers_prefix: str,
98-
parent: LoraConversionKeySet,
99+
omi_prefix: str,
100+
diffusers_prefix: str,
101+
parent: LoraConversionKeySet,
99102
) -> list[LoraConversionKeySet]:
100103
# 100 should be a safe upper bound. increase if it's not enough in the future
101-
return [LoraConversionKeySet(
102-
omi_prefix=f"{omi_prefix}.{i}",
103-
diffusers_prefix=f"{diffusers_prefix}.{i}",
104-
parent=parent,
105-
next_omi_prefix=f"{omi_prefix}.{i + 1}",
106-
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
107-
) for i in range(100)]
104+
return [
105+
LoraConversionKeySet(
106+
omi_prefix=f"{omi_prefix}.{i}",
107+
diffusers_prefix=f"{diffusers_prefix}.{i}",
108+
parent=parent,
109+
next_omi_prefix=f"{omi_prefix}.{i + 1}",
110+
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
111+
)
112+
for i in range(100)
113+
]
108114

109115

110116
def __convert(
111-
state_dict: dict[str, Tensor],
112-
key_sets: list[LoraConversionKeySet],
113-
source: str,
114-
target: str,
117+
state_dict: dict[str, Tensor],
118+
key_sets: list[LoraConversionKeySet],
119+
source: str,
120+
target: str,
115121
) -> dict[str, Tensor]:
116122
out_states = {}
117123

@@ -121,25 +127,25 @@ def __convert(
121127
# TODO: maybe replace with a non O(n^2) algorithm
122128
for key, tensor in state_dict.items():
123129
for key_set in key_sets:
124-
in_prefix = ''
130+
in_prefix = ""
125131

126-
if source == 'omi':
132+
if source == "omi":
127133
in_prefix = key_set.omi_prefix
128-
elif source == 'diffusers':
134+
elif source == "diffusers":
129135
in_prefix = key_set.diffusers_prefix
130-
elif source == 'legacy_diffusers':
136+
elif source == "legacy_diffusers":
131137
in_prefix = key_set.legacy_diffusers_prefix
132138

133139
if not key.startswith(in_prefix):
134140
continue
135141

136142
if key_set.filter_is_last is not None:
137143
next_prefix = None
138-
if source == 'omi':
144+
if source == "omi":
139145
next_prefix = key_set.next_omi_prefix
140-
elif source == 'diffusers':
146+
elif source == "diffusers":
141147
next_prefix = key_set.next_diffusers_prefix
142-
elif source == 'legacy_diffusers':
148+
elif source == "legacy_diffusers":
143149
next_prefix = key_set.next_legacy_diffusers_prefix
144150

145151
is_last = not any(k.startswith(next_prefix) for k in state_dict)
@@ -148,8 +154,8 @@ def __convert(
148154

149155
name = key_set.get_key(in_prefix, key, target)
150156

151-
can_swap_chunks = target == 'omi' or source == 'omi'
152-
if key_set.swap_chunks and name.endswith('.lora_up.weight') and can_swap_chunks:
157+
can_swap_chunks = target == "omi" or source == "omi"
158+
if key_set.swap_chunks and name.endswith(".lora_up.weight") and can_swap_chunks:
153159
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
154160
tensor = torch.cat([chunk_1, chunk_0], dim=0)
155161

@@ -161,8 +167,8 @@ def __convert(
161167

162168

163169
def __detect_source(
164-
state_dict: dict[str, Tensor],
165-
key_sets: list[LoraConversionKeySet],
170+
state_dict: dict[str, Tensor],
171+
key_sets: list[LoraConversionKeySet],
166172
) -> str:
167173
omi_count = 0
168174
diffusers_count = 0
@@ -178,34 +184,34 @@ def __detect_source(
178184
legacy_diffusers_count += 1
179185

180186
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
181-
return 'omi'
187+
return "omi"
182188
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
183-
return 'diffusers'
189+
return "diffusers"
184190
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
185-
return 'legacy_diffusers'
191+
return "legacy_diffusers"
186192

187-
return ''
193+
return ""
188194

189195

190196
def convert_to_omi(
191-
state_dict: dict[str, Tensor],
192-
key_sets: list[LoraConversionKeySet],
197+
state_dict: dict[str, Tensor],
198+
key_sets: list[LoraConversionKeySet],
193199
) -> dict[str, Tensor]:
194200
source = __detect_source(state_dict, key_sets)
195-
return __convert(state_dict, key_sets, source, 'omi')
201+
return __convert(state_dict, key_sets, source, "omi")
196202

197203

198204
def convert_to_diffusers(
199-
state_dict: dict[str, Tensor],
200-
key_sets: list[LoraConversionKeySet],
205+
state_dict: dict[str, Tensor],
206+
key_sets: list[LoraConversionKeySet],
201207
) -> dict[str, Tensor]:
202208
source = __detect_source(state_dict, key_sets)
203-
return __convert(state_dict, key_sets, source, 'diffusers')
209+
return __convert(state_dict, key_sets, source, "diffusers")
204210

205211

206212
def convert_to_legacy_diffusers(
207-
state_dict: dict[str, Tensor],
208-
key_sets: list[LoraConversionKeySet],
213+
state_dict: dict[str, Tensor],
214+
key_sets: list[LoraConversionKeySet],
209215
) -> dict[str, Tensor]:
210216
source = __detect_source(state_dict, key_sets)
211-
return __convert(state_dict, key_sets, source, 'legacy_diffusers')
217+
return __convert(state_dict, key_sets, source, "legacy_diffusers")

invokeai/backend/model_manager/omi/vendor/convert/lora/convert_sdxl_lora.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
2-
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
2+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
3+
LoraConversionKeySet,
4+
map_prefix_range,
5+
)
36

47

58
def __map_unet_resnet_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
@@ -115,7 +118,7 @@ def convert_sdxl_lora_key_sets() -> list[LoraConversionKeySet]:
115118
keys = []
116119

117120
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
118-
keys += __map_unet(LoraConversionKeySet( "unet", "lora_unet"))
121+
keys += __map_unet(LoraConversionKeySet("unet", "lora_unet"))
119122
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
120123
keys += map_clip(LoraConversionKeySet("clip_g", "lora_te2"))
121124

invokeai/backend/model_manager/omi/vendor/convert/lora/convert_t5.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
1+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
2+
LoraConversionKeySet,
3+
map_prefix_range,
4+
)
25

36

47
def map_t5(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:

0 commit comments

Comments
 (0)