Skip to content

Commit e544f92

Browse files
committed
address final feedback. author says an approximated version without cholesky root for multiple concepts can be tried, so will offer both options
1 parent 20cad68 commit e544f92

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ values = wrapped_to_values(text_enc)
7575

7676
## Todo
7777

78+
- [ ] wire up with SD 1.5, starting with xiao's dreambooth-sd
7879
- [ ] show example in readme for inference with multiple concepts
79-
- [ ] review multiple concepts
8080
- [ ] automatically infer where keys and values projection are if not specified for the `make_key_value_proj_rank1_edit_modules_` function
8181

82+
- [x] review multiple concepts - thanks to Yoad
8283
- [x] offer a function that wires up the cross attention
8384
- [x] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
8485
- [x] accept multiple concept indices

perfusion_pytorch/perfusion.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -126,20 +126,22 @@ def __init__(
126126
key_or_values_proj: nn.Linear,
127127
*,
128128
num_concepts: int = 1,
129-
C: Tensor, # covariance of input, precomputed from 100K laion text
129+
C: Tensor, # covariance of input, precomputed from 100K laion text
130130
text_seq_len: int = 77,
131131
is_key_proj: bool = False,
132132
input_decay = 0.99,
133133
train_beta = 0.75,
134134
train_temperature = 0.1,
135-
eval_beta = 0.70, # in paper, specified a range (0.6 - 0.75) for local-key lock, and (0.4 -0.6) for global-key lock
135+
eval_beta = 0.70, # in paper, specified a range (0.6 - 0.75) for local-key lock, and (0.4 -0.6) for global-key lock
136136
eval_temperature = 0.15,
137-
frac_gradient_concept_embed = 0.1 # they use a slower learning rate for the embed - this can be achieved by a trick to reduce the gradients going backwards through an operation
137+
frac_gradient_concept_embed = 0.1, # they use a slower learning rate for the embed - this can be achieved by a trick to reduce the gradients going backwards through an operation
138+
multi_concepts_use_cholesky = False # use an approximated technique without Cholesky root for multiple concepts
138139
):
139140
super().__init__()
140141
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
141142

142143
self.num_concepts = num_concepts
144+
self.multi_concepts_use_cholesky = multi_concepts_use_cholesky
143145

144146
self.weight = key_or_values_proj.weight
145147
dim_output, dim_input = self.weight.shape
@@ -183,7 +185,7 @@ def num_concepts(self):
183185
def num_concepts(self, value):
184186
self._num_concepts = value
185187

186-
if value == 1:
188+
if value == 1 or not self.multi_concepts_use_cholesky:
187189
return
188190

189191
# for multiple concepts
@@ -330,36 +332,50 @@ def forward(
330332

331333
# main contribution eq (3)
332334

333-
i_energy = opt_einsum('c o, o i, c i ->', i, Ci, i)
335+
i_energy = opt_einsum('c o, o i, c i -> c', i, Ci, i)
336+
i_energy_inv = i_energy ** -1
334337

335338
sim = opt_einsum('b n o, o i, c i -> c b n', text_enc, Ci, i)
336-
sim = rearrange(sim, '... -> ... 1')
337339

338-
sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid()
340+
# calculate W_em_orthogonal_term - depends on single or multiple concepts
339341

340342
if is_multi_concepts:
341-
L_T, L_T_inv = self.L_T, self.L_T_inv
343+
if self.multi_concepts_use_cholesky:
344+
L_T, L_T_inv = self.L_T, self.L_T_inv
345+
346+
# metric - metric space - variable with tilde in Appendix B
342347

343-
# metric - metric space - variable with tilde in Appendix B
348+
# equation (6)
344349

345-
# equation (6)
350+
i_metric = einsum('o i, c i -> c o', L_T, i)
351+
u_metric, _ = torch.linalg.qr(i_metric.T)
352+
u = einsum('o i, i c -> c o', L_T_inv, u_metric)
346353

347-
i_metric = einsum('o i, c i -> c o', L_T, i)
348-
u_metric, _ = torch.linalg.qr(i_metric.T)
349-
u = einsum('o i, i c -> c o', L_T_inv, u_metric)
354+
# equation (10)
350355

351-
# equation (10)
356+
em_orthogonal = text_enc - opt_einsum('c o, b n i, c i -> b n o', u, text_enc, u)
357+
358+
W_em_orthogonal_term = einsum('b n i, o i -> b n o', em_orthogonal, W)
359+
else:
360+
# an approximated version, without Cholesky root
361+
# author says to use this preferentially, and fallback to Cholesky root if there are issues
352362

353-
em_orthogonal = text_enc - opt_einsum('c o, b n i, c i -> b n o', u, text_enc, u)
363+
text_enc_output = einsum('b n i, o i -> b n o', text_enc, W)
354364

355-
W_em_orthogonal_term = einsum('b n i, o i -> b n o', em_orthogonal, W)
365+
W_em_orthogonal_term = text_enc_output - opt_einsum('c b n, c i, o i, c -> b n o', sim, i, W, i_energy_inv)
356366
else:
357367
text_enc_output = einsum('b n i, o i -> b n o', text_enc, W)
358368

359369
concept_output = einsum('c i, o i -> c o', i, W)
360-
concept_output = rearrange(concept_output, 'c d -> c 1 1 d')
361370

362-
W_em_orthogonal_term = text_enc_output - reduce(sim * concept_output / i_energy, 'c ... -> ...', 'sum')
371+
W_em_orthogonal_term = text_enc_output - opt_einsum('c b n, c o, c -> b n o', sim, concept_output, i_energy_inv)
372+
373+
# calculate sigmoid_term (gating)
374+
375+
sim = rearrange(sim, 'c b n -> c b n 1')
376+
i_energy = rearrange(i_energy, 'c -> c 1 1 1')
377+
378+
sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid()
363379

364380
gated_term = sigmoid_term * rearrange(o, 'c d -> c 1 1 d')
365381
gated_term = reduce(gated_term, 'c ... -> ...', 'sum')
@@ -370,7 +386,8 @@ def forward(
370386

371387
@beartype
372388
def merge_rank1_edit_modules(
373-
*modules: Rank1EditModule
389+
*modules: Rank1EditModule,
390+
use_cholesky = False
374391
) -> Rank1EditModule:
375392

376393
assert all([m.initted.all() for m in modules]), 'all modules must be initialized and ideally trained'
@@ -379,6 +396,7 @@ def merge_rank1_edit_modules(
379396

380397
first_module = modules[0]
381398
merged_module = deepcopy(first_module)
399+
merged_module.use_cholesky = use_cholesky
382400

383401
total_concepts = sum([m.num_concepts for m in modules])
384402
merged_module.num_concepts = total_concepts

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.31',
6+
version = '0.1.0',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)