Skip to content

Commit 86a4b42

Browse files
committed
offer way to merge multiple rank1editmodules. concept id will be in the same order as the merge order
1 parent 025bf03 commit 86a4b42

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,14 @@ values = wrapped_to_values(
8282

8383
## Todo
8484

85-
- [ ] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
8685
- [ ] handle rank-1 update for multiple concepts
8786
- [x] handle training with multiple concepts
8887
- [ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
88+
- [ ] accept multiple concept indices
8989
- [ ] offer a magic function that automatically tries to wire up the cross attention by looking for appropriately named `nn.Linear` and auto-inferring which ones are keys or values
9090

91+
- [x] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
92+
- [x] offer function for merging `Rank1EditModule`s
9193
- [x] add the zero-shot masking of concept proposed in paper
9294
- [x] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
9395
- [x] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)

perfusion_pytorch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from perfusion_pytorch.perfusion import (
22
Rank1EditModule,
3-
calculate_input_covariance
3+
calculate_input_covariance,
4+
loss_fn_weighted_by_mask,
5+
merge_rank1_edit_modules
46
)

perfusion_pytorch/perfusion.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from math import ceil
2+
from copy import deepcopy
3+
24
from beartype import beartype
3-
from beartype.typing import Union, List, Optional
5+
from beartype.typing import Union, List, Optional, Tuple
46

57
import torch
68
from torch import nn, einsum, Tensor, IntTensor, LongTensor, FloatTensor
@@ -163,7 +165,7 @@ def __init__(
163165

164166
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
165167

166-
self.concept_output = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)
168+
self.concept_outputs = nn.Parameter(torch.zeros(num_concepts, dim_output), requires_grad = not is_key_proj)
167169

168170
# C in the paper, inverse precomputed
169171

@@ -173,7 +175,7 @@ def parameters(self):
173175
if not self.is_key_proj:
174176
return []
175177

176-
return [self.concept_output]
178+
return [self.concept_outputs]
177179

178180
@beartype
179181
def forward(
@@ -240,21 +242,21 @@ def forward(
240242
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first batch'
241243

242244
# init concept output with superclass output - fixed for keys, learned for values
243-
self.concept_output[concept_id].data.copy_(superclass_output)
245+
self.concept_outputs[concept_id].data.copy_(superclass_output)
244246

245247
elif exists(superclass_output) and self.is_key_proj:
246248
# if text enc with superclass is passed in for more than 1 batch
247249
# just take the opportunity to exponentially average it a bit more for the keys, which have fixed concept output (to superclass)
248250

249-
ema_concept_output = self.concept_output * decay + superclass_output * (1. - decay)
250-
self.concept_output[concept_id].data.copy_(ema_concept_output)
251+
ema_concept_output = self.concept_outputs[concept_id] * decay + superclass_output * (1. - decay)
252+
self.concept_outputs[concept_id].data.copy_(ema_concept_output)
251253

252254
# if any in the batch is not initialized, initialize
253255

254256
if not initted:
255257
ema_concept_text_enc = concept_text_enc
256258
else:
257-
ema_concept_text_enc = self.ema_concept_text_enc[concept_id]
259+
ema_concept_text_enc = self.ema_concept_text_encs[concept_id]
258260

259261
# exponential moving average for concept input encoding
260262

@@ -270,7 +272,7 @@ def forward(
270272

271273
# make it easier to match with paper
272274

273-
i, o, W = self.ema_concept_text_encs[concept_id], self.concept_output[concept_id], weights
275+
i, o, W = self.ema_concept_text_encs[concept_id], self.concept_outputs[concept_id], weights
274276

275277
# main contribution eq (3)
276278

@@ -289,3 +291,24 @@ def forward(
289291
W_em_orthogonal_term = text_enc_output - (sim * concept_output / i_energy)
290292

291293
return W_em_orthogonal_term + sigmoid_term * rearrange(o, 'd -> 1 1 d')
294+
295+
# for merging trained Rank1EditModule(s) above
296+
297+
@beartype
298+
def merge_rank1_edit_modules(
299+
*modules: Rank1EditModule
300+
) -> Rank1EditModule:
301+
302+
assert all([m.initted.item() for m in modules]), 'all modules must be initialized and ideally trained'
303+
assert len(set([m.concept_outputs.shape[-1] for m in modules])) == 1, 'concept output dimension must be the same'
304+
assert len(set([m.is_key_proj for m in modules])) == 1, 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'
305+
306+
merged_module = deepcopy(modules[0])
307+
308+
print(len(modules))
309+
merged_module.num_concepts = sum([m.num_concepts for m in modules])
310+
311+
concept_outputs = torch.cat(tuple(m.concept_outputs.data for m in modules), dim = 0)
312+
merged_module.concept_outputs = nn.Parameter(concept_outputs)
313+
314+
return merged_module

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.22',
6+
version = '0.0.23',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)