From 9e36db6eec8405b0ecf09bf058a5e73fdc51fe6e Mon Sep 17 00:00:00 2001 From: YavorGIvanov Date: Sat, 12 Jul 2025 23:57:41 +0000 Subject: [PATCH] Add ELU CUDA support --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++++ ggml/src/ggml-cuda/unary.cu | 7 +++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ 3 files changed, 13 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 88b17dd682c95..4815a010e852c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2303,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; + case GGML_UNARY_OP_ELU: + ggml_cuda_op_elu(ctx, dst); + break; default: return false; } @@ -3116,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9c7b83c40d1b..91c830c4dacc3 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) { return logf(x); } +static __device__ __forceinline__ float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 289d690e5cff6..cb14d16f8f3f5 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);