Skip to content

Add Pad Reflect 1D CUDA support #14659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml.h"

#include <algorithm>
Expand Down Expand Up @@ -2346,6 +2347,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_PAD:
ggml_cuda_op_pad(ctx, dst);
break;
case GGML_OP_PAD_REFLECT_1D:
ggml_cuda_op_pad_reflect_1d(ctx, dst);
break;
case GGML_OP_ARANGE:
ggml_cuda_op_arange(ctx, dst);
break;
Expand Down Expand Up @@ -3386,6 +3390,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
Expand Down
82 changes: 82 additions & 0 deletions ggml/src/ggml-cuda/pad_reflect_1d.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "pad_reflect_1d.cuh"

static __global__ void pad_reflect_1d_kernel_f32(
const void * __restrict__ src0,
void * __restrict__ dst,
const int64_t ne0,
const int64_t ne00,
const int64_t ne01,
const int64_t ne02,
const int64_t ne03,
const int64_t nb00,
const int64_t nb01,
const int64_t nb02,
const int64_t nb03,
const int64_t nb0,
const int64_t nb1,
const int64_t nb2,
const int64_t nb3,
const int p0,
const int p1) {

const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;

if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
return;
}

const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;

for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to produce correct results but generally speaking you will get much better performance if each thread just works on a single value instead of looping over ne0. However, it would also be fine to just merge it as-is and maybe change this later if it ever becomes relevant for end-to-end performance.

float value;

if (i0 < p0) {
// Left padding - reflect
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
} else if (i0 < ne0 - p1) {
// Middle - copy
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
} else {
// Right padding - reflect
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
value = *(const float *)(src0_ptr + src_idx * nb00);
}

*(float *)(dst_ptr + i0 * nb0) = value;
}
}

void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

const int32_t * opts = (const int32_t *) dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];

const int64_t ne0 = dst->ne[0];

GGML_ASSERT(ne0 == ne00 + p0 + p1);

const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
const dim3 grid_dims(ne01, ne02, ne03);

pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
src0->data, dst->data,
ne0, ne00, ne01, ne02, ne03,
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
p0, p1
);
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/pad_reflect_1d.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256

void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading