@@ -1011,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
1011
1011
1012
1012
template <typename T, int OPTIMIZER>
1013
1013
__launch_bounds__ (TH, 1 )
1014
- __global__ void kOptimizer32bit1State(T *g, T *p,
1014
+ __global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
1015
1015
float *state1, float *unorm, const float max_unorm, const float param_norm,
1016
1016
const float beta1, const float beta2, const float eps, const float weight_decay,
1017
1017
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -1057,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
1057
1057
__syncthreads ();
1058
1058
LoadFloat (temp_storage.loadf ).Load (&(state1[i]), s1_vals, valid_items);
1059
1059
__syncthreads ();
1060
- Load (temp_storage.load ).Load (&(p[i]), p_vals, valid_items);
1060
+ Load (temp_storage.load ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1061
1061
1062
1062
# pragma unroll 4
1063
1063
for (unsigned int j = 0 ; j < NUM_PER_THREAD; j++)
1064
1064
{
1065
1065
g_vals[j] = gnorm_scale*((float )g_vals[j]);
1066
- if (weight_decay > 0 .0f )
1066
+ if (weight_decay > 0 .0f && return_updates == nullptr )
1067
1067
g_vals[j] = (float )g_vals[j] + (((float )p_vals[j])*weight_decay);
1068
1068
}
1069
1069
@@ -1080,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
1080
1080
else
1081
1081
s1_vals[j] = s1_vals[j]*beta1 + ((float )g_vals[j]);
1082
1082
1083
- p_vals[j] = ((float )p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
1083
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) + update_scale*(-lr*(s1_vals[j]));
1084
1084
break ;
1085
1085
case LION:
1086
- p_vals[j] = ((float )p_vals[j]) - update_scale*(lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_vals[j]))));
1086
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - update_scale*(lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_vals[j]))));
1087
1087
s1_vals[j] = s1_vals[j]*beta2 + ((1 .0f -beta2)*((float )g_vals[j]));
1088
1088
break ;
1089
1089
case RMSPROP:
1090
1090
s1_vals[j] = s1_vals[j]*beta1 + ((1 .0f -beta1)*((float )g_vals[j])*((float )g_vals[j]));
1091
- p_vals[j] = ((float )p_vals[j]) - update_scale*(lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps));
1091
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - update_scale*(lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps));
1092
1092
break ;
1093
1093
case ADAGRAD:
1094
1094
s1_vals[j] = s1_vals[j] + ((float )g_vals[j])*((float )g_vals[j]);
1095
- p_vals[j] = ((float )p_vals[j]) - lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps);
1095
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*__fdividef ((float )g_vals[j],sqrtf ((float )s1_vals[j])+eps);
1096
1096
break ;
1097
1097
}
1098
1098
}
1099
1099
}
1100
1100
1101
1101
__syncthreads ();
1102
- Store (temp_storage.store ).Store (&(p[i]), p_vals, valid_items);
1102
+ Store (temp_storage.store ).Store (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1103
1103
__syncthreads ();
1104
1104
StoreFloat (temp_storage.storef ).Store (&(state1[i]), s1_vals, valid_items);
1105
1105
}
@@ -1447,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
1447
1447
template <typename T, int OPTIMIZER>
1448
1448
__global__ void
1449
1449
__launch_bounds__ (1024 , 1 )
1450
- kOptimizerStatic8bit1State(T* p, T* const g, unsigned char * state1,
1450
+ kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char * state1,
1451
1451
const float *unorm, const float max_unorm, const float param_norm,
1452
1452
const float beta1, const float beta2,
1453
1453
const float eps, const int step, const float lr,
@@ -1503,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1503
1503
__syncthreads ();
1504
1504
LoadChar (temp_storage.loadc ).Load (&(state1[i]), c1s, valid_items, 128 );
1505
1505
__syncthreads ();
1506
- LoadT (temp_storage.loadh ).Load (&(p[i]), p_vals, valid_items);
1506
+ LoadT (temp_storage.loadh ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1507
1507
1508
1508
if ((i + (threadIdx .x *NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue ; }
1509
1509
@@ -1513,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1513
1513
g_val = float (g_vals[j]);
1514
1514
g_val *= gnorm_scale;
1515
1515
1516
- if (weight_decay > 0 .0f ) {
1516
+ if (weight_decay > 0 .0f && return_updates == nullptr ) {
1517
1517
switch (OPTIMIZER) {
1518
1518
case ADAGRAD:
1519
1519
case MOMENTUM:
@@ -1536,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1536
1536
else
1537
1537
s1_vals[j] = s1_vals[j]*beta1 + ((float )g_vals[j]);
1538
1538
1539
- p_vals[j] = ((float )p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
1539
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) + (-lr*update_scale*(s1_vals[j]));
1540
1540
break ;
1541
1541
case LION:
1542
- p_vals[j] = ((float )p_vals[j]) - (lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_val))));
1542
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - (lr*sgn (((float )s1_vals[j])*beta1 + ((1 .0f -beta1)*((float )g_val))));
1543
1543
s1_vals[j] = s1_vals[j]*beta2 + ((1 .0f -beta2)*g_val);
1544
1544
break ;
1545
1545
case RMSPROP:
1546
1546
s1_vals[j] = s1_vals[j]*beta1 + ((1 .0f -beta1)*(g_val*g_val));
1547
- p_vals[j] = ((float )p_vals[j]) - (lr*__fdividef (g_val,sqrtf (s1_vals[j])+eps));
1547
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - (lr*__fdividef (g_val,sqrtf (s1_vals[j])+eps));
1548
1548
break ;
1549
1549
}
1550
1550
@@ -1560,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1560
1560
}
1561
1561
}
1562
1562
1563
- StoreT (temp_storage.storeh ).Store (&(p[i]), p_vals, valid_items);
1563
+ StoreT (temp_storage.storeh ).Store (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items);
1564
1564
__syncthreads ();
1565
1565
StoreChar (temp_storage.storec ).Store (&(state1[i]), c1s, valid_items);
1566
1566
__syncthreads ();
@@ -1893,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise(
1893
1893
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
1894
1894
__launch_bounds__ (256 , 3 )
1895
1895
__global__ void
1896
- kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char * state1,
1896
+ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char * state1,
1897
1897
const float beta1, const float beta2,
1898
1898
const float eps, const int step, const float lr,
1899
1899
float * __restrict__ const quantiles1,
@@ -1957,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
1957
1957
__syncthreads ();
1958
1958
LoadChar (temp_storage.loadc ).Load (&(state1[i]), c1s, valid_items, 128 );
1959
1959
__syncthreads ();
1960
- LoadT (temp_storage.loadh ).Load (&(p[i]), p_vals, valid_items, (T)0 .0f );
1960
+ LoadT (temp_storage.loadh ).Load (return_updates == nullptr ? &(p[i]) : &(return_updates [i]), p_vals, valid_items, (T)0 .0f );
1961
1961
1962
1962
new_local_abs_max1 = -FLT_MAX;
1963
1963
@@ -1969,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
1969
1969
g_val *= gnorm_scale;
1970
1970
if (!skip_zeros || (skip_zeros && ((float )g_vals[j] != 0 .0f )))
1971
1971
{
1972
- if (weight_decay > 0 .0f ) {
1972
+ if (weight_decay > 0 .0f && return_updates == nullptr ) {
1973
1973
switch (OPTIMIZER) {
1974
1974
case MOMENTUM:
1975
1975
case ADAGRAD:
@@ -2032,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
2032
2032
switch (OPTIMIZER)
2033
2033
{
2034
2034
case MOMENTUM:
2035
- p_vals[j] = ((float )p_vals[j]) - lr*(s1_vals[j]);
2035
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(s1_vals[j]);
2036
2036
break ;
2037
2037
case LION:
2038
- p_vals[j] = ((float )p_vals[j]) - ((float )g_vals[j]);
2038
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - ((float )g_vals[j]);
2039
2039
break ;
2040
2040
case RMSPROP:
2041
2041
g_val = g_vals[j];
2042
- p_vals[j] = ((float )p_vals[j]) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
2042
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
2043
2043
break ;
2044
2044
case ADAGRAD:
2045
2045
g_val = g_vals[j];
2046
- p_vals[j] = ((float )p_vals[j]) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
2046
+ p_vals[j] = (return_updates == nullptr ? (float )p_vals[j] : 0 . 0f ) - lr*(__fdividef (g_val, sqrtf (s1_vals[j])+eps));
2047
2047
break ;
2048
2048
}
2049
2049
}
@@ -3782,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
3782
3782
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
3783
3783
3784
3784
#define MAKE_Optimizer32bit1State (oname, gtype ) \
3785
- template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, float * state1, float *unorm, const float max_unorm, const float param_norm, \
3785
+ template __global__ void kOptimizer32bit1State <gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float * state1, float *unorm, const float max_unorm, const float param_norm, \
3786
3786
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
3787
3787
3788
3788
MAKE_Optimizer32bit1State (MOMENTUM, half)
@@ -3847,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
3847
3847
MAKE_PreconditionStatic8bit1State(ADAGRAD, float )
3848
3848
3849
3849
#define MAKE_optimizerStatic8bit1State (oname, gtype ) \
3850
- template __global__ void kOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* const g, unsigned char * state1, \
3850
+ template __global__ void kOptimizerStatic8bit1State <gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char * state1, \
3851
3851
const float *unorm, const float max_unorm, const float param_norm, \
3852
3852
const float beta1, \
3853
3853
const float beta2, \
@@ -4002,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
4002
4002
4003
4003
#define MAKE_OptimizerStatic8bit1StateBlockwise (oname, gtype, block_size, num_per_thread ) \
4004
4004
template __global__ void kOptimizerStatic8bit1StateBlockwise <gtype, oname, block_size, num_per_thread>( \
4005
- gtype* p, gtype* __restrict__ const g, unsigned char * state1, \
4005
+ gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char * state1, \
4006
4006
const float beta1, const float beta2, \
4007
4007
const float eps, const int step, const float lr, \
4008
4008
float * __restrict__ const quantiles1, \
0 commit comments