Skip to content

Commit 64be8c5

Browse files
committed
Add CUDA non-contigious Unary ops implementation
1 parent c31e606 commit 64be8c5

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3116,7 +3116,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31163116
case GGML_UNARY_OP_GELU_QUICK:
31173117
case GGML_UNARY_OP_TANH:
31183118
case GGML_UNARY_OP_EXP:
3119-
return ggml_is_contiguous(op->src[0]);
3119+
return true;
31203120
default:
31213121
return false;
31223122
}

ggml/src/ggml-cuda/unary.cu

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,51 @@ static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
9595
}
9696

9797
template <float (*op)(float), typename T>
98-
static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
98+
static __global__ void unary_op_kernel_noncont(
99+
const void * x, void * dst,
100+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
101+
const int64_t nb0_x, const int64_t nb1_x, const int64_t nb2_x, const int64_t nb3_x,
102+
const int64_t nb0_d, const int64_t nb1_d, const int64_t nb2_d, const int64_t nb3_d,
103+
const int64_t k) {
104+
105+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
106+
107+
if (i >= k) {
108+
return;
109+
}
110+
111+
const int64_t i3 = i / (ne2 * ne1 * ne0);
112+
const int64_t i2 = (i / (ne1 * ne0)) % ne2;
113+
const int64_t i1 = (i / ne0) % ne1;
114+
const int64_t i0 = i % ne0;
115+
116+
const int64_t offset_x = i0*nb0_x + i1*nb1_x + i2*nb2_x + i3*nb3_x;
117+
const int64_t offset_d = i0*nb0_d + i1*nb1_d + i2*nb2_d + i3*nb3_d;
118+
119+
const T * px = (const T *)((const char *)x + offset_x);
120+
T * pd = (T *)((char *)dst + offset_d);
121+
122+
*pd = (T)op((float)*px);
123+
}
124+
125+
template <float (*op)(float), typename T>
126+
static void unary_cuda(const T * x, T * dst, const int k,
127+
const ggml_tensor * src, const ggml_tensor * dst_tensor,
128+
cudaStream_t stream) {
99129
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
100-
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
130+
131+
if (ggml_is_contiguous(src) && ggml_is_contiguous(dst_tensor)) {
132+
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
133+
} else {
134+
unary_op_kernel_noncont<op, T><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(
135+
(const void *)x, (void *)dst,
136+
src->ne[0], src->ne[1], src->ne[2], src->ne[3],
137+
src->nb[0], src->nb[1], src->nb[2], src->nb[3],
138+
dst_tensor->nb[0], dst_tensor->nb[1],
139+
dst_tensor->nb[2], dst_tensor->nb[3],
140+
k
141+
);
142+
}
101143
}
102144

103145
template <float (*op)(float)>
@@ -107,16 +149,16 @@ void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
107149
void * dst_d = dst->data;
108150
cudaStream_t stream = ctx.stream();
109151

110-
GGML_ASSERT(ggml_is_contiguous(src0));
111-
112152
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
113153
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
114154
GGML_ASSERT(src0->type == dst->type);
115155

116156
if (src0->type == GGML_TYPE_F16) {
117-
unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
157+
unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0),
158+
src0, dst, stream);
118159
} else {
119-
unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
160+
unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0),
161+
src0, dst, stream);
120162
}
121163
}
122164

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5642,6 +5642,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
56425642

56435643
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
56445644

5645+
test_cases.emplace_back(new test_unary((ggml_unary_op) GGML_UNARY_OP_ABS, GGML_TYPE_F32, {256, 256, 3, 1}, 0));
5646+
test_cases.emplace_back(new test_unary((ggml_unary_op) GGML_UNARY_OP_ABS, GGML_TYPE_F32, {256, 256, 3, 1}, 1));
5647+
56455648
return test_cases;
56465649
}
56475650

0 commit comments

Comments
 (0)