@@ -1520,9 +1520,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
1520
1520
1521
1521
def optimizer_update_32bit (
1522
1522
optimizer_name : str ,
1523
- g : Tensor ,
1524
- p : Tensor ,
1525
- state1 : Tensor ,
1523
+ g : torch . Tensor ,
1524
+ p : torch . Tensor ,
1525
+ state1 : torch . Tensor ,
1526
1526
beta1 : float ,
1527
1527
eps : float ,
1528
1528
step : int ,
@@ -1534,6 +1534,7 @@ def optimizer_update_32bit(
1534
1534
unorm_vec : Optional [torch .Tensor ] = None ,
1535
1535
max_unorm : float = 0.0 ,
1536
1536
skip_zeros = False ,
1537
+ return_updates : Optional [torch .Tensor ] = None ,
1537
1538
) -> None :
1538
1539
"""
1539
1540
Performs an inplace optimizer update with one or two optimizer states.
@@ -1572,6 +1573,8 @@ def optimizer_update_32bit(
1572
1573
The maximum update norm relative to the weight norm.
1573
1574
skip_zeros : bool
1574
1575
Whether to skip zero-valued gradients or not (default: False).
1576
+ return_updates: Optional[torch.Tensor]
1577
+ When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
1575
1578
"""
1576
1579
1577
1580
param_norm = 0.0
@@ -1595,6 +1598,7 @@ def optimizer_update_32bit(
1595
1598
optim_func (
1596
1599
get_ptr (g ),
1597
1600
get_ptr (p ),
1601
+ get_ptr (return_updates ),
1598
1602
get_ptr (state1 ),
1599
1603
get_ptr (state2 ),
1600
1604
get_ptr (unorm_vec ),
@@ -1615,25 +1619,26 @@ def optimizer_update_32bit(
1615
1619
1616
1620
def optimizer_update_8bit (
1617
1621
optimizer_name : str ,
1618
- g : Tensor ,
1619
- p : Tensor ,
1620
- state1 : Tensor ,
1622
+ g : torch . Tensor ,
1623
+ p : torch . Tensor ,
1624
+ state1 : torch . Tensor ,
1621
1625
state2 : Optional [torch .Tensor ],
1622
1626
beta1 : float ,
1623
1627
beta2 : float ,
1624
1628
eps : float ,
1625
1629
step : int ,
1626
1630
lr : float ,
1627
- qmap1 : Tensor ,
1631
+ qmap1 : torch . Tensor ,
1628
1632
qmap2 : Optional [torch .Tensor ],
1629
- max1 : Tensor ,
1633
+ max1 : torch . Tensor ,
1630
1634
max2 : Optional [torch .Tensor ],
1631
- new_max1 : Tensor ,
1635
+ new_max1 : torch . Tensor ,
1632
1636
new_max2 : Optional [torch .Tensor ],
1633
1637
weight_decay : float = 0.0 ,
1634
1638
gnorm_scale : float = 1.0 ,
1635
1639
unorm_vec : Optional [torch .Tensor ] = None ,
1636
1640
max_unorm : float = 0.0 ,
1641
+ return_updates : Optional [torch .Tensor ] = None ,
1637
1642
) -> None :
1638
1643
"""
1639
1644
Performs an inplace Adam update.
@@ -1683,6 +1688,8 @@ def optimizer_update_8bit(
1683
1688
The tensor for the update norm.
1684
1689
max_unorm : float
1685
1690
The maximum update norm relative to the weight norm.
1691
+ return_updates: Optional[torch.Tensor]
1692
+ When provided, updates are written to this tensor and not applied directly to `p`. (default: None)
1686
1693
"""
1687
1694
1688
1695
param_norm = 0.0
@@ -1695,6 +1702,7 @@ def optimizer_update_8bit(
1695
1702
str2optimizer8bit [optimizer_name ][0 ](
1696
1703
get_ptr (p ),
1697
1704
get_ptr (g ),
1705
+ get_ptr (return_updates ),
1698
1706
get_ptr (state1 ),
1699
1707
get_ptr (state2 ),
1700
1708
get_ptr (unorm_vec ),
@@ -1719,6 +1727,7 @@ def optimizer_update_8bit(
1719
1727
str2optimizer8bit [optimizer_name ][1 ](
1720
1728
get_ptr (p ),
1721
1729
get_ptr (g ),
1730
+ get_ptr (return_updates ),
1722
1731
get_ptr (state1 ),
1723
1732
get_ptr (state2 ),
1724
1733
get_ptr (unorm_vec ),
@@ -1764,6 +1773,7 @@ def optimizer_update_8bit_blockwise(
1764
1773
weight_decay : float = 0.0 ,
1765
1774
gnorm_scale : float = 1.0 ,
1766
1775
skip_zeros = False ,
1776
+ return_updates : Optional [torch .Tensor ] = None ,
1767
1777
) -> None :
1768
1778
optim_func = None
1769
1779
prev_device = pre_call (g .device )
@@ -1790,6 +1800,7 @@ def optimizer_update_8bit_blockwise(
1790
1800
optim_func (
1791
1801
get_ptr (p ),
1792
1802
get_ptr (g ),
1803
+ get_ptr (return_updates ),
1793
1804
get_ptr (state1 ),
1794
1805
get_ptr (state2 ),
1795
1806
ct .c_float (beta1 ),
0 commit comments