Skip to content

Commit b85c660

Browse files
am17anMinh141120
authored andcommitted
CUDA: add conv_2d_dw (ggml-org#14265)
* CUDA: add conv_2d_dw * better naming * simplify using template * Review: fix operation ordering in ggml-cuda, use __forceinline__, use more const
1 parent ff19f0c commit b85c660

File tree

1 file changed

+12
-46
lines changed

1 file changed

+12
-46
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
1414
#include "ggml-cuda/conv2d-dw.cuh"
15-
#include "ggml-cuda/conv2d-transpose.cuh"
1615
#include "ggml-cuda/convert.cuh"
1716
#include "ggml-cuda/count-equal.cuh"
1817
#include "ggml-cuda/cpy.cuh"
@@ -1842,27 +1841,16 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18421841
int64_t s12 = nb12 / ts_src1;
18431842
int64_t s13 = nb13 / ts_src1;
18441843

1845-
const cuda_t * src0_ptr = nullptr;
1846-
const cuda_t * src1_ptr = nullptr;
1847-
1848-
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1849-
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1850-
1851-
// Handle src0
1852-
src0_ptr = (const cuda_t *) src0->data;
1853-
1854-
// Handle src1 - convert if necessary
1855-
if (src1->type == src0_type) {
1856-
src1_ptr = (const cuda_t *) src1->data;
1857-
} else {
1858-
// Convert src1 to target type using traits conversion functions
1844+
// convert src1 to fp16
1845+
if (src1->type != GGML_TYPE_F16) {
1846+
const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
18591847
const int64_t ne_src1 = ggml_nelements(src1);
1860-
src1_alloc.alloc(ne_src1);
1848+
src1_f16_alloc.alloc(ne_src1);
1849+
GGML_ASSERT(to_fp16_cuda != nullptr);
18611850

1862-
const auto convert_func = traits::get_nc_converter(src1->type);
1863-
GGML_ASSERT(convert_func != nullptr);
1864-
convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1865-
src1_ptr = src1_alloc.get();
1851+
to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1852+
1853+
src1_f16 = src1_f16_alloc.get();
18661854
s11 = ne10;
18671855
s12 = ne11*s11;
18681856
s13 = ne12*s12;
@@ -1959,29 +1947,11 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19591947
cu_compute_type,
19601948
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
19611949
}
1950+
#endif
19621951

1963-
// Convert output back to F32 if needed
1964-
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
1965-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
1966-
to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
1967-
}
1968-
}
1969-
1970-
static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1971-
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
1972-
1973-
switch (src0->type) {
1974-
case GGML_TYPE_F32:
1975-
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
1976-
break;
1977-
case GGML_TYPE_BF16:
1978-
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
1979-
break;
1980-
case GGML_TYPE_F16:
1981-
ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
1982-
break;
1983-
default:
1984-
GGML_ABORT("Unsupported type");
1952+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
1953+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1954+
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
19851955
}
19861956
}
19871957

@@ -2411,9 +2381,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
24112381
case GGML_OP_CONV_2D_DW:
24122382
ggml_cuda_op_conv2d_dw(ctx, dst);
24132383
break;
2414-
case GGML_OP_CONV_TRANSPOSE_2D:
2415-
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2416-
break;
24172384
case GGML_OP_CONV_TRANSPOSE_1D:
24182385
ggml_cuda_op_conv_transpose_1d(ctx,dst);
24192386
break;
@@ -3340,7 +3307,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33403307
}
33413308
case GGML_OP_IM2COL:
33423309
case GGML_OP_CONV_2D_DW:
3343-
case GGML_OP_CONV_TRANSPOSE_2D:
33443310
case GGML_OP_POOL_2D:
33453311
case GGML_OP_SUM:
33463312
case GGML_OP_SUM_ROWS:

0 commit comments

Comments
 (0)