1
1
import torch
2
2
from torch import Tensor
3
-
4
3
from typing_extensions import Self
5
4
6
5
7
6
class LoraConversionKeySet :
8
7
def __init__ (
9
- self ,
10
- omi_prefix : str ,
11
- diffusers_prefix : str ,
12
- legacy_diffusers_prefix : str | None = None ,
13
- parent : Self | None = None ,
14
- swap_chunks : bool = False ,
15
- filter_is_last : bool | None = None ,
16
- next_omi_prefix : str | None = None ,
17
- next_diffusers_prefix : str | None = None ,
8
+ self ,
9
+ omi_prefix : str ,
10
+ diffusers_prefix : str ,
11
+ legacy_diffusers_prefix : str | None = None ,
12
+ parent : Self | None = None ,
13
+ swap_chunks : bool = False ,
14
+ filter_is_last : bool | None = None ,
15
+ next_omi_prefix : str | None = None ,
16
+ next_diffusers_prefix : str | None = None ,
18
17
):
19
18
if parent is not None :
20
19
self .omi_prefix = combine (parent .omi_prefix , omi_prefix )
@@ -24,9 +23,11 @@ def __init__(
24
23
self .diffusers_prefix = diffusers_prefix
25
24
26
25
if legacy_diffusers_prefix is None :
27
- self .legacy_diffusers_prefix = self .diffusers_prefix .replace ('.' , '_' )
26
+ self .legacy_diffusers_prefix = self .diffusers_prefix .replace ("." , "_" )
28
27
elif parent is not None :
29
- self .legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , legacy_diffusers_prefix ).replace ('.' , '_' )
28
+ self .legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , legacy_diffusers_prefix ).replace (
29
+ "." , "_"
30
+ )
30
31
else :
31
32
self .legacy_diffusers_prefix = legacy_diffusers_prefix
32
33
@@ -42,11 +43,13 @@ def __init__(
42
43
elif next_omi_prefix is not None and parent is not None :
43
44
self .next_omi_prefix = combine (parent .omi_prefix , next_omi_prefix )
44
45
self .next_diffusers_prefix = combine (parent .diffusers_prefix , next_diffusers_prefix )
45
- self .next_legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , next_diffusers_prefix ).replace ('.' , '_' )
46
+ self .next_legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , next_diffusers_prefix ).replace (
47
+ "." , "_"
48
+ )
46
49
elif next_omi_prefix is not None and parent is None :
47
50
self .next_omi_prefix = next_omi_prefix
48
51
self .next_diffusers_prefix = next_diffusers_prefix
49
- self .next_legacy_diffusers_prefix = next_diffusers_prefix .replace ('.' , '_' )
52
+ self .next_legacy_diffusers_prefix = next_diffusers_prefix .replace ("." , "_" )
50
53
else :
51
54
self .next_omi_prefix = None
52
55
self .next_diffusers_prefix = None
@@ -61,19 +64,19 @@ def __get_diffusers(self, in_prefix: str, key: str) -> str:
61
64
def __get_legacy_diffusers (self , in_prefix : str , key : str ) -> str :
62
65
key = self .legacy_diffusers_prefix + key .removeprefix (in_prefix )
63
66
64
- suffix = key [key .rfind ('.' ) :]
65
- if suffix not in [' .alpha' , ' .dora_scale' ]: # some keys only have a single . in the suffix
66
- suffix = key [key .removesuffix (suffix ).rfind ('.' ) :]
67
+ suffix = key [key .rfind ("." ) :]
68
+ if suffix not in [" .alpha" , " .dora_scale" ]: # some keys only have a single . in the suffix
69
+ suffix = key [key .removesuffix (suffix ).rfind ("." ) :]
67
70
key = key .removesuffix (suffix )
68
71
69
- return key .replace ('.' , '_' ) + suffix
72
+ return key .replace ("." , "_" ) + suffix
70
73
71
74
def get_key (self , in_prefix : str , key : str , target : str ) -> str :
72
- if target == ' omi' :
75
+ if target == " omi" :
73
76
return self .__get_omi (in_prefix , key )
74
- elif target == ' diffusers' :
77
+ elif target == " diffusers" :
75
78
return self .__get_diffusers (in_prefix , key )
76
- elif target == ' legacy_diffusers' :
79
+ elif target == " legacy_diffusers" :
77
80
return self .__get_legacy_diffusers (in_prefix , key )
78
81
return key
79
82
@@ -82,8 +85,8 @@ def __str__(self) -> str:
82
85
83
86
84
87
def combine (left : str , right : str ) -> str :
85
- left = left .rstrip ('.' )
86
- right = right .lstrip ('.' )
88
+ left = left .rstrip ("." )
89
+ right = right .lstrip ("." )
87
90
if left == "" or left is None :
88
91
return right
89
92
elif right == "" or right is None :
@@ -93,25 +96,28 @@ def combine(left: str, right: str) -> str:
93
96
94
97
95
98
def map_prefix_range (
96
- omi_prefix : str ,
97
- diffusers_prefix : str ,
98
- parent : LoraConversionKeySet ,
99
+ omi_prefix : str ,
100
+ diffusers_prefix : str ,
101
+ parent : LoraConversionKeySet ,
99
102
) -> list [LoraConversionKeySet ]:
100
103
# 100 should be a safe upper bound. increase if it's not enough in the future
101
- return [LoraConversionKeySet (
102
- omi_prefix = f"{ omi_prefix } .{ i } " ,
103
- diffusers_prefix = f"{ diffusers_prefix } .{ i } " ,
104
- parent = parent ,
105
- next_omi_prefix = f"{ omi_prefix } .{ i + 1 } " ,
106
- next_diffusers_prefix = f"{ diffusers_prefix } .{ i + 1 } " ,
107
- ) for i in range (100 )]
104
+ return [
105
+ LoraConversionKeySet (
106
+ omi_prefix = f"{ omi_prefix } .{ i } " ,
107
+ diffusers_prefix = f"{ diffusers_prefix } .{ i } " ,
108
+ parent = parent ,
109
+ next_omi_prefix = f"{ omi_prefix } .{ i + 1 } " ,
110
+ next_diffusers_prefix = f"{ diffusers_prefix } .{ i + 1 } " ,
111
+ )
112
+ for i in range (100 )
113
+ ]
108
114
109
115
110
116
def __convert (
111
- state_dict : dict [str , Tensor ],
112
- key_sets : list [LoraConversionKeySet ],
113
- source : str ,
114
- target : str ,
117
+ state_dict : dict [str , Tensor ],
118
+ key_sets : list [LoraConversionKeySet ],
119
+ source : str ,
120
+ target : str ,
115
121
) -> dict [str , Tensor ]:
116
122
out_states = {}
117
123
@@ -121,25 +127,25 @@ def __convert(
121
127
# TODO: maybe replace with a non O(n^2) algorithm
122
128
for key , tensor in state_dict .items ():
123
129
for key_set in key_sets :
124
- in_prefix = ''
130
+ in_prefix = ""
125
131
126
- if source == ' omi' :
132
+ if source == " omi" :
127
133
in_prefix = key_set .omi_prefix
128
- elif source == ' diffusers' :
134
+ elif source == " diffusers" :
129
135
in_prefix = key_set .diffusers_prefix
130
- elif source == ' legacy_diffusers' :
136
+ elif source == " legacy_diffusers" :
131
137
in_prefix = key_set .legacy_diffusers_prefix
132
138
133
139
if not key .startswith (in_prefix ):
134
140
continue
135
141
136
142
if key_set .filter_is_last is not None :
137
143
next_prefix = None
138
- if source == ' omi' :
144
+ if source == " omi" :
139
145
next_prefix = key_set .next_omi_prefix
140
- elif source == ' diffusers' :
146
+ elif source == " diffusers" :
141
147
next_prefix = key_set .next_diffusers_prefix
142
- elif source == ' legacy_diffusers' :
148
+ elif source == " legacy_diffusers" :
143
149
next_prefix = key_set .next_legacy_diffusers_prefix
144
150
145
151
is_last = not any (k .startswith (next_prefix ) for k in state_dict )
@@ -148,8 +154,8 @@ def __convert(
148
154
149
155
name = key_set .get_key (in_prefix , key , target )
150
156
151
- can_swap_chunks = target == ' omi' or source == ' omi'
152
- if key_set .swap_chunks and name .endswith (' .lora_up.weight' ) and can_swap_chunks :
157
+ can_swap_chunks = target == " omi" or source == " omi"
158
+ if key_set .swap_chunks and name .endswith (" .lora_up.weight" ) and can_swap_chunks :
153
159
chunk_0 , chunk_1 = tensor .chunk (2 , dim = 0 )
154
160
tensor = torch .cat ([chunk_1 , chunk_0 ], dim = 0 )
155
161
@@ -161,8 +167,8 @@ def __convert(
161
167
162
168
163
169
def __detect_source (
164
- state_dict : dict [str , Tensor ],
165
- key_sets : list [LoraConversionKeySet ],
170
+ state_dict : dict [str , Tensor ],
171
+ key_sets : list [LoraConversionKeySet ],
166
172
) -> str :
167
173
omi_count = 0
168
174
diffusers_count = 0
@@ -178,34 +184,34 @@ def __detect_source(
178
184
legacy_diffusers_count += 1
179
185
180
186
if omi_count > diffusers_count and omi_count > legacy_diffusers_count :
181
- return ' omi'
187
+ return " omi"
182
188
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count :
183
- return ' diffusers'
189
+ return " diffusers"
184
190
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count :
185
- return ' legacy_diffusers'
191
+ return " legacy_diffusers"
186
192
187
- return ''
193
+ return ""
188
194
189
195
190
196
def convert_to_omi (
191
- state_dict : dict [str , Tensor ],
192
- key_sets : list [LoraConversionKeySet ],
197
+ state_dict : dict [str , Tensor ],
198
+ key_sets : list [LoraConversionKeySet ],
193
199
) -> dict [str , Tensor ]:
194
200
source = __detect_source (state_dict , key_sets )
195
- return __convert (state_dict , key_sets , source , ' omi' )
201
+ return __convert (state_dict , key_sets , source , " omi" )
196
202
197
203
198
204
def convert_to_diffusers (
199
- state_dict : dict [str , Tensor ],
200
- key_sets : list [LoraConversionKeySet ],
205
+ state_dict : dict [str , Tensor ],
206
+ key_sets : list [LoraConversionKeySet ],
201
207
) -> dict [str , Tensor ]:
202
208
source = __detect_source (state_dict , key_sets )
203
- return __convert (state_dict , key_sets , source , ' diffusers' )
209
+ return __convert (state_dict , key_sets , source , " diffusers" )
204
210
205
211
206
212
def convert_to_legacy_diffusers (
207
- state_dict : dict [str , Tensor ],
208
- key_sets : list [LoraConversionKeySet ],
213
+ state_dict : dict [str , Tensor ],
214
+ key_sets : list [LoraConversionKeySet ],
209
215
) -> dict [str , Tensor ]:
210
216
source = __detect_source (state_dict , key_sets )
211
- return __convert (state_dict , key_sets , source , ' legacy_diffusers' )
217
+ return __convert (state_dict , key_sets , source , " legacy_diffusers" )
0 commit comments