Skip to content

Commit 89f1e6c

Browse files
committed
add yet another function for automatically returning a finetune optimizer
1 parent b6e261d commit 89f1e6c

File tree

4 files changed

+49
-14
lines changed

4 files changed

+49
-14
lines changed

perfusion_pytorch/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
from perfusion_pytorch.save_load import (
1616
save,
17-
load,
18-
get_finetune_parameters
17+
load
18+
)
19+
20+
from perfusion_pytorch.optimizer import (
21+
get_finetune_parameters,
22+
get_finetune_optimizer
1923
)

perfusion_pytorch/optimizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from torch.nn import Module
2+
from torch.optim import AdamW, Adam, Optimizer
3+
4+
from beartype import beartype
5+
6+
from perfusion_pytorch.embedding import EmbeddingWrapper
7+
from perfusion_pytorch.perfusion import Rank1EditModule
8+
9+
# function that automatically finds all the parameters necessary for fine tuning
10+
11+
@beartype
12+
def get_finetune_parameters(text_image_model: Module):
13+
params = []
14+
for module in text_image_model.modules():
15+
if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
16+
params.extend(module.parameters())
17+
18+
return params
19+
20+
@beartype
21+
def get_finetune_optimizer(
22+
text_image_model: Module,
23+
lr = 1e-4,
24+
wd = 1e-2,
25+
betas = (0.9, 0.99),
26+
eps = 1e-8,
27+
**kwargs
28+
) -> Optimizer:
29+
params = get_finetune_parameters(text_image_model)
30+
31+
assert len(params) > 0, 'no finetuneable parameters found'
32+
total_params = sum([p.numel() for p in params])
33+
print(f'optimizing {total_params} parameters')
34+
35+
has_weight_decay = wd > 0
36+
adam_klass = AdamW if has_weight_decay else Adam
37+
adam_kwargs = dict(lr = lr, betas = betas, eps = eps)
38+
39+
if has_weight_decay:
40+
adam_kwargs.update(weight_decay = wd)
41+
42+
return adam_klass(params, **adam_kwargs, **kwargs)

perfusion_pytorch/save_load.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@
1414
def exists(val):
1515
return val is not None
1616

17-
# function that automatically finds all the parameters necessary for fine tuning
18-
19-
@beartype
20-
def get_finetune_parameters(text_image_model: Module):
21-
params = []
22-
for module in text_image_model.modules():
23-
if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
24-
params.extend(module.parameters())
25-
26-
return params
27-
2817
# saving and loading the necessary extra finetuned params
2918

3019
@beartype

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.17',
6+
version = '0.1.19',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)