diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh new file mode 100644 index 0000000000000..e7a0bd2f1a077 --- /dev/null +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include "ggml-common.h" + +static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) { + *dst = *src; +} + +static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) { + *dst = __float2half(*src); +} + +static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) { + *dst = *src; +} + +static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) { + *dst = *src; +} + +static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) { + *dst = *src; +} + +static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) { + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = x[j]; + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (x[0 + j] - vmin)*id; + const float x1 = (x[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK5_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK5_0/2 + j]*id; + + const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) { + float min = x[0]; + float max = x[0]; + + for (int j = 1; j < QK5_1; ++j) { + const float v = x[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (x[0 + j] - min)*id; + const float x1 = (x[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[j]*id; + y->qs[j] = roundf(x0); + } +} + +static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_NL; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_NL/2 + j]*id; + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); + y->qs[j] = xi0 | (xi1 << 4); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = x[0 + j]*x[0 + j]; + const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j]; + sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + y->d = sumq2 > 0 ? sumqx/sumq2 : d; +} + +// Wrapper functions for cpy.cu compatibility +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { + quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { + quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { + quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); +} + +static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { + convert_f32_f32((const float *)cxi, (float *)cdsti); +} + +static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { + convert_f32_f16((const float *)cxi, (half *)cdsti); +} + +static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { + convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti); +} + +static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { + convert_f16_f16((const half *)cxi, (half *)cdsti); +} + +static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { + convert_f16_f32((const half *)cxi, (float *)cdsti); +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 2c55d2149b2d3..e7d0da087056b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,46 +1,12 @@ #include "cpy.cuh" #include "dequantize.cuh" +#include "cpy-utils.cuh" #ifdef GGML_USE_MUSA #include "ggml-musa/mudnn.cuh" #endif // GGML_USE_MUSA typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - half * dsti = (half *) cdsti; - - *dsti = __float2half(*xi); -} - -static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { - const half * xi = (const half *) cxi; - half * dsti = (half *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { - const half * xi = (const half *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - template static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in cpy_1(cx + x_offset, cdst + dst_offset); } -static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q8_0 * dsti = (block_q8_0 *) cdsti; - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = xi[j]; - amax = fmaxf(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = xi[j]*id; - - dsti->qs[j] = roundf(x0); - } -} - static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { } } -static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_0 * dsti = (block_q4_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_0; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - const float d = vmax / -8; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_0/2 + j]*id; - - const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q4_1 * dsti = (block_q4_1 *) cdsti; - - float vmin = FLT_MAX; - float vmax = -FLT_MAX; - - for (int j = 0; j < QK4_1; ++j) { - const float v = xi[j]; - - if (v < vmin) vmin = v; - if (v > vmax) vmax = v; - } - - const float d = (vmax - vmin) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dsti->dm.x = d; - dsti->dm.y = vmin; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (xi[0 + j] - vmin)*id; - const float x1 = (xi[QK4_1/2 + j] - vmin)*id; - - const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); - - dsti->qs[j] = xi0; - dsti->qs[j] |= xi1 << 4; - } -} - -static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q5_0 * dsti = (block_q5_0 *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK5_0; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - const float d = vmax / -16; - const float id = d ? 1.0f/d : 0.0f; - - dsti->d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK5_0/2 + j]*id; - - const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); - - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - memcpy(dsti->qh, &qh, sizeof(qh)); -} - -static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_q5_1 * dsti = (block_q5_1 *) cdsti; - - float min = xi[0]; - float max = xi[0]; - - for (int j = 1; j < QK5_1; ++j) { - const float v = xi[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dsti->dm.x = d; - dsti->dm.y = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (xi[0 + j] - min)*id; - const float x1 = (xi[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - memcpy(dsti->qh, &qh, sizeof(qh)); -} - template static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } } -static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - -static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - block_iq4_nl * dsti = (block_iq4_nl *) cdsti; - - float amax = 0.0f; - float vmax = 0.0f; - - for (int j = 0; j < QK4_NL; ++j) { - const float v = xi[j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - vmax = v; - } - } - - float d = vmax / kvalues_iq4nl[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = xi[0 + j]*id; - const float x1 = xi[QK4_NL/2 + j]*id; - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); - dsti->qs[j] = xi0 | (xi1 << 4); - const float v0 = kvalues_iq4nl[xi0]; - const float v1 = kvalues_iq4nl[xi1]; - const float w0 = xi[0 + j]*xi[0 + j]; - const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j]; - sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - } - - dsti->d = sumq2 > 0 ? sumqx/sumq2 : d; -} - template static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8015b0d4e8d92..e450e0ae2e55c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3226,8 +3226,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { -#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") - return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || + op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64; } break; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 58cee9244018f..560604d095f3b 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -1,4 +1,5 @@ #include "set-rows.cuh" +#include "cpy-utils.cuh" typedef void (*set_rows_kernel_t)(const char * src, char * dst); @@ -10,17 +11,93 @@ __device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { template<> __device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { - *dst_h = __float2half(*src_f); + convert_f32_f16(src_f, dst_h); } template<> __device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { - *dst_b = *src_f; + convert_f32_bf16(src_f, dst_b); } template<> __device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { - *dst_f = *src_f; + convert_f32_f32(src_f, dst_f); +} + +// Generic quantized set_rows kernel template +template +static __global__ void k_set_rows_quant( + const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { + + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + + if (i >= ne_total) { + return; + } + + const int64_t i_base = i * qk; + const int64_t i03 = i_base / (ne00 * ne01 * ne02); + const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type); + + const float * src_block = src0_row + i00; + block_type * dst_block = dst_row_ptr + i00 / qk; + + quantize_func(src_block, dst_block); +} + +// Template dispatch function for quantized set_rows +template +static void set_rows_cuda_quant( + const float * src0_d, const int64_t * src1_d, block_type * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + GGML_ASSERT(ne00 % qk == 0); + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(int64_t); + const int64_t s11 = nb11/sizeof(int64_t); + const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s1 = nb1; + const int64_t s2 = nb2; + const int64_t s3 = nb3; + + if (ne_total > 0) { + k_set_rows_quant<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + s01, s02, s03, + s10, s11, s12, + s1, s2, s3); + } } template @@ -145,7 +222,67 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_Q4_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q4_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q4_1) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q4_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q5_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_1) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q5_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q8_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q8_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_IQ4_NL) { + set_rows_cuda_quant( + src0_d, src1_d, (block_iq4_nl*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); } else { - GGML_ABORT("unsupported type"); + GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } }