1717EmbeddingReturn = namedtuple ('EmbeddingReturn' , [
1818 'embed_with_concept' ,
1919 'embed_with_superclass' ,
20- 'embed_mask'
20+ 'embed_mask' ,
21+ 'concept_indices'
2122])
2223
2324# helper functions
@@ -120,6 +121,11 @@ def forward(
120121 return_embed_with_superclass = True
121122 ) -> EmbeddingReturn :
122123
124+ assert not (self .training and self .num_concepts > 1 ), 'cannot train with multiple concepts'
125+
126+ if self .training :
127+ concept_id = default (concept_id , 0 )
128+
123129 if exists (concept_id ):
124130 if not isinstance (concept_id , tuple ):
125131 concept_id = (concept_id ,)
@@ -136,6 +142,7 @@ def forward(
136142 assert superclass_mask .any (dim = - 1 ).all (), 'superclass embed id must be present for all prompts'
137143
138144 # automatically replace the superclass id with the concept id
145+
139146 x = torch .where (superclass_mask , inferred_concept_id , x )
140147
141148 # get the embedding mask, defined as not padding id
@@ -177,6 +184,14 @@ def forward(
177184 embeds
178185 )
179186
187+ # whether to return concept indices for the rank-1-edit modules
188+
189+ concept_indices = None
190+
191+ if self .training and exists (concept_id ) and len (concept_id ) == 1 :
192+ concept_mask , = concept_masks
193+ concept_indices = (concept_mask .cumsum (dim = - 1 ) == 0 ).sum (dim = - 1 ).long ()
194+
180195 # if training, and superclass embed id given
181196 # also return embeddings with superclass, for deriving superclass_text_enc
182197
@@ -186,9 +201,9 @@ def forward(
186201 with torch .no_grad ():
187202 superclass_embeds = self .embed (x )
188203
189- return EmbeddingReturn (embeds , superclass_embeds , embed_mask )
204+ return EmbeddingReturn (embeds , superclass_embeds , embed_mask , concept_indices )
190205
191- return EmbeddingReturn (embeds , None , embed_mask )
206+ return EmbeddingReturn (embeds , None , embed_mask , concept_indices )
192207
193208@beartype
194209def merge_embedding_wrappers (
0 commit comments