Skip to content

Commit 61189fc

Browse files
Rebase on main - resolve conflicts
1 parent 59883ac commit 61189fc

File tree

6 files changed

+25
-110
lines changed

6 files changed

+25
-110
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
from collections import abc as container_abcs, defaultdict
66
from copy import deepcopy
77
from itertools import chain
8-
<<<<<<< HEAD
98
from typing import Any, Dict, Optional
10-
=======
11-
from typing import Optional
12-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
139

1410
import torch
1511

csrc/kernels.cu

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,20 +1617,10 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
16171617
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
16181618
__launch_bounds__(256, 3)
16191619
__global__ void
1620-
<<<<<<< HEAD
1621-
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_updates,
1622-
unsigned char* state1, unsigned char* state2,
1623-
const float beta1, const float beta2,
1624-
const float eps, const int step, const float lr,
1625-
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
1626-
float* absmax1, float* absmax2,
1627-
float weight_decay,
1628-
const float gnorm_scale, const bool skip_zeros, const int n)
1629-
{
1630-
=======
16311620
kOptimizerStatic8bit2StateBlockwise(
1632-
T* p,
1621+
T* __restrict__ p,
16331622
T* __restrict__ const g,
1623+
T* __restrict__ return_updates,
16341624
unsigned char* state1,
16351625
unsigned char* state2,
16361626
const float beta1,
@@ -1649,7 +1639,6 @@ kOptimizerStatic8bit2StateBlockwise(
16491639
const bool skip_zeros,
16501640
const int n
16511641
) {
1652-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
16531642

16541643
//const int n_full = n + (n%BLOCK_SIZE);
16551644
const int n_full = gridDim.x * BLOCK_SIZE;
@@ -1834,28 +1823,22 @@ kOptimizerStatic8bit2StateBlockwise(
18341823
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
18351824
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
18361825
{
1837-
<<<<<<< HEAD
1838-
if (return_updates == nullptr) {
1839-
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
1840-
if(weight_decay > 0.0f)
1841-
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
1842-
} else {
1843-
p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
1844-
}
1845-
=======
18461826
if (OPTIMIZER == ADEMAMIX) {
18471827
p_vals[j] = T((float)p_vals[j] - lr * (
18481828
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
18491829
(sqrtf(s2_vals[j]) / correction2) + eps
18501830
)
18511831
));
18521832
} else {
1853-
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
1833+
if (return_updates == nullptr) {
1834+
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
1835+
} else {
1836+
p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
1837+
}
18541838
}
18551839

1856-
if(weight_decay > 0.0f)
1840+
if (return_updates == nullptr && weight_decay > 0.0f)
18571841
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
1858-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
18591842
}
18601843
}
18611844

@@ -3813,7 +3796,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float)
38133796
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)
38143797

38153798
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
3816-
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
3799+
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
38173800
float* state1, float* state2, float *unorm, \
38183801
const float beta1, const float beta2, const float eps, const float weight_decay, \
38193802
const int step, const float lr, const float gnorm_scale, const int n); \
@@ -3825,28 +3808,19 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
38253808
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
38263809
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)
38273810

3828-
<<<<<<< HEAD
38293811
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3830-
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3831-
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3832-
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3833-
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3834-
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3835-
=======
3836-
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
38373812
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3838-
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3813+
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
38393814
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3840-
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3815+
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
38413816
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3842-
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3817+
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
38433818
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3844-
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3819+
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, half* return_updates,float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
38453820
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
3846-
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
3821+
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
38473822
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
38483823

3849-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
38503824

38513825
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
38523826
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
@@ -4006,14 +3980,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General
40063980
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);
40073981

40083982
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
4009-
<<<<<<< HEAD
40103983
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, gtype* return_updates, \
40113984
unsigned char* state1, unsigned char* state2, \
4012-
const float beta1, const float beta2, \
4013-
=======
4014-
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
40153985
const float beta1, const float beta2, const float beta3, const float alpha, \
4016-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
40173986
const float eps, const int step, const float lr, \
40183987
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
40193988
float* absmax1, float* absmax2, \

csrc/kernels.cuh

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* s
8989
float weight_decay, const float gnorm_scale, const int n);
9090

9191
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
92-
<<<<<<< HEAD
93-
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2,
94-
const float beta1, const float beta2, const float eps, const int step, const float lr,
95-
=======
96-
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
92+
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2,
9793
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
98-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
9994
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
10095
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
10196

csrc/ops.cu

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_up
109109
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
110110
CUDA_CHECK_RETURN(cudaPeekAtLastError());
111111
}
112-
<<<<<<< HEAD
113-
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
114-
=======
115-
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
116-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
112+
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
117113
CUDA_CHECK_RETURN(cudaPeekAtLastError());
118114
break;
119115
case MOMENTUM:
@@ -200,15 +196,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
200196
#define BLOCKSIZE_1STATE 256
201197
#define NUM_1STATE 1
202198

203-
<<<<<<< HEAD
204-
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
205-
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
206-
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
207-
{
208-
=======
209199
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
210200
T* p,
211201
T* g,
202+
T* return_updates,
212203
unsigned char* state1,
213204
unsigned char* state2,
214205
float beta1,
@@ -227,7 +218,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
227218
bool skip_zeros,
228219
int n
229220
) {
230-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
231221

232222
int num_blocks = 0;
233223
switch(OPTIMIZER)
@@ -236,16 +226,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
236226
case ADEMAMIX:
237227
num_blocks = n/BLOCKSIZE_2STATE;
238228
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
239-
<<<<<<< HEAD
240-
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr,
241-
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
242-
=======
243229
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
244-
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
230+
p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
245231
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
246232
skip_zeros, n
247233
);
248-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
249234
CUDA_CHECK_RETURN(cudaPeekAtLastError());
250235
break;
251236
case MOMENTUM:
@@ -872,13 +857,8 @@ MAKE_optimizerStatic8bit(ADAGRAD, float)
872857

873858

874859
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
875-
<<<<<<< HEAD
876860
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, gtype* return_updates, \
877-
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
878-
=======
879-
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
880861
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
881-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
882862
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
883863

884864
MAKE_optimizerStatic8bitBlockwise(half, ADAM);

csrc/ops.cuh

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
163163
float weight_decay,
164164
const float gnorm_scale, int n);
165165

166-
<<<<<<< HEAD
167166
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
168-
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
169-
=======
170-
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
171-
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,
172-
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
167+
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
168+
float eps, int step, float lr,
173169
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
174170
bool skip_zeros, int n);
175171

0 commit comments

Comments
 (0)