From 12f5f7c85807573b1e0324a25b59f2f00cb7674c Mon Sep 17 00:00:00 2001 From: YavorGIvanov Date: Sun, 13 Jul 2025 00:45:34 +0000 Subject: [PATCH 1/2] Add Pad Reflect 1D CUDA support --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++ ggml/src/ggml-cuda/pad_reflect_1d.cu | 82 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/pad_reflect_1d.cuh | 5 ++ 3 files changed, 92 insertions(+) create mode 100644 ggml/src/ggml-cuda/pad_reflect_1d.cu create mode 100644 ggml/src/ggml-cuda/pad_reflect_1d.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 88b17dd682c95..65637b2f7b0e9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -44,6 +44,7 @@ #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" #include @@ -2346,6 +2347,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PAD: ggml_cuda_op_pad(ctx, dst); break; + case GGML_OP_PAD_REFLECT_1D: + ggml_cuda_op_pad_reflect_1d(ctx, dst); + break; case GGML_OP_ARANGE: ggml_cuda_op_arange(ctx, dst); break; @@ -3386,6 +3390,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ggml/src/ggml-cuda/pad_reflect_1d.cu new file mode 100644 index 0000000000000..3e014d7d60d42 --- /dev/null +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -0,0 +1,82 @@ +#include "pad_reflect_1d.cuh" + +static __global__ void pad_reflect_1d_kernel_f32( + const void * src0, + void * dst, + const int64_t ne0, + const int64_t ne00, + const int64_t ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t nb00, + const int64_t nb01, + const int64_t nb02, + const int64_t nb03, + const int64_t nb0, + const int64_t nb1, + const int64_t nb2, + const int64_t nb3, + const int p0, + const int p1) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + return; + } + + const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01; + char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1; + + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + float value; + + if (i0 < p0) { + // Left padding - reflect + value = *(const float *)(src0_ptr + (p0 - i0) * nb00); + } else if (i0 < ne0 - p1) { + // Middle - copy + value = *(const float *)(src0_ptr + (i0 - p0) * nb00); + } else { + // Right padding - reflect + int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1; + value = *(const float *)(src0_ptr + src_idx * nb00); + } + + *(float *)(dst_ptr + i0 * nb0) = value; + } +} + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + + GGML_ASSERT(ne0 == ne00 + p0 + p1); + + const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1); + const dim3 grid_dims(ne01, ne02, ne03); + + pad_reflect_1d_kernel_f32<<>>( + src0->data, dst->data, + ne0, ne00, ne01, ne02, ne03, + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + p0, p1 + ); +} diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cuh b/ggml/src/ggml-cuda/pad_reflect_1d.cuh new file mode 100644 index 0000000000000..15f2ed1737b1a --- /dev/null +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From f908dce12391e32e9682abcf2fd4f06aedc84176 Mon Sep 17 00:00:00 2001 From: Yavor Ivanov Date: Wed, 16 Jul 2025 21:22:54 -0700 Subject: [PATCH 2/2] Update ggml/src/ggml-cuda/pad_reflect_1d.cu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/pad_reflect_1d.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ggml/src/ggml-cuda/pad_reflect_1d.cu index 3e014d7d60d42..4ed34aec3d331 100644 --- a/ggml/src/ggml-cuda/pad_reflect_1d.cu +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -1,8 +1,8 @@ #include "pad_reflect_1d.cuh" static __global__ void pad_reflect_1d_kernel_f32( - const void * src0, - void * dst, + const void * __restrict__ src0, + void * __restrict__ dst, const int64_t ne0, const int64_t ne00, const int64_t ne01,