Skip to content

Add CUDA non-contiguous Unary Ops support #14639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3116,7 +3116,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
return true;
default:
return false;
}
Expand Down
54 changes: 48 additions & 6 deletions ggml/src/ggml-cuda/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,51 @@ static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
}

template <float (*op)(float), typename T>
static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
static __global__ void unary_op_kernel_noncont(
const void * x, void * dst,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
const int64_t nb0_x, const int64_t nb1_x, const int64_t nb2_x, const int64_t nb3_x,
const int64_t nb0_d, const int64_t nb1_d, const int64_t nb2_d, const int64_t nb3_d,
const int64_t k) {

const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
return;
}

const int64_t i3 = i / (ne2 * ne1 * ne0);
const int64_t i2 = (i / (ne1 * ne0)) % ne2;
const int64_t i1 = (i / ne0) % ne1;
const int64_t i0 = i % ne0;

const int64_t offset_x = i0*nb0_x + i1*nb1_x + i2*nb2_x + i3*nb3_x;
const int64_t offset_d = i0*nb0_d + i1*nb1_d + i2*nb2_d + i3*nb3_d;

const T * px = (const T *)((const char *)x + offset_x);
T * pd = (T *)((char *)dst + offset_d);

*pd = (T)op((float)*px);
}

template <float (*op)(float), typename T>
static void unary_cuda(const T * x, T * dst, const int k,
const ggml_tensor * src, const ggml_tensor * dst_tensor,
cudaStream_t stream) {
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);

if (ggml_is_contiguous(src) && ggml_is_contiguous(dst_tensor)) {
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} else {
Comment on lines +131 to +133
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the contiguous path, it's no longer needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it as the performance of the simple cont kernel is obviously better. I thought you may prefer to still use the most optimal path in this case. I know in the big scheme of things these unary operations are a very small part of the inference time, but think it is good idea to not degrade cont perf in this case.

  ABS(type=f32,ne_a=[256,256,3,1],v=0):               532415 runs -     1.88 us/run -     1536 kB/run -  778.95 GB/s
  ABS(type=f32,ne_a=[256,256,3,1],v=1):               311220 runs -     3.24 us/run -     3070 kB/run -  903.14 GB/s

Here is example perf test using test-backend-ops on a H100 SXM5.
v=0 meaning contiguous and v=1 meaning non-contiguous.

Let me know whether you still want the cont path removed or you agree I should keep it for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply, if you want to keep the contiguous path, add a template parameter to the noncontiguous kernel where you return early.

More generally, if you're concerned about the performance one thing you can try is replace the byte offsets with logical offsets (calculate these in host code and pass to the kernel). But I expect the impact on end-to-end performance to be negligible.

unary_op_kernel_noncont<op, T><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(
(const void *)x, (void *)dst,
src->ne[0], src->ne[1], src->ne[2], src->ne[3],
src->nb[0], src->nb[1], src->nb[2], src->nb[3],
dst_tensor->nb[0], dst_tensor->nb[1],
dst_tensor->nb[2], dst_tensor->nb[3],
k
);
}
}

template <float (*op)(float)>
Expand All @@ -107,16 +149,16 @@ void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);

if (src0->type == GGML_TYPE_F16) {
unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0),
src0, dst, stream);
} else {
unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0),
src0, dst, stream);
}
}

Expand Down
3 changes: 3 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5642,6 +5642,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {

test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));

test_cases.emplace_back(new test_unary((ggml_unary_op) GGML_UNARY_OP_ABS, GGML_TYPE_F32, {256, 256, 3, 1}, 0));
test_cases.emplace_back(new test_unary((ggml_unary_op) GGML_UNARY_OP_ABS, GGML_TYPE_F32, {256, 256, 3, 1}, 1));

return test_cases;
}

Expand Down
Loading