@@ -88,7 +88,7 @@ def __init__(
8888
8989 self .register_buffer ('initted' , torch .zeros (num_finetune_prompts ).bool ())
9090 self .register_buffer ('ema_concept_text_encs' , torch .zeros (num_finetune_prompts , dim_input ))
91- self .register_buffer ('ema_superclass_text_encs ' , torch .zeros (num_finetune_prompts , dim_input ))
91+ self .register_buffer ('superclass_text_encs ' , torch .zeros (num_finetune_prompts , dim_input ))
9292 self .register_buffer ('superclass_outputs' , torch .zeros (num_finetune_prompts , dim_output ))
9393
9494 # C in the paper, inverse precomputed
@@ -154,7 +154,12 @@ def forward(
154154 all_initted = initted .all ()
155155
156156 ema_concept_text_enc = self .ema_concept_text_encs [prompt_ids ]
157- ema_superclass_text_enc = self .ema_superclass_text_encs [prompt_ids ]
157+
158+ # fetch superclass
159+
160+ assert exists (superclass_text_enc ) or all_initted
161+
162+ stored_superclass_text_enc = self .superclass_text_encs [prompt_ids ]
158163
159164 # for keys, the superclass output (o*) is stored on init
160165 # and never optimized
@@ -170,9 +175,9 @@ def forward(
170175 concept_text_enc
171176 )
172177
173- ema_superclass_text_enc = torch .where (
178+ superclass_text_enc = torch .where (
174179 initted ,
175- ema_superclass_text_enc ,
180+ stored_superclass_text_enc ,
176181 superclass_text_enc
177182 )
178183
@@ -182,17 +187,16 @@ def forward(
182187 superclass_output
183188 )
184189
185- # exponential moving average of both concept and superclass
190+ # exponential moving average for concept input encoding
186191
187192 concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay )
188- superclass_text_enc = ema_superclass_text_enc * decay + superclass_text_enc * (1. - decay )
189193
190194 # store
191195
192196 if not all_initted :
193197 self .initted [prompt_ids ] = True
194198 self .ema_concept_text_encs [prompt_ids ] = ema_concept_text_enc
195- self .ema_superclass_text_encs [prompt_ids ] = ema_superclass_text_enc
199+ self .superclass_text_encs [prompt_ids ] = superclass_text_enc
196200 self .superclass_outputs [prompt_ids ] = superclass_output
197201
198202 # take care of the output
0 commit comments