Skip to content

Commit 75bdb98

Browse files
committed
remove some code in favor of embedding and clip wrappers, also make clip wrapper more flexible
1 parent 89f1e6c commit 75bdb98

File tree

3 files changed

+13
-38
lines changed

3 files changed

+13
-38
lines changed

perfusion_pytorch/embedding.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,28 @@ def forward(
226226
# and on forward, passes the concept embeddings + superclass concept embeddings through the text transformer + final layernorm
227227
# as well as make the forward pass the ids and superclass_ids through the modified text encoder twice (will attempt to substitute the nn.Embedding with an nn.Identity)
228228

229-
from open_clip import CLIP
230-
231229
class OpenClipEmbedWrapper(Module):
232230
@beartype
233231
def __init__(
234232
self,
235-
clip: CLIP,
233+
clip: Module,
234+
text_transformer_path = 'transformer',
235+
ln_final_path = 'ln_final', # in CLIP, they had the final layernorm separate from the transformer
236236
**embedding_wrapper_kwargs
237237
):
238238
super().__init__()
239239
self.wrapped_embed = EmbeddingWrapper(clip.token_embedding, **embedding_wrapper_kwargs)
240240

241+
path_to_modules = dict([(path, mod) for path, mod in clip.named_modules()])
242+
243+
assert text_transformer_path in path_to_modules
244+
245+
text_transformer = path_to_modules[text_transformer_path]
246+
ln_final = path_to_modules.get(ln_final_path, nn.Identity())
247+
241248
self.text_transformer = nn.Sequential(
242-
clip.transformer,
243-
clip.ln_final
249+
text_transformer,
250+
ln_final
244251
)
245252

246253
def forward(

perfusion_pytorch/perfusion.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -69,38 +69,6 @@ def calculate_input_covariance(
6969

7070
return einsum('n d, n e -> d e', all_embeds, all_embeds) / length
7171

72-
@beartype
73-
def find_first_index(
74-
indices: IndicesTensor,
75-
concept_or_superclass_id: int
76-
):
77-
"""
78-
for deriving the concept_indices to be passed into the Rank1EditModule
79-
"""
80-
81-
edge = (indices == concept_or_superclass_id).cumsum(dim = -1) # [1, 3, 5, 4, 1, 1], 4 -> [0, 0, 0, 1, 0, 0, 0] -> [0, 0, 0, 1, 1, 1]
82-
return edge.sum(dim = -1)
83-
84-
@beartype
85-
def return_text_enc_with_concept_and_superclass(
86-
text_ids: IndicesTensor,
87-
concept_id: int,
88-
superclass_id: int,
89-
clip: Optional[OpenClipAdapter] = None
90-
):
91-
batch = text_ids.shape[0]
92-
batch_arange = torch.arange(batch, device = text_ids.device)
93-
concept_indices = find_first_index(text_ids, concept_id)
94-
text_ids_with_superclass = text_ids[batch_arange, concept_indices] = superclass_ids
95-
96-
if not exists(clip):
97-
return text_ids, concept_indices, text_ids_with_superclass
98-
99-
concept_text_enc = clip.embed_texts(text_ids)
100-
superclass_text_enc = clip.embed_texts(text_ids_with_superclass)
101-
102-
return concept_text_enc, concept_indices, superclass_text_enc
103-
10472
# loss weighted by the mask
10573

10674
@beartype

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.19',
6+
version = '0.1.20',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)