Skip to content

Commit 06bd7f8

Browse files
committed
tweak
1 parent 86a4b42 commit 06bd7f8

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,13 @@ def merge_rank1_edit_modules(
303303
assert len(set([m.concept_outputs.shape[-1] for m in modules])) == 1, 'concept output dimension must be the same'
304304
assert len(set([m.is_key_proj for m in modules])) == 1, 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'
305305

306-
merged_module = deepcopy(modules[0])
306+
first_module = modules[0]
307+
merged_module = deepcopy(first_module)
307308

308309
print(len(modules))
309310
merged_module.num_concepts = sum([m.num_concepts for m in modules])
310311

311312
concept_outputs = torch.cat(tuple(m.concept_outputs.data for m in modules), dim = 0)
312-
merged_module.concept_outputs = nn.Parameter(concept_outputs)
313+
merged_module.concept_outputs = nn.Parameter(concept_outputs, requires_grad = not first_module.is_key_proj)
313314

314315
return merged_module

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.23',
6+
version = '0.0.24',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)