Skip to content

Commit 4028b4e

Browse files
committed
return the concept indices from the embedding wrapper as well
1 parent 03059d9 commit 4028b4e

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ wrapped_embed = EmbeddingWrapper(
9292

9393
# now just pass in your prompts with the superclass id
9494

95-
embeds_with_new_concept, embeds_with_superclass, embed_mask = wrapped_embed([
95+
embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = wrapped_embed([
9696
'a portrait of dog',
9797
'dog running through a green field',
9898
'a man walking his dog'
99-
]) # (3, 77, 512), (3, 77, 512), (3, 77)
99+
]) # (3, 77, 512), (3, 77, 512), (3, 77), (3,)
100100

101101
# now pass both embeds through clip text transformer
102102
# the embed_mask needs to be passed to the cross attention as key padding mask

perfusion_pytorch/embedding.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
EmbeddingReturn = namedtuple('EmbeddingReturn', [
1818
'embed_with_concept',
1919
'embed_with_superclass',
20-
'embed_mask'
20+
'embed_mask',
21+
'concept_indices'
2122
])
2223

2324
# helper functions
@@ -120,6 +121,11 @@ def forward(
120121
return_embed_with_superclass = True
121122
) -> EmbeddingReturn:
122123

124+
assert not (self.training and self.num_concepts > 1), 'cannot train with multiple concepts'
125+
126+
if self.training:
127+
concept_id = default(concept_id, 0)
128+
123129
if exists(concept_id):
124130
if not isinstance(concept_id, tuple):
125131
concept_id = (concept_id,)
@@ -136,6 +142,7 @@ def forward(
136142
assert superclass_mask.any(dim = -1).all(), 'superclass embed id must be present for all prompts'
137143

138144
# automatically replace the superclass id with the concept id
145+
139146
x = torch.where(superclass_mask, inferred_concept_id, x)
140147

141148
# get the embedding mask, defined as not padding id
@@ -177,6 +184,14 @@ def forward(
177184
embeds
178185
)
179186

187+
# whether to return concept indices for the rank-1-edit modules
188+
189+
concept_indices = None
190+
191+
if self.training and exists(concept_id) and len(concept_id) == 1:
192+
concept_mask, = concept_masks
193+
concept_indices = (concept_mask.cumsum(dim = -1) == 0).sum(dim = -1).long()
194+
180195
# if training, and superclass embed id given
181196
# also return embeddings with superclass, for deriving superclass_text_enc
182197

@@ -186,9 +201,9 @@ def forward(
186201
with torch.no_grad():
187202
superclass_embeds = self.embed(x)
188203

189-
return EmbeddingReturn(embeds, superclass_embeds, embed_mask)
204+
return EmbeddingReturn(embeds, superclass_embeds, embed_mask, concept_indices)
190205

191-
return EmbeddingReturn(embeds, None, embed_mask)
206+
return EmbeddingReturn(embeds, None, embed_mask, concept_indices)
192207

193208
@beartype
194209
def merge_embedding_wrappers(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.10',
6+
version = '0.1.11',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)