Skip to content

Commit 97a2b5e

Browse files
expose momentum into precond, trust region
1 parent f6a911f commit 97a2b5e

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

kron_torch/kron.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ class Kron(torch.optim.Optimizer):
5656
to set all preconditioners to be triangular, 'one_diag' sets the largest
5757
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
5858
to be diagonal.
59+
momentum_into_precond_update: (bool), whether to send momentum into preconditioner
60+
update instead of raw gradients.
5961
mu_dtype (torch.dtype, optional): Dtype of the momentum accumulator.
6062
precond_dtype (torch.dtype, optional): Dtype of the preconditioner.
61-
trust_region_scale (float): Trust region on preconditioned grads. Normally this
62-
doesn't need to be changed but if things seem unstable you can try reducing
63-
this to 1.5.
6463
"""
6564

6665
def __init__(
@@ -73,9 +72,9 @@ def __init__(
7372
max_size_triangular=8192,
7473
min_ndim_triangular=2,
7574
memory_save_mode=None,
75+
momentum_into_precond_update=True,
7676
mu_dtype=None,
7777
precond_dtype=None,
78-
trust_region_scale=1.5,
7978
):
8079
if not 0.0 <= lr:
8180
raise ValueError(f"Invalid learning rate: {lr}")
@@ -95,11 +94,11 @@ def __init__(
9594
max_size_triangular=max_size_triangular,
9695
min_ndim_triangular=min_ndim_triangular,
9796
memory_save_mode=memory_save_mode,
97+
momentum_into_precond_update=momentum_into_precond_update,
9898
precond_lr=0.1, # precond lr hardcoded to 0.1
9999
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
100100
mu_dtype=mu_dtype,
101101
precond_dtype=precond_dtype,
102-
trust_region_scale=trust_region_scale,
103102
)
104103
super(Kron, self).__init__(params, defaults)
105104

@@ -129,8 +128,11 @@ def step(self, closure=None):
129128
balance = self.rng.random() < 0.01 and do_update
130129

131130
for group in self.param_groups:
132-
precond_dtype = group.get("precond_dtype", torch.float32)
133131
mu_dtype = group.get("mu_dtype")
132+
precond_dtype = group.get("precond_dtype", torch.float32)
133+
momentum_into_precond_update = group.get(
134+
"momentum_into_precond_update", True
135+
)
134136

135137
for p in group["params"]:
136138
if p.grad is None:
@@ -197,7 +199,7 @@ def step(self, closure=None):
197199
state["Q"],
198200
state["exprs"],
199201
torch.randn_like(debiased_momentum, dtype=precond_dtype),
200-
debiased_momentum,
202+
debiased_momentum if momentum_into_precond_update else grad,
201203
group["precond_lr"],
202204
self._tiny,
203205
)
@@ -210,9 +212,8 @@ def step(self, closure=None):
210212
trust_region_fn = lambda x: 0.1 * torch.sign(x) * torch.log(
211213
torch.abs(x) + 1
212214
) + 0.9 * torch.tanh(x)
213-
pre_grad = (
214-
trust_region_fn(pre_grad / group["trust_region_scale"])
215-
* group["trust_region_scale"]
215+
pre_grad = torch.clip(
216+
trust_region_fn(pre_grad / 1.5) * 1.5, min=-2, max=2
216217
)
217218

218219
# Apply weight decay and update parameters

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "kron-torch"
7-
version = "0.2.4"
7+
version = "0.2.5"
88
description = "An implementation of PSGD Kron optimizer in PyTorch."
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

0 commit comments

Comments
 (0)