-
Notifications
You must be signed in to change notification settings - Fork 12.4k
Add Pad Reflect 1D CUDA support #14659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
YavorGIvanov
wants to merge
2
commits into
ggml-org:master
Choose a base branch
from
YavorGIvanov:feature/pad-reflect-cuda-support
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+92
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#include "pad_reflect_1d.cuh" | ||
|
||
static __global__ void pad_reflect_1d_kernel_f32( | ||
const void * __restrict__ src0, | ||
void * __restrict__ dst, | ||
const int64_t ne0, | ||
const int64_t ne00, | ||
const int64_t ne01, | ||
const int64_t ne02, | ||
const int64_t ne03, | ||
const int64_t nb00, | ||
const int64_t nb01, | ||
const int64_t nb02, | ||
const int64_t nb03, | ||
const int64_t nb0, | ||
const int64_t nb1, | ||
const int64_t nb2, | ||
const int64_t nb3, | ||
const int p0, | ||
const int p1) { | ||
|
||
const int64_t i3 = blockIdx.z; | ||
const int64_t i2 = blockIdx.y; | ||
const int64_t i1 = blockIdx.x; | ||
|
||
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { | ||
return; | ||
} | ||
|
||
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01; | ||
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1; | ||
|
||
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { | ||
float value; | ||
|
||
if (i0 < p0) { | ||
// Left padding - reflect | ||
value = *(const float *)(src0_ptr + (p0 - i0) * nb00); | ||
} else if (i0 < ne0 - p1) { | ||
// Middle - copy | ||
value = *(const float *)(src0_ptr + (i0 - p0) * nb00); | ||
} else { | ||
// Right padding - reflect | ||
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1; | ||
value = *(const float *)(src0_ptr + src_idx * nb00); | ||
} | ||
|
||
*(float *)(dst_ptr + i0 * nb0) = value; | ||
} | ||
} | ||
|
||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
const ggml_tensor * src0 = dst->src[0]; | ||
cudaStream_t stream = ctx.stream(); | ||
|
||
GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
|
||
const int32_t * opts = (const int32_t *) dst->op_params; | ||
const int p0 = opts[0]; | ||
const int p1 = opts[1]; | ||
|
||
const int64_t ne00 = src0->ne[0]; | ||
const int64_t ne01 = src0->ne[1]; | ||
const int64_t ne02 = src0->ne[2]; | ||
const int64_t ne03 = src0->ne[3]; | ||
|
||
const int64_t ne0 = dst->ne[0]; | ||
|
||
GGML_ASSERT(ne0 == ne00 + p0 + p1); | ||
|
||
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1); | ||
const dim3 grid_dims(ne01, ne02, ne03); | ||
|
||
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>( | ||
src0->data, dst->data, | ||
ne0, ne00, ne01, ne02, ne03, | ||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
p0, p1 | ||
); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include "common.cuh" | ||
|
||
#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256 | ||
|
||
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to produce correct results but generally speaking you will get much better performance if each thread just works on a single value instead of looping over
ne0
. However, it would also be fine to just merge it as-is and maybe change this later if it ever becomes relevant for end-to-end performance.