Skip to content

Commit 794669d

Browse files
committed
superclass does not need ema
1 parent a88ab93 commit 794669d

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888

8989
self.register_buffer('initted', torch.zeros(num_finetune_prompts).bool())
9090
self.register_buffer('ema_concept_text_encs', torch.zeros(num_finetune_prompts, dim_input))
91-
self.register_buffer('ema_superclass_text_encs', torch.zeros(num_finetune_prompts, dim_input))
91+
self.register_buffer('superclass_text_encs', torch.zeros(num_finetune_prompts, dim_input))
9292
self.register_buffer('superclass_outputs', torch.zeros(num_finetune_prompts, dim_output))
9393

9494
# C in the paper, inverse precomputed
@@ -154,7 +154,12 @@ def forward(
154154
all_initted = initted.all()
155155

156156
ema_concept_text_enc = self.ema_concept_text_encs[prompt_ids]
157-
ema_superclass_text_enc = self.ema_superclass_text_encs[prompt_ids]
157+
158+
# fetch superclass
159+
160+
assert exists(superclass_text_enc) or all_initted
161+
162+
stored_superclass_text_enc = self.superclass_text_encs[prompt_ids]
158163

159164
# for keys, the superclass output (o*) is stored on init
160165
# and never optimized
@@ -170,9 +175,9 @@ def forward(
170175
concept_text_enc
171176
)
172177

173-
ema_superclass_text_enc = torch.where(
178+
superclass_text_enc = torch.where(
174179
initted,
175-
ema_superclass_text_enc,
180+
stored_superclass_text_enc,
176181
superclass_text_enc
177182
)
178183

@@ -182,17 +187,16 @@ def forward(
182187
superclass_output
183188
)
184189

185-
# exponential moving average of both concept and superclass
190+
# exponential moving average for concept input encoding
186191

187192
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
188-
superclass_text_enc = ema_superclass_text_enc * decay + superclass_text_enc * (1. - decay)
189193

190194
# store
191195

192196
if not all_initted:
193197
self.initted[prompt_ids] = True
194198
self.ema_concept_text_encs[prompt_ids] = ema_concept_text_enc
195-
self.ema_superclass_text_encs[prompt_ids] = ema_superclass_text_enc
199+
self.superclass_text_encs[prompt_ids] = superclass_text_enc
196200
self.superclass_outputs[prompt_ids] = superclass_output
197201

198202
# take care of the output

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.10',
6+
version = '0.0.11',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)