Skip to content

Commit b5fab92

Browse files
committed
almost there
1 parent c6c04fb commit b5fab92

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ values = wrapped_to_values(text_enc)
7575

7676
## Todo
7777

78-
- [ ] offer a magic function that automatically tries to wire up the cross attention by looking for appropriately named `nn.Linear` and auto-inferring which ones are keys or values
7978
- [ ] show example in readme for inference with multiple concepts
8079
- [ ] review multiple concepts
80+
- [ ] automatically infer where keys and values projection are if not specified for the `make_key_value_proj_rank1_edit_modules_` function
8181

82+
- [x] offer a function that wires up the cross attention
8283
- [x] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
8384
- [x] accept multiple concept indices
8485
- [x] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference

perfusion_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Rank1EditModule,
33
calculate_input_covariance,
44
loss_fn_weighted_by_mask,
5-
merge_rank1_edit_modules
5+
merge_rank1_edit_modules,
6+
make_key_value_proj_rank1_edit_modules_
67
)
78

89
from perfusion_pytorch.embedding import (

perfusion_pytorch/perfusion.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,26 @@ def merge_rank1_edit_modules(
388388
merged_module.register_buffer('initted', torch.ones(total_concepts, 1).bool())
389389

390390
return merged_module
391+
392+
# function for wiring up the cross attention
393+
394+
@beartype
395+
def make_key_value_proj_rank1_edit_modules_(
396+
cross_attention: nn.Module,
397+
*,
398+
input_covariance: Tensor,
399+
key_proj_name: str,
400+
value_proj_name: str,
401+
**rank1_edit_module_kwargs
402+
):
403+
linear_key = getattr(cross_attention, key_proj_name, None)
404+
linear_values = getattr(cross_attention, value_proj_name, None)
405+
406+
assert isinstance(linear_key, nn.Linear), f'{key_proj_name} must point to where the keys projection is (ex. self.to_keys = nn.Linear(in, out, bias = False) -> key_proj_name = "to_keys")'
407+
assert isinstance(linear_values, nn.Linear), f'{value_proj_name} must point to where the values projection is (ex. self.to_keys = nn.Linear(in, out, bias = False) -> value_proj_name = "to_values")'
408+
409+
rank1_edit_module_keys = Rank1EditModule(linear_key, input_covariance = input_covariance, is_key_proj = True, **rank1_edit_module_kwargs)
410+
rank1_edit_module_values = Rank1EditModule(linear_values, input_covariance = input_covariance, is_key_proj = False, **rank1_edit_module_kwargs)
411+
412+
setattr(cross_attention, key_proj_name, rank1_edit_module_keys)
413+
setattr(cross_attention, value_proj_name, rank1_edit_module_values)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.28',
6+
version = '0.0.29',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)