Skip to content

Commit 5c75115

Browse files
committed
CUDA: set_rows + cpy.cu refactor (ggml-org#14712)
Updated for IKL IQ4_NL and Q6_0. Original Author : Aman Gupta
1 parent f30e462 commit 5c75115

File tree

4 files changed

+475
-259
lines changed

4 files changed

+475
-259
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
#pragma once
2+
3+
#include "ggml-common.h"
4+
5+
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
6+
*dst = *src;
7+
}
8+
9+
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
10+
*dst = __float2half(*src);
11+
}
12+
13+
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
14+
*dst = *src;
15+
}
16+
17+
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
18+
*dst = *src;
19+
}
20+
21+
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
22+
*dst = *src;
23+
}
24+
25+
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
26+
if (x <= val[0]) return 0;
27+
if (x >= val[n-1]) return n-1;
28+
int ml = 0, mu = n-1;
29+
while (mu-ml > 1) {
30+
int mav = (ml+mu)/2;
31+
if (x < val[mav]) mu = mav; else ml = mav;
32+
}
33+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
34+
}
35+
36+
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
37+
float amax = 0.0f;
38+
float vmax = 0.0f;
39+
40+
for (int j = 0; j < QK4_0; ++j) {
41+
const float v = x[j];
42+
if (amax < fabsf(v)) {
43+
amax = fabsf(v);
44+
vmax = v;
45+
}
46+
}
47+
48+
const float d = vmax / -8;
49+
const float id = d ? 1.0f/d : 0.0f;
50+
51+
y->d = d;
52+
53+
for (int j = 0; j < QK4_0/2; ++j) {
54+
const float x0 = x[0 + j]*id;
55+
const float x1 = x[QK4_0/2 + j]*id;
56+
57+
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
58+
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
59+
60+
y->qs[j] = xi0;
61+
y->qs[j] |= xi1 << 4;
62+
}
63+
}
64+
65+
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
66+
float vmin = FLT_MAX;
67+
float vmax = -FLT_MAX;
68+
69+
for (int j = 0; j < QK4_1; ++j) {
70+
const float v = x[j];
71+
if (v < vmin) vmin = v;
72+
if (v > vmax) vmax = v;
73+
}
74+
75+
const float d = (vmax - vmin) / ((1 << 4) - 1);
76+
const float id = d ? 1.0f/d : 0.0f;
77+
78+
y->dm.x = d;
79+
y->dm.y = vmin;
80+
81+
for (int j = 0; j < QK4_1/2; ++j) {
82+
const float x0 = (x[0 + j] - vmin)*id;
83+
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
84+
85+
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
86+
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
87+
88+
y->qs[j] = xi0;
89+
y->qs[j] |= xi1 << 4;
90+
}
91+
}
92+
93+
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
94+
float amax = 0.0f;
95+
float vmax = 0.0f;
96+
97+
for (int j = 0; j < QK5_0; ++j) {
98+
const float v = x[j];
99+
if (amax < fabsf(v)) {
100+
amax = fabsf(v);
101+
vmax = v;
102+
}
103+
}
104+
105+
const float d = vmax / -16;
106+
const float id = d ? 1.0f/d : 0.0f;
107+
108+
y->d = d;
109+
110+
uint32_t qh = 0;
111+
for (int j = 0; j < QK5_0/2; ++j) {
112+
const float x0 = x[0 + j]*id;
113+
const float x1 = x[QK5_0/2 + j]*id;
114+
115+
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
116+
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
117+
118+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
119+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
120+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
121+
}
122+
memcpy(y->qh, &qh, sizeof(qh));
123+
}
124+
125+
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
126+
float min = x[0];
127+
float max = x[0];
128+
129+
for (int j = 1; j < QK5_1; ++j) {
130+
const float v = x[j];
131+
min = v < min ? v : min;
132+
max = v > max ? v : max;
133+
}
134+
135+
const float d = (max - min) / 31;
136+
const float id = d ? 1.0f/d : 0.0f;
137+
138+
y->dm.x = d;
139+
y->dm.y = min;
140+
141+
uint32_t qh = 0;
142+
for (int j = 0; j < QK5_1/2; ++j) {
143+
const float x0 = (x[0 + j] - min)*id;
144+
const float x1 = (x[QK5_1/2 + j] - min)*id;
145+
146+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
147+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
148+
149+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
150+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
151+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
152+
}
153+
memcpy(y->qh, &qh, sizeof(qh));
154+
}
155+
156+
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
157+
float amax = 0.0f; // absolute max
158+
159+
for (int j = 0; j < QK8_0; j++) {
160+
const float v = x[j];
161+
amax = fmaxf(amax, fabsf(v));
162+
}
163+
164+
const float d = amax / ((1 << 7) - 1);
165+
const float id = d ? 1.0f/d : 0.0f;
166+
167+
y->d = d;
168+
169+
for (int j = 0; j < QK8_0; ++j) {
170+
const float x0 = x[j]*id;
171+
y->qs[j] = roundf(x0);
172+
}
173+
}
174+
175+
static __device__ const int8_t iq4nl_index[241] = {
176+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
177+
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
178+
3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
179+
5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8,
180+
8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26,
181+
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13,
182+
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
183+
14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15
184+
};
185+
static __device__ __forceinline__ int best_index_iq4nl(const int8_t * values, float x) {
186+
int ix = (int)x - values[0];
187+
if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15;
188+
ix = iq4nl_index[ix];
189+
return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15;
190+
}
191+
192+
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
193+
// const float * xi = (const float *) cxi;
194+
// block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
195+
196+
float amax = 0.0f;
197+
float vmax = 0.0f;
198+
199+
for (int j = 0; j < QK4_NL; ++j) {
200+
const float v = x[j];
201+
if (amax < fabsf(v)) {
202+
amax = fabsf(v);
203+
vmax = v;
204+
}
205+
}
206+
207+
float d = vmax / kvalues_iq4nl[0];
208+
const float id = d ? 1.0f/d : 0.0f;
209+
210+
//dsti->d = d;
211+
212+
float sumqx = 0, sumq2 = 0;
213+
for (int j = 0; j < QK4_NL/2; ++j) {
214+
const float x0 = x[0 + j]*id;
215+
const float x1 = x[QK4_NL/2 + j]*id;
216+
const uint8_t xi0 = best_index_iq4nl(kvalues_iq4nl, x0);
217+
const uint8_t xi1 = best_index_iq4nl(kvalues_iq4nl, x1);
218+
y->qs[j] = xi0 | (xi1 << 4);
219+
const float v0 = kvalues_iq4nl[xi0];
220+
const float v1 = kvalues_iq4nl[xi1];
221+
const float w0 = x[0 + j]*x[0 + j];
222+
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
223+
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
224+
sumq2 += w0*v0*v0 + w1*v1*v1;
225+
}
226+
227+
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
228+
}
229+
230+
static __device__ void quantize_f32_q6_0_block(const float * __restrict__ x, block_q6_0 * __restrict__ y) {
231+
// const float * xi = (const float *) cxi;
232+
// block_q6_0 * dsti = (block_q6_0 *) cdsti;
233+
234+
float amax = 0.0f;
235+
float vmax = 0.0f;
236+
237+
for (int j = 0; j < QK6_0; ++j) {
238+
const float v = x[j];
239+
const float av = fabsf(x[j]);
240+
if (amax < av) {
241+
amax = av;
242+
vmax = v;
243+
}
244+
}
245+
246+
const float d = vmax / -32;
247+
const float id = d ? 1.0f/d : 0.0f;
248+
249+
y->d = d;
250+
memset(y->qh, 0, QK6_0/4);
251+
252+
for (int j = 0; j < QK6_0/2; ++j) {
253+
const float x0 = x[0 + j]*id;
254+
const float x1 = x[QK4_0/2 + j]*id;
255+
256+
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
257+
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
258+
259+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
260+
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
261+
y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
262+
}
263+
}
264+
265+
// Wrapper functions for cpy.cu compatibility
266+
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
267+
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
268+
}
269+
270+
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
271+
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
272+
}
273+
274+
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
275+
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
276+
}
277+
278+
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
279+
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
280+
}
281+
282+
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
283+
quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti);
284+
}
285+
286+
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
287+
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
288+
}
289+
290+
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
291+
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
292+
}
293+
294+
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
295+
convert_f32_f32((const float *)cxi, (float *)cdsti);
296+
}
297+
298+
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
299+
convert_f32_f16((const float *)cxi, (half *)cdsti);
300+
}
301+
302+
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
303+
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
304+
}
305+
306+
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
307+
convert_f16_f16((const half *)cxi, (half *)cdsti);
308+
}
309+
310+
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
311+
convert_f16_f32((const half *)cxi, (float *)cdsti);
312+
}

0 commit comments

Comments
 (0)