Skip to content

Commit 032ac2e

Browse files
Initial kernel changes for 2-state optimizers to support GaLore
1 parent 0548376 commit 032ac2e

File tree

7 files changed

+102
-261
lines changed

7 files changed

+102
-261
lines changed

bitsandbytes/functional.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,9 +1520,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
15201520

15211521
def optimizer_update_32bit(
15221522
optimizer_name: str,
1523-
g: Tensor,
1524-
p: Tensor,
1525-
state1: Tensor,
1523+
g: torch.Tensor,
1524+
p: torch.Tensor,
1525+
state1: torch.Tensor,
15261526
beta1: float,
15271527
eps: float,
15281528
step: int,
@@ -1534,6 +1534,7 @@ def optimizer_update_32bit(
15341534
unorm_vec: Optional[torch.Tensor] = None,
15351535
max_unorm: float = 0.0,
15361536
skip_zeros=False,
1537+
return_updates: Optional[torch.Tensor] = None,
15371538
) -> None:
15381539
"""
15391540
Performs an inplace optimizer update with one or two optimizer states.
@@ -1572,6 +1573,8 @@ def optimizer_update_32bit(
15721573
The maximum update norm relative to the weight norm.
15731574
skip_zeros : bool
15741575
Whether to skip zero-valued gradients or not (default: False).
1576+
return_updates: Optional[torch.Tensor]
1577+
When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
15751578
"""
15761579

15771580
param_norm = 0.0
@@ -1595,6 +1598,7 @@ def optimizer_update_32bit(
15951598
optim_func(
15961599
get_ptr(g),
15971600
get_ptr(p),
1601+
get_ptr(return_updates),
15981602
get_ptr(state1),
15991603
get_ptr(state2),
16001604
get_ptr(unorm_vec),
@@ -1615,25 +1619,26 @@ def optimizer_update_32bit(
16151619

16161620
def optimizer_update_8bit(
16171621
optimizer_name: str,
1618-
g: Tensor,
1619-
p: Tensor,
1620-
state1: Tensor,
1622+
g: torch.Tensor,
1623+
p: torch.Tensor,
1624+
state1: torch.Tensor,
16211625
state2: Optional[torch.Tensor],
16221626
beta1: float,
16231627
beta2: float,
16241628
eps: float,
16251629
step: int,
16261630
lr: float,
1627-
qmap1: Tensor,
1631+
qmap1: torch.Tensor,
16281632
qmap2: Optional[torch.Tensor],
1629-
max1: Tensor,
1633+
max1: torch.Tensor,
16301634
max2: Optional[torch.Tensor],
1631-
new_max1: Tensor,
1635+
new_max1: torch.Tensor,
16321636
new_max2: Optional[torch.Tensor],
16331637
weight_decay: float = 0.0,
16341638
gnorm_scale: float = 1.0,
16351639
unorm_vec: Optional[torch.Tensor] = None,
16361640
max_unorm: float = 0.0,
1641+
return_updates: Optional[torch.Tensor] = None,
16371642
) -> None:
16381643
"""
16391644
Performs an inplace Adam update.
@@ -1683,6 +1688,8 @@ def optimizer_update_8bit(
16831688
The tensor for the update norm.
16841689
max_unorm : float
16851690
The maximum update norm relative to the weight norm.
1691+
return_updates: Optional[torch.Tensor]
1692+
When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
16861693
"""
16871694

16881695
param_norm = 0.0
@@ -1695,6 +1702,7 @@ def optimizer_update_8bit(
16951702
str2optimizer8bit[optimizer_name][0](
16961703
get_ptr(p),
16971704
get_ptr(g),
1705+
get_ptr(return_updates),
16981706
get_ptr(state1),
16991707
get_ptr(state2),
17001708
get_ptr(unorm_vec),
@@ -1719,6 +1727,7 @@ def optimizer_update_8bit(
17191727
str2optimizer8bit[optimizer_name][1](
17201728
get_ptr(p),
17211729
get_ptr(g),
1730+
get_ptr(return_updates),
17221731
get_ptr(state1),
17231732
get_ptr(state2),
17241733
get_ptr(unorm_vec),
@@ -1764,6 +1773,7 @@ def optimizer_update_8bit_blockwise(
17641773
weight_decay: float = 0.0,
17651774
gnorm_scale: float = 1.0,
17661775
skip_zeros=False,
1776+
return_updates: Optional[torch.Tensor] = None,
17671777
) -> None:
17681778
optim_func = None
17691779
prev_device = pre_call(g.device)
@@ -1790,6 +1800,7 @@ def optimizer_update_8bit_blockwise(
17901800
optim_func(
17911801
get_ptr(p),
17921802
get_ptr(g),
1803+
get_ptr(return_updates),
17931804
get_ptr(state1),
17941805
get_ptr(state2),
17951806
ct.c_float(beta1),

bitsandbytes/optim/optimizer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import abc as container_abcs, defaultdict
66
from copy import deepcopy
77
from itertools import chain
8+
from typing import Any, Dict, Optional
89

910
import torch
1011

@@ -313,7 +314,7 @@ def get_config(self, gindex, pindex, group):
313314
def init_state(self, group, p, gindex, pindex):
314315
raise NotImplementedError("init_state method needs to be overridden")
315316

316-
def update_step(self, group, p, gindex, pindex):
317+
def update_step(self, group, p, gindex, pindex, return_updates):
317318
raise NotImplementedError("The update_step method needs to be overridden")
318319

319320
def get_state_buffer(self, p, dtype=torch.float32):
@@ -473,7 +474,14 @@ def init_state(self, group, p, gindex, pindex):
473474
state["unorm_vec"] = torch.zeros((1,), device=p.device)
474475

475476
@torch.no_grad()
476-
def update_step(self, group, p, gindex, pindex):
477+
def update_step(
478+
self,
479+
group: Dict[str, Any],
480+
p: torch.Tensor,
481+
gindex: int,
482+
pindex: int,
483+
return_updates: Optional[torch.Tensor] = None,
484+
):
477485
state = self.state[p]
478486
grad = p.grad
479487

@@ -509,6 +517,7 @@ def update_step(self, group, p, gindex, pindex):
509517
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
510518
max_unorm=config["max_unorm"],
511519
skip_zeros=config["skip_zeros"],
520+
return_updates=return_updates,
512521
)
513522

514523
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
@@ -533,6 +542,7 @@ def update_step(self, group, p, gindex, pindex):
533542
gnorm_scale=gnorm_scale,
534543
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
535544
max_unorm=config["max_unorm"],
545+
return_updates=return_updates,
536546
)
537547

538548
# swap maxes
@@ -557,6 +567,7 @@ def update_step(self, group, p, gindex, pindex):
557567
config["weight_decay"],
558568
gnorm_scale=gnorm_scale,
559569
skip_zeros=config["skip_zeros"],
570+
return_updates=return_updates,
560571
)
561572

562573

0 commit comments

Comments
 (0)