Skip to content

Commit 76b64cc

Browse files
committed
lower the effective learning rate of the concept embedding through the module itself
1 parent 86d1896 commit 76b64cc

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@ def __init__(
5757
key_or_values_proj: nn.Linear,
5858
*,
5959
num_finetune_prompts: int,
60-
C: Tensor, # covariance of input, precomputed from 100K laion text
60+
C: Tensor, # covariance of input, precomputed from 100K laion text
6161
text_seq_len: int = 77,
6262
is_key_proj: bool = False,
6363
input_decay = 0.99,
6464
train_beta = 0.75,
6565
train_temperature = 0.1,
66-
eval_beta = 0.70, # in paper, specified a range (0.6 - 0.75) for local-key lock, and (0.4 -0.6) for global-key lock
67-
eval_temperature = 0.15
66+
eval_beta = 0.70, # in paper, specified a range (0.6 - 0.75) for local-key lock, and (0.4 -0.6) for global-key lock
67+
eval_temperature = 0.15,
68+
frac_gradient_concept_embed = 0.1 # they use a slower learning rate for the embed - this can be achieved by a trick to reduce the gradients going backwards through an operation
6869
):
6970
super().__init__()
7071
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
@@ -81,6 +82,11 @@ def __init__(
8182

8283
self.text_seq_len = text_seq_len
8384

85+
# for the lowered learning rate on the concept embed (0.006 vs 0.03 or something)
86+
87+
assert 0 < frac_gradient_concept_embed <= 1.
88+
self.frac_gradient_concept_embed = frac_gradient_concept_embed
89+
8490
# for exponentially smoothing the inputs
8591
# will smooth both concept and superclass token inputs
8692

@@ -129,6 +135,10 @@ def forward(
129135

130136
weights, decay, Ci = self.weight, self.input_decay, self.C_inv
131137

138+
# reduce learning rate going back to the text encoder and into the concept embed
139+
140+
text_enc = text_enc * self.frac_gradient_concept_embed + text_enc.detach() * (1 - self.frac_gradient_concept_embed)
141+
132142
# beta and temperature depends on whether training or inference
133143

134144
beta, temperature = (self.train_beta, self.train_temperature) if self.training else (self.eval_beta, self.eval_temperature)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.12',
6+
version = '0.0.14',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)