Skip to content

Commit df6169d

Browse files
committed
more feedback; authors used uncentered covariance matrix, detail in appendix C
1 parent b1174e2 commit df6169d

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def calculate_input_covariance(
3636

3737
all_embeds = []
3838

39+
length = len(texts)
40+
3941
for batch_ind in range(num_batches):
4042
start_index = batch_ind * batch_size
4143
batch_texts = texts[start_index:(start_index + batch_size)]
@@ -44,9 +46,8 @@ def calculate_input_covariance(
4446
all_embeds.append(embeds[mask])
4547

4648
all_embeds = torch.cat((all_embeds), dim = 0)
47-
all_embeds = rearrange(all_embeds, 'n d -> d n')
4849

49-
return torch.cov(all_embeds, correction = 0, **cov_kwargs)
50+
return einsum('n d, n e -> d e', all_embeds, all_embeds) / length
5051

5152
# a module that wraps the keys and values projection of the cross attentions to text encodings
5253

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.16',
6+
version = '0.0.17',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)