@@ -1555,9 +1555,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
1555
1555
1556
1556
def optimizer_update_32bit (
1557
1557
optimizer_name : str ,
1558
- g : Tensor ,
1559
- p : Tensor ,
1560
- state1 : Tensor ,
1558
+ g : torch . Tensor ,
1559
+ p : torch . Tensor ,
1560
+ state1 : torch . Tensor ,
1561
1561
beta1 : float ,
1562
1562
eps : float ,
1563
1563
step : int ,
@@ -1571,6 +1571,7 @@ def optimizer_update_32bit(
1571
1571
unorm_vec : Optional [torch .Tensor ] = None ,
1572
1572
max_unorm : float = 0.0 ,
1573
1573
skip_zeros = False ,
1574
+ return_updates : Optional [torch .Tensor ] = None ,
1574
1575
) -> None :
1575
1576
"""
1576
1577
Performs an inplace optimizer update with one or two optimizer states.
@@ -1613,6 +1614,8 @@ def optimizer_update_32bit(
1613
1614
The maximum update norm relative to the weight norm.
1614
1615
skip_zeros : bool
1615
1616
Whether to skip zero-valued gradients or not (default: False).
1617
+ return_updates: Optional[torch.Tensor]
1618
+ When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
1616
1619
"""
1617
1620
1618
1621
param_norm = 0.0
@@ -1636,6 +1639,7 @@ def optimizer_update_32bit(
1636
1639
optim_func (
1637
1640
get_ptr (g ),
1638
1641
get_ptr (p ),
1642
+ get_ptr (return_updates ),
1639
1643
get_ptr (state1 ),
1640
1644
get_ptr (state2 ),
1641
1645
get_ptr (unorm_vec ),
@@ -1658,25 +1662,26 @@ def optimizer_update_32bit(
1658
1662
1659
1663
def optimizer_update_8bit (
1660
1664
optimizer_name : str ,
1661
- g : Tensor ,
1662
- p : Tensor ,
1663
- state1 : Tensor ,
1665
+ g : torch . Tensor ,
1666
+ p : torch . Tensor ,
1667
+ state1 : torch . Tensor ,
1664
1668
state2 : Optional [torch .Tensor ],
1665
1669
beta1 : float ,
1666
1670
beta2 : float ,
1667
1671
eps : float ,
1668
1672
step : int ,
1669
1673
lr : float ,
1670
- qmap1 : Tensor ,
1674
+ qmap1 : torch . Tensor ,
1671
1675
qmap2 : Optional [torch .Tensor ],
1672
- max1 : Tensor ,
1676
+ max1 : torch . Tensor ,
1673
1677
max2 : Optional [torch .Tensor ],
1674
- new_max1 : Tensor ,
1678
+ new_max1 : torch . Tensor ,
1675
1679
new_max2 : Optional [torch .Tensor ],
1676
1680
weight_decay : float = 0.0 ,
1677
1681
gnorm_scale : float = 1.0 ,
1678
1682
unorm_vec : Optional [torch .Tensor ] = None ,
1679
1683
max_unorm : float = 0.0 ,
1684
+ return_updates : Optional [torch .Tensor ] = None ,
1680
1685
) -> None :
1681
1686
"""
1682
1687
Performs an inplace Adam update.
@@ -1726,6 +1731,8 @@ def optimizer_update_8bit(
1726
1731
The tensor for the update norm.
1727
1732
max_unorm : float
1728
1733
The maximum update norm relative to the weight norm.
1734
+ return_updates: Optional[torch.Tensor]
1735
+ When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
1729
1736
"""
1730
1737
1731
1738
param_norm = 0.0
@@ -1738,6 +1745,7 @@ def optimizer_update_8bit(
1738
1745
str2optimizer8bit [optimizer_name ][0 ](
1739
1746
get_ptr (p ),
1740
1747
get_ptr (g ),
1748
+ get_ptr (return_updates ),
1741
1749
get_ptr (state1 ),
1742
1750
get_ptr (state2 ),
1743
1751
get_ptr (unorm_vec ),
@@ -1762,6 +1770,7 @@ def optimizer_update_8bit(
1762
1770
str2optimizer8bit [optimizer_name ][1 ](
1763
1771
get_ptr (p ),
1764
1772
get_ptr (g ),
1773
+ get_ptr (return_updates ),
1765
1774
get_ptr (state1 ),
1766
1775
get_ptr (state2 ),
1767
1776
get_ptr (unorm_vec ),
@@ -1809,6 +1818,7 @@ def optimizer_update_8bit_blockwise(
1809
1818
weight_decay : float = 0.0 ,
1810
1819
gnorm_scale : float = 1.0 ,
1811
1820
skip_zeros = False ,
1821
+ return_updates : Optional [torch .Tensor ] = None ,
1812
1822
) -> None :
1813
1823
optim_func = None
1814
1824
prev_device = pre_call (g .device )
@@ -1835,6 +1845,7 @@ def optimizer_update_8bit_blockwise(
1835
1845
optim_func (
1836
1846
get_ptr (p ),
1837
1847
get_ptr (g ),
1848
+ get_ptr (return_updates ),
1838
1849
get_ptr (state1 ),
1839
1850
get_ptr (state2 ),
1840
1851
ct .c_float (beta1 ),
0 commit comments