|
12 | 12 | #include "ggml-cuda/concat.cuh"
|
13 | 13 | #include "ggml-cuda/conv-transpose-1d.cuh"
|
14 | 14 | #include "ggml-cuda/conv2d-dw.cuh"
|
15 |
| -#include "ggml-cuda/conv2d-transpose.cuh" |
16 | 15 | #include "ggml-cuda/convert.cuh"
|
17 | 16 | #include "ggml-cuda/count-equal.cuh"
|
18 | 17 | #include "ggml-cuda/cpy.cuh"
|
@@ -1842,27 +1841,16 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
1842 | 1841 | int64_t s12 = nb12 / ts_src1;
|
1843 | 1842 | int64_t s13 = nb13 / ts_src1;
|
1844 | 1843 |
|
1845 |
| - const cuda_t * src0_ptr = nullptr; |
1846 |
| - const cuda_t * src1_ptr = nullptr; |
1847 |
| - |
1848 |
| - ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool()); |
1849 |
| - ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool()); |
1850 |
| - |
1851 |
| - // Handle src0 |
1852 |
| - src0_ptr = (const cuda_t *) src0->data; |
1853 |
| - |
1854 |
| - // Handle src1 - convert if necessary |
1855 |
| - if (src1->type == src0_type) { |
1856 |
| - src1_ptr = (const cuda_t *) src1->data; |
1857 |
| - } else { |
1858 |
| - // Convert src1 to target type using traits conversion functions |
| 1844 | + // convert src1 to fp16 |
| 1845 | + if (src1->type != GGML_TYPE_F16) { |
| 1846 | + const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type); |
1859 | 1847 | const int64_t ne_src1 = ggml_nelements(src1);
|
1860 |
| - src1_alloc.alloc(ne_src1); |
| 1848 | + src1_f16_alloc.alloc(ne_src1); |
| 1849 | + GGML_ASSERT(to_fp16_cuda != nullptr); |
1861 | 1850 |
|
1862 |
| - const auto convert_func = traits::get_nc_converter(src1->type); |
1863 |
| - GGML_ASSERT(convert_func != nullptr); |
1864 |
| - convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); |
1865 |
| - src1_ptr = src1_alloc.get(); |
| 1851 | + to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); |
| 1852 | + |
| 1853 | + src1_f16 = src1_f16_alloc.get(); |
1866 | 1854 | s11 = ne10;
|
1867 | 1855 | s12 = ne11*s11;
|
1868 | 1856 | s13 = ne12*s12;
|
@@ -1959,29 +1947,11 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
1959 | 1947 | cu_compute_type,
|
1960 | 1948 | CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
1961 | 1949 | }
|
| 1950 | +#endif |
1962 | 1951 |
|
1963 |
| - // Convert output back to F32 if needed |
1964 |
| - if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) { |
1965 |
| - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val); |
1966 |
| - to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream); |
1967 |
| - } |
1968 |
| -} |
1969 |
| - |
1970 |
| -static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { |
1971 |
| - GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32); |
1972 |
| - |
1973 |
| - switch (src0->type) { |
1974 |
| - case GGML_TYPE_F32: |
1975 |
| - ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst); |
1976 |
| - break; |
1977 |
| - case GGML_TYPE_BF16: |
1978 |
| - ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst); |
1979 |
| - break; |
1980 |
| - case GGML_TYPE_F16: |
1981 |
| - ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst); |
1982 |
| - break; |
1983 |
| - default: |
1984 |
| - GGML_ABORT("Unsupported type"); |
| 1952 | + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) { |
| 1953 | + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); |
| 1954 | + to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); |
1985 | 1955 | }
|
1986 | 1956 | }
|
1987 | 1957 |
|
@@ -2411,9 +2381,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
2411 | 2381 | case GGML_OP_CONV_2D_DW:
|
2412 | 2382 | ggml_cuda_op_conv2d_dw(ctx, dst);
|
2413 | 2383 | break;
|
2414 |
| - case GGML_OP_CONV_TRANSPOSE_2D: |
2415 |
| - ggml_cuda_conv_2d_transpose_p0(ctx, dst); |
2416 |
| - break; |
2417 | 2384 | case GGML_OP_CONV_TRANSPOSE_1D:
|
2418 | 2385 | ggml_cuda_op_conv_transpose_1d(ctx,dst);
|
2419 | 2386 | break;
|
@@ -3340,7 +3307,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
3340 | 3307 | }
|
3341 | 3308 | case GGML_OP_IM2COL:
|
3342 | 3309 | case GGML_OP_CONV_2D_DW:
|
3343 |
| - case GGML_OP_CONV_TRANSPOSE_2D: |
3344 | 3310 | case GGML_OP_POOL_2D:
|
3345 | 3311 | case GGML_OP_SUM:
|
3346 | 3312 | case GGML_OP_SUM_ROWS:
|
|
0 commit comments