Skip to content

Commit 280cfd9

Browse files
am17anqnixsynapse
authored andcommitted
CUDA: add softmax broadcast (ggml-org#14475)
* CUDA: add softmax broadcast * Pass by const ref * Review: Use blockDims for indexing, remove designated initializers * Add TODO for noncontigous input/output
1 parent d3984d9 commit 280cfd9

File tree

2 files changed

+41
-45
lines changed

2 files changed

+41
-45
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3329,13 +3329,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33293329
case GGML_OP_DIAG_MASK_INF:
33303330
return true;
33313331
case GGML_OP_SOFT_MAX:
3332-
// TODO: support batching
3333-
if (op->src[0]->ne[3] != 1) {
3334-
return false;
3335-
}
3336-
// TODO: support broadcast
3337-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
3338-
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
3332+
return true;
33393333
case GGML_OP_SOFT_MAX_BACK: {
33403334
float max_bias = 0.0f;
33413335
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));

ggml/src/ggml-cuda/softmax.cu

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "ggml.h"
33
#include "softmax.cuh"
44
#include <cstdint>
5-
#include <utility>
65

76
template <typename T>
87
static __device__ __forceinline__ float t2f32(T val) {
@@ -182,37 +181,6 @@ static __global__ void soft_max_back_f32(
182181
}
183182
}
184183

185-
template<int... Ns, typename T>
186-
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
187-
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188-
{
189-
const int id = ggml_cuda_get_device();
190-
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
191-
192-
auto launch_kernel = [=](auto I) -> bool {
193-
constexpr int ncols = decltype(I)::value;
194-
constexpr int block = (ncols > 1024 ? 1024 : ncols);
195-
196-
if (p.ncols == ncols) {
197-
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
198-
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199-
(x, mask, dst, p);
200-
return true;
201-
}
202-
return false;
203-
};
204-
205-
// unary fold over launch_kernel
206-
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
207-
return;
208-
}
209-
210-
//default case
211-
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
212-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
213-
}
214-
215-
216184
template<typename T>
217185
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
218186
int nth = WARP_SIZE;
@@ -225,12 +193,46 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
225193
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
226194

227195

228-
const int id = ggml_cuda_get_device();
229-
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
230-
231-
232-
if (nbytes_shared <= smpbo) {
233-
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
196+
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
197+
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
198+
switch (ncols_x) {
199+
case 32:
200+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
201+
(x, mask, dst, params);
202+
break;
203+
case 64:
204+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
205+
(x, mask, dst, params);
206+
break;
207+
case 128:
208+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
209+
(x, mask, dst, params);
210+
break;
211+
case 256:
212+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
213+
(x, mask, dst, params);
214+
break;
215+
case 512:
216+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
217+
(x, mask, dst, params);
218+
break;
219+
case 1024:
220+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
221+
(x, mask, dst, params);
222+
break;
223+
case 2048:
224+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
225+
(x, mask, dst, params);
226+
break;
227+
case 4096:
228+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
229+
(x, mask, dst, params);
230+
break;
231+
default:
232+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
233+
(x, mask, dst, params);
234+
break;
235+
}
234236
} else {
235237
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
236238
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);

0 commit comments

Comments
 (0)