2
2
#include " ggml.h"
3
3
#include " softmax.cuh"
4
4
#include < cstdint>
5
- #include < utility>
6
5
7
6
template <typename T>
8
7
static __device__ __forceinline__ float t2f32 (T val) {
@@ -182,37 +181,6 @@ static __global__ void soft_max_back_f32(
182
181
}
183
182
}
184
183
185
- template <int ... Ns, typename T>
186
- static void launch_soft_max_kernels (const float * x, const T * mask, float * dst,
187
- const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
188
- {
189
- const int id = ggml_cuda_get_device ();
190
- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
191
-
192
- auto launch_kernel = [=](auto I) -> bool {
193
- constexpr int ncols = decltype (I)::value;
194
- constexpr int block = (ncols > 1024 ? 1024 : ncols);
195
-
196
- if (p.ncols == ncols) {
197
- CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , ncols, block, T>), smpbo);
198
- soft_max_f32<true , ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
199
- (x, mask, dst, p);
200
- return true ;
201
- }
202
- return false ;
203
- };
204
-
205
- // unary fold over launch_kernel
206
- if ((launch_kernel (std::integral_constant<int , Ns>{}) || ...)) {
207
- return ;
208
- }
209
-
210
- // default case
211
- CUDA_SET_SHARED_MEMORY_LIMIT ((soft_max_f32<true , 0 , 0 , T>), smpbo);
212
- soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>> (x, mask, dst, p);
213
- }
214
-
215
-
216
184
template <typename T>
217
185
static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
218
186
int nth = WARP_SIZE;
@@ -225,12 +193,46 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
225
193
static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
226
194
227
195
228
- const int id = ggml_cuda_get_device ();
229
- const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
230
-
231
-
232
- if (nbytes_shared <= smpbo) {
233
- launch_soft_max_kernels<32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 >(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
196
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
197
+ if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
198
+ switch (ncols_x) {
199
+ case 32 :
200
+ soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201
+ (x, mask, dst, params);
202
+ break ;
203
+ case 64 :
204
+ soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205
+ (x, mask, dst, params);
206
+ break ;
207
+ case 128 :
208
+ soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
209
+ (x, mask, dst, params);
210
+ break ;
211
+ case 256 :
212
+ soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
213
+ (x, mask, dst, params);
214
+ break ;
215
+ case 512 :
216
+ soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
217
+ (x, mask, dst, params);
218
+ break ;
219
+ case 1024 :
220
+ soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
221
+ (x, mask, dst, params);
222
+ break ;
223
+ case 2048 :
224
+ soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
225
+ (x, mask, dst, params);
226
+ break ;
227
+ case 4096 :
228
+ soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
229
+ (x, mask, dst, params);
230
+ break ;
231
+ default :
232
+ soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
233
+ (x, mask, dst, params);
234
+ break ;
235
+ }
234
236
} else {
235
237
const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
236
238
soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params);
0 commit comments