Skip to content

Commit 6c55697

Browse files
am17anqnixsynapse
authored andcommitted
CUDA: add dynamic shared mem to softmax, refactor general usage (ggml-org#14497)
1 parent 8274c5a commit 6c55697

File tree

2 files changed

+39
-40
lines changed

2 files changed

+39
-40
lines changed

ggml/src/ggml-cuda/softmax.cu

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ggml.h"
33
#include "softmax.cuh"
44
#include <cstdint>
5+
#include <utility>
56

67
template <typename T>
78
static __device__ __forceinline__ float t2f32(T val) {
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
181182
}
182183
}
183184

185+
template<int... Ns, typename T>
186+
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
187+
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188+
{
189+
const int id = ggml_cuda_get_device();
190+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
191+
192+
auto launch_kernel = [=](auto I) -> bool {
193+
constexpr int ncols = decltype(I)::value;
194+
constexpr int block = (ncols > 1024 ? 1024 : ncols);
195+
196+
if (p.ncols == ncols) {
197+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
198+
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199+
(x, mask, dst, p);
200+
return true;
201+
}
202+
return false;
203+
};
204+
205+
// unary fold over launch_kernel
206+
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
207+
return;
208+
}
209+
210+
//default case
211+
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
212+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
213+
}
214+
215+
184216
template<typename T>
185217
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
186218
int nth = WARP_SIZE;
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
193225
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
194226

195227

196-
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
197-
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
198-
switch (ncols_x) {
199-
case 32:
200-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
201-
(x, mask, dst, params);
202-
break;
203-
case 64:
204-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
205-
(x, mask, dst, params);
206-
break;
207-
case 128:
208-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
209-
(x, mask, dst, params);
210-
break;
211-
case 256:
212-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
213-
(x, mask, dst, params);
214-
break;
215-
case 512:
216-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
217-
(x, mask, dst, params);
218-
break;
219-
case 1024:
220-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
221-
(x, mask, dst, params);
222-
break;
223-
case 2048:
224-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
225-
(x, mask, dst, params);
226-
break;
227-
case 4096:
228-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
229-
(x, mask, dst, params);
230-
break;
231-
default:
232-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
233-
(x, mask, dst, params);
234-
break;
235-
}
228+
const int id = ggml_cuda_get_device();
229+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
230+
231+
232+
if (nbytes_shared <= smpbo) {
233+
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
236234
} else {
237235
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
238236
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4932,6 +4932,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
49324932
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
49334933

49344934
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
4935+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49354936
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49364937
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
49374938
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

0 commit comments

Comments
 (0)