@@ -1617,20 +1617,10 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
1617
1617
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
1618
1618
__launch_bounds__ (256 , 3 )
1619
1619
__global__ void
1620
- <<<<<<< HEAD
1621
- kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_updates,
1622
- unsigned char * state1, unsigned char * state2,
1623
- const float beta1, const float beta2,
1624
- const float eps, const int step, const float lr,
1625
- float * __restrict__ const quantiles1, float * __restrict__ const quantiles2,
1626
- float * absmax1, float * absmax2,
1627
- float weight_decay,
1628
- const float gnorm_scale, const bool skip_zeros, const int n)
1629
- {
1630
- =======
1631
1620
kOptimizerStatic8bit2StateBlockwise(
1632
- T* p,
1621
+ T* __restrict__ p,
1633
1622
T* __restrict__ const g,
1623
+ T* __restrict__ return_updates,
1634
1624
unsigned char * state1,
1635
1625
unsigned char * state2,
1636
1626
const float beta1,
@@ -1649,7 +1639,6 @@ kOptimizerStatic8bit2StateBlockwise(
1649
1639
const bool skip_zeros,
1650
1640
const int n
1651
1641
) {
1652
- >>>>>>> d964546 (Add AdEMAMix optimizer (#1360 ))
1653
1642
1654
1643
// const int n_full = n + (n%BLOCK_SIZE);
1655
1644
const int n_full = gridDim .x * BLOCK_SIZE;
@@ -1834,28 +1823,22 @@ kOptimizerStatic8bit2StateBlockwise(
1834
1823
// if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
1835
1824
if (!isnan ((float )g_vals[j]) && !isinf ((float )g_vals[j]))
1836
1825
{
1837
- <<<<<<< HEAD
1838
- if (return_updates == nullptr ) {
1839
- p_vals[j] = (T)(((float )p_vals[j]) + ((step_size*(__fdividef (s1_vals[j],(sqrtf (s2_vals[j])+(correction2*eps)))))));
1840
- if (weight_decay > 0 .0f )
1841
- p_vals[j] = ((float )p_vals[j])*(1 .0f -(lr*weight_decay));
1842
- } else {
1843
- p_vals[j] = (T)(step_size*(__fdividef (s1_vals[j],(sqrtf (s2_vals[j])+(correction2*eps)))));
1844
- }
1845
- =======
1846
1826
if (OPTIMIZER == ADEMAMIX) {
1847
1827
p_vals[j] = T ((float )p_vals[j] - lr * (
1848
1828
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
1849
1829
(sqrtf (s2_vals[j]) / correction2) + eps
1850
1830
)
1851
1831
));
1852
1832
} else {
1853
- p_vals[j] = (T)(((float )p_vals[j]) + ((step_size*(__fdividef (s1_vals[j],(sqrtf (s2_vals[j])+(correction2*eps)))))));
1833
+ if (return_updates == nullptr ) {
1834
+ p_vals[j] = (T)(((float )p_vals[j]) + ((step_size*(__fdividef (s1_vals[j],(sqrtf (s2_vals[j])+(correction2*eps)))))));
1835
+ } else {
1836
+ p_vals[j] = (T)(step_size*(__fdividef (s1_vals[j],(sqrtf (s2_vals[j])+(correction2*eps)))));
1837
+ }
1854
1838
}
1855
1839
1856
- if ( weight_decay > 0 .0f )
1840
+ if (return_updates == nullptr && weight_decay > 0 .0f )
1857
1841
p_vals[j] = ((float )p_vals[j])*(1 .0f -(lr*weight_decay));
1858
- >>> >>> > d964546 (Add AdEMAMix optimizer (#1360 ))
1859
1842
}
1860
1843
}
1861
1844
@@ -3813,7 +3796,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float)
3813
3796
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
3814
3797
3815
3798
#define MAKE_PreconditionOptimizer32bit2State (oname, gtype ) \
3816
- template __global__ void kPreconditionOptimizer32bit2State <gtype, oname, 4096 , 8 >(gtype* g, gtype* p, \
3799
+ template __global__ void kPreconditionOptimizer32bit2State <gtype, oname, 4096 , 8 >(gtype* g, gtype* p, \
3817
3800
float * state1, float * state2, float *unorm, \
3818
3801
const float beta1, const float beta2, const float eps, const float weight_decay, \
3819
3802
const int step, const float lr, const float gnorm_scale, const int n); \
@@ -3825,28 +3808,19 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
3825
3808
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
3826
3809
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)
3827
3810
3828
- <<<<<<< HEAD
3829
3811
template __global__ void kOptimizer32bit2State<float, ADAM>(float * g, float * p, float * return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3830
- const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3831
- template __global__ void kOptimizer32bit2State <half, ADAM>(half* g, half* p, half* return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3832
- const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3833
- template __global__ void kOptimizer32bit2State <__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3834
- const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3835
- =======
3836
- template __global__ void kOptimizer32bit2State <float , ADAM>(float * g, float * p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3837
3812
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3838
- template __global__ void kOptimizer32bit2State <half, ADAM>(half* g, half* p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3813
+ template __global__ void kOptimizer32bit2State <half, ADAM>(half* g, half* p, half* return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3839
3814
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3840
- template __global__ void kOptimizer32bit2State <__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3815
+ template __global__ void kOptimizer32bit2State <__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3841
3816
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3842
- template __global__ void kOptimizer32bit2State <float , ADEMAMIX>(float * g, float * p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3817
+ template __global__ void kOptimizer32bit2State <float , ADEMAMIX>(float * g, float * p, float * return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3843
3818
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3844
- template __global__ void kOptimizer32bit2State <half, ADEMAMIX>(half* g, half* p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3819
+ template __global__ void kOptimizer32bit2State <half, ADEMAMIX>(half* g, half* p, half* return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3845
3820
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3846
- template __global__ void kOptimizer32bit2State <__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3821
+ template __global__ void kOptimizer32bit2State <__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float * state1, float * state2, float *unorm, const float max_unorm, const float param_norm,
3847
3822
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3848
3823
3849
- >>> >>> > d964546 (Add AdEMAMix optimizer (#1360 ))
3850
3824
3851
3825
#define MAKE_PreconditionStatic8bit1State (oname, gtype ) \
3852
3826
template __global__ void kPreconditionOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char *__restrict__ const state1, \
@@ -4006,14 +3980,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General
4006
3980
template __global__ void kDequantizeBlockwise <__nv_bfloat16, 512 , 64 , 8 , NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
4007
3981
4008
3982
#define MAKE_OptimizerStatic8bit2StateBlockwise (oname, gtype, block_size, num_per_thread ) \
4009
- <<<<<<< HEAD
4010
3983
template __global__ void kOptimizerStatic8bit2StateBlockwise <gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, gtype* return_updates, \
4011
3984
unsigned char * state1, unsigned char * state2, \
4012
- const float beta1, const float beta2, \
4013
- =======
4014
- template __global__ void kOptimizerStatic8bit2StateBlockwise <gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char * state1, unsigned char * state2, \
4015
3985
const float beta1, const float beta2, const float beta3, const float alpha, \
4016
- >>> >>> > d964546 (Add AdEMAMix optimizer (#1360 ))
4017
3986
const float eps, const int step, const float lr, \
4018
3987
float * __restrict__ const quantiles1, float * __restrict__ const quantiles2, \
4019
3988
float * absmax1, float * absmax2, \
0 commit comments