Skip to content

Commit 7141779

Browse files
no trust region, deterministic updates, norm grads arg
1 parent 97a2b5e commit 7141779

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

kron_torch/kron.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
def precond_update_prob_schedule(
17-
max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=250
17+
max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500
1818
):
1919
"""Anneal preconditioner update probability during beginning of training.
2020
@@ -45,6 +45,7 @@ class Kron(torch.optim.Optimizer):
4545
parameter groups.
4646
lr (float): Learning rate.
4747
b1 (float): Momentum parameter.
48+
normalize_grads (bool): Whether to normalize incoming gradients layer-wise.
4849
weight_decay (float): Weight decay (L2 penalty).
4950
preconditioner_update_probability (callable or float, optional): Probability of
5051
updating the preconditioner. If None, defaults to a schedule that anneals
@@ -67,6 +68,7 @@ def __init__(
6768
params,
6869
lr=0.001,
6970
b1=0.9,
71+
normalize_grads=False,
7072
weight_decay=0.0,
7173
preconditioner_update_probability=None,
7274
max_size_triangular=8192,
@@ -89,6 +91,7 @@ def __init__(
8991
defaults = dict(
9092
lr=lr,
9193
b1=b1,
94+
normalize_grads=normalize_grads,
9295
weight_decay=weight_decay,
9396
preconditioner_update_probability=preconditioner_update_probability,
9497
max_size_triangular=max_size_triangular,
@@ -104,6 +107,7 @@ def __init__(
104107

105108
self._tiny = torch.finfo(torch.bfloat16).tiny
106109
self._prob_step = 0
110+
self._update_counter = 0
107111
self.rng = random.Random(5318008)
108112

109113
@torch.no_grad()
@@ -118,13 +122,17 @@ def step(self, closure=None):
118122
total_precond_size = 0
119123
total_precond_mb = 0
120124

121-
# update preconditioners all together
125+
# update preconditioners all together deterministically
122126
update_prob = self.param_groups[0]["preconditioner_update_probability"]
123127
if callable(update_prob):
124128
update_prob = update_prob(self._prob_step)
125-
do_update = self.rng.random() < update_prob
129+
self._update_counter += 1
130+
do_update = self._update_counter >= 1 / update_prob
131+
if do_update:
132+
self._update_counter = 0
126133
self._prob_step += 1
127134

135+
# balance preconditioners roughly every 100 updates
128136
balance = self.rng.random() < 0.01 and do_update
129137

130138
for group in self.param_groups:
@@ -174,6 +182,9 @@ def step(self, closure=None):
174182

175183
state["step"] += 1
176184

185+
if group["normalize_grads"]:
186+
grad /= torch.norm(grad) + 1e-12
187+
177188
# Update momentum buffer
178189
beta = group["b1"]
179190
bias_correction = 1 - beta ** state["step"]
@@ -209,13 +220,6 @@ def step(self, closure=None):
209220
state["Q"], state["exprs"], debiased_momentum
210221
).to(dtype=p.dtype, non_blocking=True)
211222

212-
trust_region_fn = lambda x: 0.1 * torch.sign(x) * torch.log(
213-
torch.abs(x) + 1
214-
) + 0.9 * torch.tanh(x)
215-
pre_grad = torch.clip(
216-
trust_region_fn(pre_grad / 1.5) * 1.5, min=-2, max=2
217-
)
218-
219223
# Apply weight decay and update parameters
220224
if group["weight_decay"] != 0 and p.dim() >= 2:
221225
pre_grad.add_(p, alpha=group["weight_decay"])

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "kron-torch"
7-
version = "0.2.5"
7+
version = "0.2.6"
88
description = "An implementation of PSGD Kron optimizer in PyTorch."
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }
@@ -34,7 +34,7 @@ classifiers = [
3434
dependencies = [
3535
"torch",
3636
"opt_einsum",
37-
"triton==3.0.0"
37+
"triton"
3838
]
3939

4040
[project.urls]

0 commit comments

Comments
 (0)