-
Notifications
You must be signed in to change notification settings - Fork 12.4k
CUDA: add softmax broadcast #14475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: add softmax broadcast #14475
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,6 +13,28 @@ __device__ float __forceinline__ t2f32<half>(half val) { | |||||||||||||||||
return __half2float(val); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
struct soft_max_params { | ||||||||||||||||||
|
||||||||||||||||||
int64_t nheads; | ||||||||||||||||||
uint32_t n_head_log2; | ||||||||||||||||||
int64_t ncols; | ||||||||||||||||||
int64_t nrows_x; | ||||||||||||||||||
int64_t nrows_y; | ||||||||||||||||||
int64_t ne00; | ||||||||||||||||||
int64_t ne01; | ||||||||||||||||||
int64_t ne02; | ||||||||||||||||||
int64_t nb11; | ||||||||||||||||||
int64_t nb12; | ||||||||||||||||||
int64_t nb13; | ||||||||||||||||||
|
||||||||||||||||||
int64_t ne12; | ||||||||||||||||||
int64_t ne13; | ||||||||||||||||||
float scale; | ||||||||||||||||||
float max_bias; | ||||||||||||||||||
float m0; | ||||||||||||||||||
float m1; | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. | ||||||||||||||||||
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. | ||||||||||||||||||
#ifdef __clang__ | ||||||||||||||||||
|
@@ -21,24 +43,30 @@ __device__ float __forceinline__ t2f32<half>(half val) { | |||||||||||||||||
#endif // __clang__ | ||||||||||||||||||
template <bool use_shared, int ncols_template, int block_size_template, typename T> | ||||||||||||||||||
static __global__ void soft_max_f32( | ||||||||||||||||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, | ||||||||||||||||||
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { | ||||||||||||||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template; | ||||||||||||||||||
const float * x, const T * mask, float * dst, const soft_max_params p) { | ||||||||||||||||||
const int ncols = ncols_template == 0 ? p.ncols : ncols_template; | ||||||||||||||||||
|
||||||||||||||||||
const int tid = threadIdx.x; | ||||||||||||||||||
const int rowx = blockIdx.x; | ||||||||||||||||||
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension | ||||||||||||||||||
|
||||||||||||||||||
const int64_t i03 = rowx / (p.ne01 * p.ne02); | ||||||||||||||||||
const int64_t i02 = (rowx % (p.ne01 * p.ne02)) / p.ne01; | ||||||||||||||||||
const int64_t i01 = rowx % p.ne01; | ||||||||||||||||||
|
||||||||||||||||||
const int64_t i11 = i01; | ||||||||||||||||||
const int64_t i12 = i02 % p.ne12; | ||||||||||||||||||
const int64_t i13 = i03 % p.ne13; | ||||||||||||||||||
|
||||||||||||||||||
x += int64_t(rowx)*ncols; | ||||||||||||||||||
mask += int64_t(rowy)*ncols * (mask != nullptr); | ||||||||||||||||||
mask += (int64_t(i11)*p.nb11 + int64_t(i12)*p.nb12 + int64_t(i13)*p.nb13) / sizeof(T) * (mask != nullptr); | ||||||||||||||||||
dst += int64_t(rowx)*ncols; | ||||||||||||||||||
|
||||||||||||||||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; | ||||||||||||||||||
|
||||||||||||||||||
const int warp_id = threadIdx.x / WARP_SIZE; | ||||||||||||||||||
const int lane_id = threadIdx.x % WARP_SIZE; | ||||||||||||||||||
|
||||||||||||||||||
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); | ||||||||||||||||||
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); | ||||||||||||||||||
|
||||||||||||||||||
extern __shared__ float data_soft_max_f32[]; | ||||||||||||||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication | ||||||||||||||||||
|
@@ -55,7 +83,7 @@ static __global__ void soft_max_f32( | |||||||||||||||||
break; | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); | ||||||||||||||||||
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); | ||||||||||||||||||
|
||||||||||||||||||
vals[col] = val; | ||||||||||||||||||
max_val = max(max_val, val); | ||||||||||||||||||
|
@@ -151,63 +179,60 @@ static __global__ void soft_max_back_f32( | |||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
template<typename T> | ||||||||||||||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { | ||||||||||||||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, soft_max_params params, cudaStream_t stream) { | ||||||||||||||||||
am17an marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
int nth = WARP_SIZE; | ||||||||||||||||||
const int64_t ncols_x = params.ncols; | ||||||||||||||||||
|
||||||||||||||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; | ||||||||||||||||||
const dim3 block_dims(nth, 1, 1); | ||||||||||||||||||
const dim3 block_nums(nrows_x, 1, 1); | ||||||||||||||||||
const dim3 block_nums(params.nrows_x, 1, 1); | ||||||||||||||||||
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); | ||||||||||||||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); | ||||||||||||||||||
|
||||||||||||||||||
const uint32_t n_head = nrows_x/nrows_y; | ||||||||||||||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||||||||||||||||
|
||||||||||||||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); | ||||||||||||||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||||||||||||||||
|
||||||||||||||||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer | ||||||||||||||||||
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { | ||||||||||||||||||
Comment on lines
196
to
197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you were asking me for things to do, consider also tackling this. It's not an immediate problem but if we ever want to do sampling on the GPU it will be. See llama.cpp/ggml/src/ggml-cuda/mmq.cuh Lines 3019 to 3026 in 343b6e9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I checked with NVIDIA engineers: it is not possible to raise the shared memory limit universally, you have to do it manually for each function. Yes, it's stupid. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if I understand this correctly - prior to calling any which uses shared mem, we should somehow call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only matters for like 3 kernels that sometimes want to use more than 48 kB of shared memory. The desired behavior is to raise the shared memory limit once per function and device. Raising the limit multiple times does not result in an error but it makes the code slower (thus the static variable). A macro could conceivably be used to solve this issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll take a look at this |
||||||||||||||||||
switch (ncols_x) { | ||||||||||||||||||
case 32: | ||||||||||||||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 64: | ||||||||||||||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 128: | ||||||||||||||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 256: | ||||||||||||||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 512: | ||||||||||||||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 1024: | ||||||||||||||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 2048: | ||||||||||||||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
case 4096: | ||||||||||||||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
default: | ||||||||||||||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>> | ||||||||||||||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
(x, mask, dst, params); | ||||||||||||||||||
break; | ||||||||||||||||||
} | ||||||||||||||||||
} else { | ||||||||||||||||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); | ||||||||||||||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); | ||||||||||||||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params); | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -235,10 +260,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |||||||||||||||||
|
||||||||||||||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional | ||||||||||||||||||
|
||||||||||||||||||
const int64_t ne00 = src0->ne[0]; | ||||||||||||||||||
const int64_t nrows_x = ggml_nrows(src0); | ||||||||||||||||||
const int64_t nrows_y = src0->ne[1]; | ||||||||||||||||||
|
||||||||||||||||||
const int64_t ne00 = src0->ne[0]; | ||||||||||||||||||
|
||||||||||||||||||
float scale = 1.0f; | ||||||||||||||||||
float max_bias = 0.0f; | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -247,10 +273,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |||||||||||||||||
|
||||||||||||||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); | ||||||||||||||||||
|
||||||||||||||||||
const int64_t nb11 = src1 ? src1->nb[1] : 1; | ||||||||||||||||||
const int64_t nb12 = src1 ? src1->nb[2] : 1; | ||||||||||||||||||
const int64_t nb13 = src1 ? src1->nb[3] : 1; | ||||||||||||||||||
|
||||||||||||||||||
const int64_t ne12 = src1 ? src1->ne[2] : 1; | ||||||||||||||||||
const int64_t ne13 = src1 ? src1->ne[3] : 1; | ||||||||||||||||||
|
||||||||||||||||||
const uint32_t n_head = src0->ne[2]; | ||||||||||||||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||||||||||||||||
|
||||||||||||||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); | ||||||||||||||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
soft_max_params params = { | ||||||||||||||||||
.nheads = src0->ne[2], | ||||||||||||||||||
.n_head_log2 = n_head_log2, | ||||||||||||||||||
.ncols = ne00, | ||||||||||||||||||
.nrows_x = nrows_x, | ||||||||||||||||||
.nrows_y = nrows_y, | ||||||||||||||||||
.ne00 = src0->ne[0], | ||||||||||||||||||
.ne01 = src0->ne[1], | ||||||||||||||||||
.ne02 = src0->ne[2], | ||||||||||||||||||
.nb11 = nb11, | ||||||||||||||||||
.nb12 = nb12, | ||||||||||||||||||
.nb13 = nb13, | ||||||||||||||||||
.ne12 = ne12, | ||||||||||||||||||
.ne13 = ne13, | ||||||||||||||||||
.scale = scale, | ||||||||||||||||||
.max_bias = max_bias, | ||||||||||||||||||
.m0 = m0, | ||||||||||||||||||
.m1 = m1 | ||||||||||||||||||
}; | ||||||||||||||||||
am17an marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||
|
||||||||||||||||||
if (use_f16) { | ||||||||||||||||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); | ||||||||||||||||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream); | ||||||||||||||||||
} else { | ||||||||||||||||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); | ||||||||||||||||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream); | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.