Skip to content

Commit b1174e2

Browse files
committed
cleanup for yet another round of review
1 parent bd50036 commit b1174e2

File tree

2 files changed

+13
-24
lines changed

2 files changed

+13
-24
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797

9898
self.is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
9999

100-
self.superclass_outputs = nn.Parameter(torch.zeros(dim_output), requires_grad = not is_key_proj)
100+
self.superclass_output = nn.Parameter(torch.zeros(dim_output), requires_grad = not is_key_proj)
101101

102102
# C in the paper, inverse precomputed
103103

@@ -113,10 +113,8 @@ def parameters(self):
113113
def forward(
114114
self,
115115
text_enc: FloatTensor,
116-
text_enc_with_superclass: FloatTensor,
117116
concept_indices: IndicesTensor,
118-
*,
119-
prompt_ids: Optional[IndicesTensor] = None
117+
text_enc_with_superclass: Optional[FloatTensor] = None
120118
):
121119
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
122120

@@ -161,31 +159,26 @@ def forward(
161159

162160
superclass_output = einsum('i, o i -> o', superclass_text_enc, weights)
163161

164-
if self.training and exists(prompt_ids):
165-
# get the initialization state
166-
# as well as the exponentially smoothed text encodings
167-
168-
initted = self.initted.item()
162+
# get the initialization state
169163

170-
ema_concept_text_enc = self.ema_concept_text_encs[prompt_ids]
164+
initted = self.initted.item()
171165

166+
if self.training:
172167
# store the superclass i* if not all initialized
173168
# else fetch it from the buffer
174169

175170
if not initted:
176-
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first epoch for all prompts to initialize the module correctly'
177-
178-
non_initted_prompt_ids = prompt_ids[~initted]
171+
assert exists(superclass_output), 'text_enc_with_superclass must be passed in for the first batch'
179172

180173
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
181-
self.superclass_outputs.data.copy_(superclass_output)
182-
183-
superclass_output = self.superclass_outputs
174+
self.superclass_output.data.copy_(superclass_output)
184175

185176
# if any in the batch is not initialized, initialize
186177

187178
if not initted:
188179
ema_concept_text_enc = concept_text_enc
180+
else:
181+
ema_concept_text_enc = self.ema_concept_text_enc
189182

190183
# exponential moving average for concept input encoding
191184

@@ -196,16 +189,12 @@ def forward(
196189
if not initted:
197190
self.initted.data.copy_(Tensor([True]))
198191
self.ema_concept_text_encs.data.copy_(ema_concept_text_enc)
199-
200-
# take care of the output
201-
# for the keys, make sure to turn off gradients as it is 'locked'
202-
203-
if self.is_key_proj:
204-
superclass_output = superclass_output.detach()
192+
else:
193+
assert initted, 'you have not initialized or trained this module yet'
205194

206195
# make it easier to match with paper
207196

208-
i, o, W = concept_text_enc, superclass_output, weights
197+
i, o, W = concept_text_enc, self.superclass_output, weights
209198

210199
# main contribution eq (3)
211200

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.15',
6+
version = '0.0.16',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)