Skip to content

Commit a5f9552

Browse files
Initial kernel changes for 2-state optimizers to support GaLore
1 parent 9568735 commit a5f9552

File tree

7 files changed

+114
-141
lines changed

7 files changed

+114
-141
lines changed

bitsandbytes/functional.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,9 +1555,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
15551555

15561556
def optimizer_update_32bit(
15571557
optimizer_name: str,
1558-
g: Tensor,
1559-
p: Tensor,
1560-
state1: Tensor,
1558+
g: torch.Tensor,
1559+
p: torch.Tensor,
1560+
state1: torch.Tensor,
15611561
beta1: float,
15621562
eps: float,
15631563
step: int,
@@ -1571,6 +1571,7 @@ def optimizer_update_32bit(
15711571
unorm_vec: Optional[torch.Tensor] = None,
15721572
max_unorm: float = 0.0,
15731573
skip_zeros=False,
1574+
return_updates: Optional[torch.Tensor] = None,
15741575
) -> None:
15751576
"""
15761577
Performs an inplace optimizer update with one or two optimizer states.
@@ -1613,6 +1614,8 @@ def optimizer_update_32bit(
16131614
The maximum update norm relative to the weight norm.
16141615
skip_zeros : bool
16151616
Whether to skip zero-valued gradients or not (default: False).
1617+
return_updates: Optional[torch.Tensor]
1618+
When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
16161619
"""
16171620

16181621
param_norm = 0.0
@@ -1636,6 +1639,7 @@ def optimizer_update_32bit(
16361639
optim_func(
16371640
get_ptr(g),
16381641
get_ptr(p),
1642+
get_ptr(return_updates),
16391643
get_ptr(state1),
16401644
get_ptr(state2),
16411645
get_ptr(unorm_vec),
@@ -1658,25 +1662,26 @@ def optimizer_update_32bit(
16581662

16591663
def optimizer_update_8bit(
16601664
optimizer_name: str,
1661-
g: Tensor,
1662-
p: Tensor,
1663-
state1: Tensor,
1665+
g: torch.Tensor,
1666+
p: torch.Tensor,
1667+
state1: torch.Tensor,
16641668
state2: Optional[torch.Tensor],
16651669
beta1: float,
16661670
beta2: float,
16671671
eps: float,
16681672
step: int,
16691673
lr: float,
1670-
qmap1: Tensor,
1674+
qmap1: torch.Tensor,
16711675
qmap2: Optional[torch.Tensor],
1672-
max1: Tensor,
1676+
max1: torch.Tensor,
16731677
max2: Optional[torch.Tensor],
1674-
new_max1: Tensor,
1678+
new_max1: torch.Tensor,
16751679
new_max2: Optional[torch.Tensor],
16761680
weight_decay: float = 0.0,
16771681
gnorm_scale: float = 1.0,
16781682
unorm_vec: Optional[torch.Tensor] = None,
16791683
max_unorm: float = 0.0,
1684+
return_updates: Optional[torch.Tensor] = None,
16801685
) -> None:
16811686
"""
16821687
Performs an inplace Adam update.
@@ -1726,6 +1731,8 @@ def optimizer_update_8bit(
17261731
The tensor for the update norm.
17271732
max_unorm : float
17281733
The maximum update norm relative to the weight norm.
1734+
return_updates: Optional[torch.Tensor]
1735+
When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
17291736
"""
17301737

17311738
param_norm = 0.0
@@ -1738,6 +1745,7 @@ def optimizer_update_8bit(
17381745
str2optimizer8bit[optimizer_name][0](
17391746
get_ptr(p),
17401747
get_ptr(g),
1748+
get_ptr(return_updates),
17411749
get_ptr(state1),
17421750
get_ptr(state2),
17431751
get_ptr(unorm_vec),
@@ -1762,6 +1770,7 @@ def optimizer_update_8bit(
17621770
str2optimizer8bit[optimizer_name][1](
17631771
get_ptr(p),
17641772
get_ptr(g),
1773+
get_ptr(return_updates),
17651774
get_ptr(state1),
17661775
get_ptr(state2),
17671776
get_ptr(unorm_vec),
@@ -1809,6 +1818,7 @@ def optimizer_update_8bit_blockwise(
18091818
weight_decay: float = 0.0,
18101819
gnorm_scale: float = 1.0,
18111820
skip_zeros=False,
1821+
return_updates: Optional[torch.Tensor] = None,
18121822
) -> None:
18131823
optim_func = None
18141824
prev_device = pre_call(g.device)
@@ -1835,6 +1845,7 @@ def optimizer_update_8bit_blockwise(
18351845
optim_func(
18361846
get_ptr(p),
18371847
get_ptr(g),
1848+
get_ptr(return_updates),
18381849
get_ptr(state1),
18391850
get_ptr(state2),
18401851
ct.c_float(beta1),

bitsandbytes/optim/optimizer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import abc as container_abcs, defaultdict
66
from copy import deepcopy
77
from itertools import chain
8-
from typing import Optional
8+
from typing import Any, Dict, Optional
99

1010
import torch
1111

@@ -320,7 +320,7 @@ def get_config(self, gindex, pindex, group):
320320
def init_state(self, group, p, gindex, pindex):
321321
raise NotImplementedError("init_state method needs to be overridden")
322322

323-
def update_step(self, group, p, gindex, pindex):
323+
def update_step(self, group, p, gindex, pindex, return_updates):
324324
raise NotImplementedError("The update_step method needs to be overridden")
325325

326326
def get_state_buffer(self, p, dtype=torch.float32):
@@ -494,7 +494,14 @@ def init_state(self, group, p, gindex, pindex):
494494
state["unorm_vec"] = torch.zeros((1,), device=p.device)
495495

496496
@torch.no_grad()
497-
def update_step(self, group, p, gindex, pindex):
497+
def update_step(
498+
self,
499+
group: Dict[str, Any],
500+
p: torch.Tensor,
501+
gindex: int,
502+
pindex: int,
503+
return_updates: Optional[torch.Tensor] = None,
504+
):
498505
# avoid update error from non-contiguous memory layout
499506
p.data = p.data.contiguous()
500507
p.grad = p.grad.contiguous()
@@ -536,6 +543,7 @@ def update_step(self, group, p, gindex, pindex):
536543
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
537544
max_unorm=config["max_unorm"],
538545
skip_zeros=config["skip_zeros"],
546+
return_updates=return_updates,
539547
)
540548

541549
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
@@ -560,6 +568,7 @@ def update_step(self, group, p, gindex, pindex):
560568
gnorm_scale=gnorm_scale,
561569
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
562570
max_unorm=config["max_unorm"],
571+
return_updates=return_updates,
563572
)
564573

565574
# swap maxes
@@ -586,6 +595,7 @@ def update_step(self, group, p, gindex, pindex):
586595
config["weight_decay"],
587596
gnorm_scale=gnorm_scale,
588597
skip_zeros=config["skip_zeros"],
598+
return_updates=return_updates,
589599
)
590600

591601

0 commit comments

Comments
 (0)