Skip to content

Commit c9adf6d

Browse files
committed
prepare two more functions that make life easy
1 parent 3d0ba4a commit c9adf6d

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ def exists(val):
2525
@beartype
2626
@torch.no_grad()
2727
def calculate_input_covariance(
28-
open_clip: OpenClipAdapter,
28+
clip: OpenClipAdapter,
2929
texts: List[str],
3030
batch_size = 32,
3131
**cov_kwargs
3232
):
33-
embeds, mask = open_clip.embed_texts(texts)
33+
embeds, mask = clip.embed_texts(texts)
3434

3535
num_batches = ceil(len(texts) / batch_size)
3636

@@ -42,13 +42,45 @@ def calculate_input_covariance(
4242
start_index = batch_ind * batch_size
4343
batch_texts = texts[start_index:(start_index + batch_size)]
4444

45-
embeds, mask = open_clip.embed_texts(batch_texts)
45+
embeds, mask = clip.embed_texts(batch_texts)
4646
all_embeds.append(embeds[mask])
4747

48-
all_embeds = torch.cat((all_embeds), dim = 0)
48+
all_embeds = torch.cat(all_embeds, dim = 0)
4949

5050
return einsum('n d, n e -> d e', all_embeds, all_embeds) / length
5151

52+
@beartype
53+
def find_first_index(
54+
indices: IndicesTensor,
55+
concept_or_superclass_id: int
56+
):
57+
"""
58+
for deriving the concept_indices to be passed into the Rank1EditModule
59+
"""
60+
61+
edge = (indices == concept_or_superclass_id).cumsum(dim = -1) # [1, 3, 5, 4, 1, 1], 4 -> [0, 0, 0, 1, 0, 0, 0] -> [0, 0, 0, 1, 1, 1]
62+
return edge.sum(dim = -1)
63+
64+
@beartype
65+
def return_text_enc_with_concept_and_superclass(
66+
text_ids: IndicesTensor,
67+
concept_id: int,
68+
superclass_id: int,
69+
clip: Optional[OpenClipAdapter] = None
70+
):
71+
batch = text_ids.shape[0]
72+
batch_arange = torch.arange(batch, device = text_ids.device)
73+
concept_indices = find_first_index(text_ids, concept_id)
74+
text_ids_with_superclass = text_ids[batch_arange, concept_indices] = superclass_ids
75+
76+
if not exists(clip):
77+
return text_ids, concept_indices, text_ids_with_superclass
78+
79+
concept_text_enc = clip.embed_texts(text_ids)
80+
superclass_text_enc = clip.embed_texts(text_ids_with_superclass)
81+
82+
return concept_text_enc, concept_indices, superclass_text_enc
83+
5284
# a module that wraps the keys and values projection of the cross attentions to text encodings
5385

5486
class Rank1EditModule(Module):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.18',
6+
version = '0.0.19',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)