Skip to content

Commit fe83c2f

Browse files
jazzhaikupsychedelicious
authored andcommitted
Add OMI vendor files
1 parent 17dead3 commit fe83c2f

File tree

10 files changed

+472
-0
lines changed

10 files changed

+472
-0
lines changed

invokeai/backend/model_manager/omi/vendor/__init__.py

Whitespace-only changes.

invokeai/backend/model_manager/omi/vendor/convert/__init__.py

Whitespace-only changes.

invokeai/backend/model_manager/omi/vendor/convert/lora/__init__.py

Whitespace-only changes.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
2+
3+
4+
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
5+
keys = []
6+
7+
keys += [LoraConversionKeySet("text_projection", "text_projection", parent=key_prefix)]
8+
9+
for k in map_prefix_range("text_model.encoder.layers", "text_model.encoder.layers", parent=key_prefix):
10+
keys += [LoraConversionKeySet("mlp.fc1", "mlp.fc1", parent=k)]
11+
keys += [LoraConversionKeySet("mlp.fc2", "mlp.fc2", parent=k)]
12+
keys += [LoraConversionKeySet("self_attn.k_proj", "self_attn.k_proj", parent=k)]
13+
keys += [LoraConversionKeySet("self_attn.out_proj", "self_attn.out_proj", parent=k)]
14+
keys += [LoraConversionKeySet("self_attn.q_proj", "self_attn.q_proj", parent=k)]
15+
keys += [LoraConversionKeySet("self_attn.v_proj", "self_attn.v_proj", parent=k)]
16+
17+
return keys
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
2+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
3+
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
4+
5+
6+
def __map_double_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
7+
keys = []
8+
9+
keys += [LoraConversionKeySet("img_attn.qkv.0", "attn.to_q", parent=key_prefix)]
10+
keys += [LoraConversionKeySet("img_attn.qkv.1", "attn.to_k", parent=key_prefix)]
11+
keys += [LoraConversionKeySet("img_attn.qkv.2", "attn.to_v", parent=key_prefix)]
12+
13+
keys += [LoraConversionKeySet("txt_attn.qkv.0", "attn.add_q_proj", parent=key_prefix)]
14+
keys += [LoraConversionKeySet("txt_attn.qkv.1", "attn.add_k_proj", parent=key_prefix)]
15+
keys += [LoraConversionKeySet("txt_attn.qkv.2", "attn.add_v_proj", parent=key_prefix)]
16+
17+
keys += [LoraConversionKeySet("img_attn.proj", "attn.to_out.0", parent=key_prefix)]
18+
keys += [LoraConversionKeySet("img_mlp.0", "ff.net.0.proj", parent=key_prefix)]
19+
keys += [LoraConversionKeySet("img_mlp.2", "ff.net.2", parent=key_prefix)]
20+
keys += [LoraConversionKeySet("img_mod.lin", "norm1.linear", parent=key_prefix)]
21+
22+
keys += [LoraConversionKeySet("txt_attn.proj", "attn.to_add_out", parent=key_prefix)]
23+
keys += [LoraConversionKeySet("txt_mlp.0", "ff_context.net.0.proj", parent=key_prefix)]
24+
keys += [LoraConversionKeySet("txt_mlp.2", "ff_context.net.2", parent=key_prefix)]
25+
keys += [LoraConversionKeySet("txt_mod.lin", "norm1_context.linear", parent=key_prefix)]
26+
27+
return keys
28+
29+
30+
def __map_single_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
31+
keys = []
32+
33+
keys += [LoraConversionKeySet("linear1.0", "attn.to_q", parent=key_prefix)]
34+
keys += [LoraConversionKeySet("linear1.1", "attn.to_k", parent=key_prefix)]
35+
keys += [LoraConversionKeySet("linear1.2", "attn.to_v", parent=key_prefix)]
36+
keys += [LoraConversionKeySet("linear1.3", "proj_mlp", parent=key_prefix)]
37+
38+
keys += [LoraConversionKeySet("linear2", "proj_out", parent=key_prefix)]
39+
keys += [LoraConversionKeySet("modulation.lin", "norm.linear", parent=key_prefix)]
40+
41+
return keys
42+
43+
44+
def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
45+
keys = []
46+
47+
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
48+
keys += [LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)]
49+
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
50+
keys += [LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)]
51+
keys += [LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)]
52+
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
53+
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
54+
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]
55+
keys += [LoraConversionKeySet("time_in.out_layer", "time_text_embed.timestep_embedder.linear_2", parent=key_prefix)]
56+
keys += [LoraConversionKeySet("img_in.proj", "x_embedder", parent=key_prefix)]
57+
58+
for k in map_prefix_range("double_blocks", "transformer_blocks", parent=key_prefix):
59+
keys += __map_double_transformer_block(k)
60+
61+
for k in map_prefix_range("single_blocks", "single_transformer_blocks", parent=key_prefix):
62+
keys += __map_single_transformer_block(k)
63+
64+
return keys
65+
66+
67+
def convert_flux_lora_key_sets() -> list[LoraConversionKeySet]:
68+
keys = []
69+
70+
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
71+
keys += __map_transformer(LoraConversionKeySet("transformer", "lora_transformer"))
72+
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
73+
keys += map_t5(LoraConversionKeySet("t5", "lora_te2"))
74+
75+
return keys
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import torch
2+
from torch import Tensor
3+
4+
from typing_extensions import Self
5+
6+
7+
class LoraConversionKeySet:
8+
def __init__(
9+
self,
10+
omi_prefix: str,
11+
diffusers_prefix: str,
12+
legacy_diffusers_prefix: str | None = None,
13+
parent: Self | None = None,
14+
swap_chunks: bool = False,
15+
filter_is_last: bool | None = None,
16+
next_omi_prefix: str | None = None,
17+
next_diffusers_prefix: str | None = None,
18+
):
19+
if parent is not None:
20+
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
21+
self.diffusers_prefix = combine(parent.diffusers_prefix, diffusers_prefix)
22+
else:
23+
self.omi_prefix = omi_prefix
24+
self.diffusers_prefix = diffusers_prefix
25+
26+
if legacy_diffusers_prefix is None:
27+
self.legacy_diffusers_prefix = self.diffusers_prefix.replace('.', '_')
28+
elif parent is not None:
29+
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace('.', '_')
30+
else:
31+
self.legacy_diffusers_prefix = legacy_diffusers_prefix
32+
33+
self.parent = parent
34+
self.swap_chunks = swap_chunks
35+
self.filter_is_last = filter_is_last
36+
self.prefix = parent
37+
38+
if next_omi_prefix is None and parent is not None:
39+
self.next_omi_prefix = parent.next_omi_prefix
40+
self.next_diffusers_prefix = parent.next_diffusers_prefix
41+
self.next_legacy_diffusers_prefix = parent.next_legacy_diffusers_prefix
42+
elif next_omi_prefix is not None and parent is not None:
43+
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
44+
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
45+
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace('.', '_')
46+
elif next_omi_prefix is not None and parent is None:
47+
self.next_omi_prefix = next_omi_prefix
48+
self.next_diffusers_prefix = next_diffusers_prefix
49+
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace('.', '_')
50+
else:
51+
self.next_omi_prefix = None
52+
self.next_diffusers_prefix = None
53+
self.next_legacy_diffusers_prefix = None
54+
55+
def __get_omi(self, in_prefix: str, key: str) -> str:
56+
return self.omi_prefix + key.removeprefix(in_prefix)
57+
58+
def __get_diffusers(self, in_prefix: str, key: str) -> str:
59+
return self.diffusers_prefix + key.removeprefix(in_prefix)
60+
61+
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
62+
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
63+
64+
suffix = key[key.rfind('.'):]
65+
if suffix not in ['.alpha', '.dora_scale']: # some keys only have a single . in the suffix
66+
suffix = key[key.removesuffix(suffix).rfind('.'):]
67+
key = key.removesuffix(suffix)
68+
69+
return key.replace('.', '_') + suffix
70+
71+
def get_key(self, in_prefix: str, key: str, target: str) -> str:
72+
if target == 'omi':
73+
return self.__get_omi(in_prefix, key)
74+
elif target == 'diffusers':
75+
return self.__get_diffusers(in_prefix, key)
76+
elif target == 'legacy_diffusers':
77+
return self.__get_legacy_diffusers(in_prefix, key)
78+
return key
79+
80+
def __str__(self) -> str:
81+
return f"omi: {self.omi_prefix}, diffusers: {self.diffusers_prefix}, legacy: {self.legacy_diffusers_prefix}"
82+
83+
84+
def combine(left: str, right: str) -> str:
85+
left = left.rstrip('.')
86+
right = right.lstrip('.')
87+
if left == "" or left is None:
88+
return right
89+
elif right == "" or right is None:
90+
return left
91+
else:
92+
return left + "." + right
93+
94+
95+
def map_prefix_range(
96+
omi_prefix: str,
97+
diffusers_prefix: str,
98+
parent: LoraConversionKeySet,
99+
) -> list[LoraConversionKeySet]:
100+
# 100 should be a safe upper bound. increase if it's not enough in the future
101+
return [LoraConversionKeySet(
102+
omi_prefix=f"{omi_prefix}.{i}",
103+
diffusers_prefix=f"{diffusers_prefix}.{i}",
104+
parent=parent,
105+
next_omi_prefix=f"{omi_prefix}.{i + 1}",
106+
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
107+
) for i in range(100)]
108+
109+
110+
def __convert(
111+
state_dict: dict[str, Tensor],
112+
key_sets: list[LoraConversionKeySet],
113+
source: str,
114+
target: str,
115+
) -> dict[str, Tensor]:
116+
out_states = {}
117+
118+
if source == target:
119+
return dict(state_dict)
120+
121+
# TODO: maybe replace with a non O(n^2) algorithm
122+
for key, tensor in state_dict.items():
123+
for key_set in key_sets:
124+
in_prefix = ''
125+
126+
if source == 'omi':
127+
in_prefix = key_set.omi_prefix
128+
elif source == 'diffusers':
129+
in_prefix = key_set.diffusers_prefix
130+
elif source == 'legacy_diffusers':
131+
in_prefix = key_set.legacy_diffusers_prefix
132+
133+
if not key.startswith(in_prefix):
134+
continue
135+
136+
if key_set.filter_is_last is not None:
137+
next_prefix = None
138+
if source == 'omi':
139+
next_prefix = key_set.next_omi_prefix
140+
elif source == 'diffusers':
141+
next_prefix = key_set.next_diffusers_prefix
142+
elif source == 'legacy_diffusers':
143+
next_prefix = key_set.next_legacy_diffusers_prefix
144+
145+
is_last = not any(k.startswith(next_prefix) for k in state_dict)
146+
if key_set.filter_is_last != is_last:
147+
continue
148+
149+
name = key_set.get_key(in_prefix, key, target)
150+
151+
can_swap_chunks = target == 'omi' or source == 'omi'
152+
if key_set.swap_chunks and name.endswith('.lora_up.weight') and can_swap_chunks:
153+
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
154+
tensor = torch.cat([chunk_1, chunk_0], dim=0)
155+
156+
out_states[name] = tensor
157+
158+
break # only map the first matching key set
159+
160+
return out_states
161+
162+
163+
def __detect_source(
164+
state_dict: dict[str, Tensor],
165+
key_sets: list[LoraConversionKeySet],
166+
) -> str:
167+
omi_count = 0
168+
diffusers_count = 0
169+
legacy_diffusers_count = 0
170+
171+
for key in state_dict:
172+
for key_set in key_sets:
173+
if key.startswith(key_set.omi_prefix):
174+
omi_count += 1
175+
if key.startswith(key_set.diffusers_prefix):
176+
diffusers_count += 1
177+
if key.startswith(key_set.legacy_diffusers_prefix):
178+
legacy_diffusers_count += 1
179+
180+
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
181+
return 'omi'
182+
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
183+
return 'diffusers'
184+
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
185+
return 'legacy_diffusers'
186+
187+
return ''
188+
189+
190+
def convert_to_omi(
191+
state_dict: dict[str, Tensor],
192+
key_sets: list[LoraConversionKeySet],
193+
) -> dict[str, Tensor]:
194+
source = __detect_source(state_dict, key_sets)
195+
return __convert(state_dict, key_sets, source, 'omi')
196+
197+
198+
def convert_to_diffusers(
199+
state_dict: dict[str, Tensor],
200+
key_sets: list[LoraConversionKeySet],
201+
) -> dict[str, Tensor]:
202+
source = __detect_source(state_dict, key_sets)
203+
return __convert(state_dict, key_sets, source, 'diffusers')
204+
205+
206+
def convert_to_legacy_diffusers(
207+
state_dict: dict[str, Tensor],
208+
key_sets: list[LoraConversionKeySet],
209+
) -> dict[str, Tensor]:
210+
source = __detect_source(state_dict, key_sets)
211+
return __convert(state_dict, key_sets, source, 'legacy_diffusers')

0 commit comments

Comments
 (0)