Skip to content

Commit b1fb85b

Browse files
Support eturn_outputs buffer option for 1-state optimizers
1 parent 31854da commit b1fb85b

File tree

3 files changed

+33
-33
lines changed

3 files changed

+33
-33
lines changed

csrc/kernels.cu

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
10111011

10121012
template<typename T, int OPTIMIZER>
10131013
__launch_bounds__(TH, 1)
1014-
__global__ void kOptimizer32bit1State(T *g, T *p,
1014+
__global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
10151015
float *state1, float *unorm, const float max_unorm, const float param_norm,
10161016
const float beta1, const float beta2, const float eps, const float weight_decay,
10171017
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@@ -1057,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
10571057
__syncthreads();
10581058
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
10591059
__syncthreads();
1060-
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
1060+
Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
10611061

10621062
# pragma unroll 4
10631063
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
10641064
{
10651065
g_vals[j] = gnorm_scale*((float)g_vals[j]);
1066-
if(weight_decay > 0.0f)
1066+
if(weight_decay > 0.0f && return_updates == nullptr)
10671067
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
10681068
}
10691069

@@ -1080,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
10801080
else
10811081
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
10821082

1083-
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
1083+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + update_scale*(-lr*(s1_vals[j]));
10841084
break;
10851085
case LION:
1086-
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
1086+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
10871087
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
10881088
break;
10891089
case RMSPROP:
10901090
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
1091-
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
1091+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
10921092
break;
10931093
case ADAGRAD:
10941094
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
1095-
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
1095+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
10961096
break;
10971097
}
10981098
}
10991099
}
11001100

11011101
__syncthreads();
1102-
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
1102+
Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
11031103
__syncthreads();
11041104
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
11051105
}
@@ -1447,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
14471447
template<typename T, int OPTIMIZER>
14481448
__global__ void
14491449
__launch_bounds__(1024, 1)
1450-
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
1450+
kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
14511451
const float *unorm, const float max_unorm, const float param_norm,
14521452
const float beta1, const float beta2,
14531453
const float eps, const int step, const float lr,
@@ -1503,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15031503
__syncthreads();
15041504
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
15051505
__syncthreads();
1506-
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
1506+
LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
15071507

15081508
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
15091509

@@ -1513,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15131513
g_val = float(g_vals[j]);
15141514
g_val *= gnorm_scale;
15151515

1516-
if(weight_decay > 0.0f) {
1516+
if(weight_decay > 0.0f && return_updates == nullptr) {
15171517
switch(OPTIMIZER) {
15181518
case ADAGRAD:
15191519
case MOMENTUM:
@@ -1536,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15361536
else
15371537
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
15381538

1539-
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
1539+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + (-lr*update_scale*(s1_vals[j]));
15401540
break;
15411541
case LION:
1542-
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
1542+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
15431543
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
15441544
break;
15451545
case RMSPROP:
15461546
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
1547-
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
1547+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
15481548
break;
15491549
}
15501550

@@ -1560,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
15601560
}
15611561
}
15621562

1563-
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
1563+
StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
15641564
__syncthreads();
15651565
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
15661566
__syncthreads();
@@ -1893,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise(
18931893
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
18941894
__launch_bounds__(256, 3)
18951895
__global__ void
1896-
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
1896+
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
18971897
const float beta1, const float beta2,
18981898
const float eps, const int step, const float lr,
18991899
float* __restrict__ const quantiles1,
@@ -1957,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19571957
__syncthreads();
19581958
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
19591959
__syncthreads();
1960-
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
1960+
LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f);
19611961

19621962
new_local_abs_max1 = -FLT_MAX;
19631963

@@ -1969,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
19691969
g_val *= gnorm_scale;
19701970
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
19711971
{
1972-
if(weight_decay > 0.0f) {
1972+
if(weight_decay > 0.0f && return_updates == nullptr) {
19731973
switch(OPTIMIZER) {
19741974
case MOMENTUM:
19751975
case ADAGRAD:
@@ -2032,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
20322032
switch(OPTIMIZER)
20332033
{
20342034
case MOMENTUM:
2035-
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
2035+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(s1_vals[j]);
20362036
break;
20372037
case LION:
2038-
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
2038+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - ((float)g_vals[j]);
20392039
break;
20402040
case RMSPROP:
20412041
g_val = g_vals[j];
2042-
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
2042+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
20432043
break;
20442044
case ADAGRAD:
20452045
g_val = g_vals[j];
2046-
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
2046+
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
20472047
break;
20482048
}
20492049
}
@@ -3782,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
37823782
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)
37833783

37843784
#define MAKE_Optimizer32bit1State(oname, gtype) \
3785-
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
3785+
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, \
37863786
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
37873787

37883788
MAKE_Optimizer32bit1State(MOMENTUM, half)
@@ -3847,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
38473847
MAKE_PreconditionStatic8bit1State(ADAGRAD, float)
38483848

38493849
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
3850-
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
3850+
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char* state1, \
38513851
const float *unorm, const float max_unorm, const float param_norm, \
38523852
const float beta1, \
38533853
const float beta2, \
@@ -4002,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)
40024002

40034003
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
40044004
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
4005-
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
4005+
gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char* state1, \
40064006
const float beta1, const float beta2, \
40074007
const float eps, const int step, const float lr, \
40084008
float* __restrict__ const quantiles1, \

csrc/kernels.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
3838
const int step, const float lr, const float gnorm_scale, const int n);
3939

4040
template<typename T, int OPTIMIZER>
41-
__global__ void kOptimizer32bit1State(T* g, T* p,
41+
__global__ void kOptimizer32bit1State(T* g, T* p, T* return_updates,
4242
float* state1, float *unorm, const float max_unorm, const float param_norm,
4343
const float beta1, const float beta2, const float eps, const float weight_decay,
4444
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
@@ -57,7 +57,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
5757

5858
template<typename T, int OPTIMIZER>
5959
__global__ void
60-
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
60+
kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
6161
const float *unorm, const float max_unorm, const float param_norm,
6262
const float beta1, const float beta2,
6363
const float eps, const int step, const float lr,
@@ -96,7 +96,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
9696
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
9797

9898
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
99-
T* p, T* __restrict__ const g, unsigned char* state1,
99+
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
100100
const float beta1, const float beta2,
101101
const float eps, const int step, const float lr,
102102
float* __restrict__ const quantiles1,

csrc/ops.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_up
122122
CUDA_CHECK_RETURN(cudaPeekAtLastError());
123123
}
124124

125-
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
125+
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
126126
CUDA_CHECK_RETURN(cudaPeekAtLastError());
127127
break;
128128
case LION:
129129
// in lion, the momentum update after the parameter update
130-
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
130+
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
131131
CUDA_CHECK_RETURN(cudaPeekAtLastError());
132132

133133
if(max_unorm > 0.0f)
@@ -172,13 +172,13 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
172172
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
173173
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
174174
CUDA_CHECK_RETURN(cudaPeekAtLastError());
175-
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
175+
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
176176
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
177177
CUDA_CHECK_RETURN(cudaPeekAtLastError());
178178
break;
179179
case LION:
180180
// in lion, the momentum update happens after the parameter update
181-
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
181+
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
182182
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
183183
CUDA_CHECK_RETURN(cudaPeekAtLastError());
184184

@@ -239,7 +239,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
239239
case LION:
240240
num_blocks = n/BLOCKSIZE_1STATE;
241241
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
242-
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
242+
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, return_updates, state1, beta1, beta2, eps, step, lr,
243243
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
244244
CUDA_CHECK_RETURN(cudaPeekAtLastError());
245245
break;

0 commit comments

Comments
 (0)