File tree Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -1618,18 +1618,18 @@ def optimizer_update_8bit(
1618
1618
g : Tensor ,
1619
1619
p : Tensor ,
1620
1620
state1 : Tensor ,
1621
- state2 : Tensor ,
1621
+ state2 : Optional [ torch . Tensor ] ,
1622
1622
beta1 : float ,
1623
1623
beta2 : float ,
1624
1624
eps : float ,
1625
1625
step : int ,
1626
1626
lr : float ,
1627
1627
qmap1 : Tensor ,
1628
- qmap2 : Tensor ,
1628
+ qmap2 : Optional [ torch . Tensor ] ,
1629
1629
max1 : Tensor ,
1630
- max2 : Tensor ,
1630
+ max2 : Optional [ torch . Tensor ] ,
1631
1631
new_max1 : Tensor ,
1632
- new_max2 : Tensor ,
1632
+ new_max2 : Optional [ torch . Tensor ] ,
1633
1633
weight_decay : float = 0.0 ,
1634
1634
gnorm_scale : float = 1.0 ,
1635
1635
unorm_vec : Optional [torch .Tensor ] = None ,
@@ -1751,16 +1751,16 @@ def optimizer_update_8bit_blockwise(
1751
1751
g : Tensor ,
1752
1752
p : Tensor ,
1753
1753
state1 : Tensor ,
1754
- state2 : Tensor ,
1754
+ state2 : Optional [ torch . Tensor ] ,
1755
1755
beta1 : float ,
1756
1756
beta2 : float ,
1757
1757
eps : float ,
1758
1758
step : int ,
1759
1759
lr : float ,
1760
1760
qmap1 : Tensor ,
1761
- qmap2 : Tensor ,
1761
+ qmap2 : Optional [ torch . Tensor ] ,
1762
1762
absmax1 : Tensor ,
1763
- absmax2 : Tensor ,
1763
+ absmax2 : Optional [ torch . Tensor ] ,
1764
1764
weight_decay : float = 0.0 ,
1765
1765
gnorm_scale : float = 1.0 ,
1766
1766
skip_zeros = False ,
You can’t perform that action at this time.
0 commit comments