Skip to content

Commit 0c6dda0

Browse files
committed
Mark some optimizer update arguments as Noneable (they were being called with Nones)
1 parent 3ec3dd2 commit 0c6dda0

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

bitsandbytes/functional.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,18 +1618,18 @@ def optimizer_update_8bit(
16181618
g: Tensor,
16191619
p: Tensor,
16201620
state1: Tensor,
1621-
state2: Tensor,
1621+
state2: Optional[torch.Tensor],
16221622
beta1: float,
16231623
beta2: float,
16241624
eps: float,
16251625
step: int,
16261626
lr: float,
16271627
qmap1: Tensor,
1628-
qmap2: Tensor,
1628+
qmap2: Optional[torch.Tensor],
16291629
max1: Tensor,
1630-
max2: Tensor,
1630+
max2: Optional[torch.Tensor],
16311631
new_max1: Tensor,
1632-
new_max2: Tensor,
1632+
new_max2: Optional[torch.Tensor],
16331633
weight_decay: float = 0.0,
16341634
gnorm_scale: float = 1.0,
16351635
unorm_vec: Optional[torch.Tensor] = None,
@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise(
17511751
g: Tensor,
17521752
p: Tensor,
17531753
state1: Tensor,
1754-
state2: Tensor,
1754+
state2: Optional[torch.Tensor],
17551755
beta1: float,
17561756
beta2: float,
17571757
eps: float,
17581758
step: int,
17591759
lr: float,
17601760
qmap1: Tensor,
1761-
qmap2: Tensor,
1761+
qmap2: Optional[torch.Tensor],
17621762
absmax1: Tensor,
1763-
absmax2: Tensor,
1763+
absmax2: Optional[torch.Tensor],
17641764
weight_decay: float = 0.0,
17651765
gnorm_scale: float = 1.0,
17661766
skip_zeros=False,

0 commit comments

Comments
 (0)