@@ -97,7 +97,7 @@ def __init__(
9797
9898 self .is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
9999
100- self .superclass_outputs = nn .Parameter (torch .zeros (dim_output ), requires_grad = not is_key_proj )
100+ self .superclass_output = nn .Parameter (torch .zeros (dim_output ), requires_grad = not is_key_proj )
101101
102102 # C in the paper, inverse precomputed
103103
@@ -113,10 +113,8 @@ def parameters(self):
113113 def forward (
114114 self ,
115115 text_enc : FloatTensor ,
116- text_enc_with_superclass : FloatTensor ,
117116 concept_indices : IndicesTensor ,
118- * ,
119- prompt_ids : Optional [IndicesTensor ] = None
117+ text_enc_with_superclass : Optional [FloatTensor ] = None
120118 ):
121119 assert text_enc .shape [- 2 ] == self .text_seq_len , f'CLIP text sequence length is set to be { self .text_seq_len } , but received text encoding with length { text_enc .shape [- 2 ]} '
122120
@@ -161,31 +159,26 @@ def forward(
161159
162160 superclass_output = einsum ('i, o i -> o' , superclass_text_enc , weights )
163161
164- if self .training and exists (prompt_ids ):
165- # get the initialization state
166- # as well as the exponentially smoothed text encodings
167-
168- initted = self .initted .item ()
162+ # get the initialization state
169163
170- ema_concept_text_enc = self .ema_concept_text_encs [ prompt_ids ]
164+ initted = self .initted . item ()
171165
166+ if self .training :
172167 # store the superclass i* if not all initialized
173168 # else fetch it from the buffer
174169
175170 if not initted :
176- assert exists (superclass_output ), 'text_enc_with_superclass must be passed in for the first epoch for all prompts to initialize the module correctly'
177-
178- non_initted_prompt_ids = prompt_ids [~ initted ]
171+ assert exists (superclass_output ), 'text_enc_with_superclass must be passed in for the first batch'
179172
180173 # for the prompt ids not initialized yet, hard copy over the initial superclass outputs
181- self .superclass_outputs .data .copy_ (superclass_output )
182-
183- superclass_output = self .superclass_outputs
174+ self .superclass_output .data .copy_ (superclass_output )
184175
185176 # if any in the batch is not initialized, initialize
186177
187178 if not initted :
188179 ema_concept_text_enc = concept_text_enc
180+ else :
181+ ema_concept_text_enc = self .ema_concept_text_enc
189182
190183 # exponential moving average for concept input encoding
191184
@@ -196,16 +189,12 @@ def forward(
196189 if not initted :
197190 self .initted .data .copy_ (Tensor ([True ]))
198191 self .ema_concept_text_encs .data .copy_ (ema_concept_text_enc )
199-
200- # take care of the output
201- # for the keys, make sure to turn off gradients as it is 'locked'
202-
203- if self .is_key_proj :
204- superclass_output = superclass_output .detach ()
192+ else :
193+ assert initted , 'you have not initialized or trained this module yet'
205194
206195 # make it easier to match with paper
207196
208- i , o , W = concept_text_enc , superclass_output , weights
197+ i , o , W = concept_text_enc , self . superclass_output , weights
209198
210199 # main contribution eq (3)
211200
0 commit comments