Skip to content

Commit f9a31ee

Browse files
authored
CUDA: set_rows + cpy.cu refactor (#14712)
1 parent 8f974bc commit f9a31ee

File tree

4 files changed

+396
-244
lines changed

4 files changed

+396
-244
lines changed

ggml/src/ggml-cuda/cpy-utils.cuh

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#pragma once
2+
3+
#include "ggml-common.h"
4+
5+
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
6+
*dst = *src;
7+
}
8+
9+
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
10+
*dst = __float2half(*src);
11+
}
12+
13+
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
14+
*dst = *src;
15+
}
16+
17+
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
18+
*dst = *src;
19+
}
20+
21+
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
22+
*dst = *src;
23+
}
24+
25+
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
26+
if (x <= val[0]) return 0;
27+
if (x >= val[n-1]) return n-1;
28+
int ml = 0, mu = n-1;
29+
while (mu-ml > 1) {
30+
int mav = (ml+mu)/2;
31+
if (x < val[mav]) mu = mav; else ml = mav;
32+
}
33+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
34+
}
35+
36+
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
37+
float amax = 0.0f;
38+
float vmax = 0.0f;
39+
40+
for (int j = 0; j < QK4_0; ++j) {
41+
const float v = x[j];
42+
if (amax < fabsf(v)) {
43+
amax = fabsf(v);
44+
vmax = v;
45+
}
46+
}
47+
48+
const float d = vmax / -8;
49+
const float id = d ? 1.0f/d : 0.0f;
50+
51+
y->d = d;
52+
53+
for (int j = 0; j < QK4_0/2; ++j) {
54+
const float x0 = x[0 + j]*id;
55+
const float x1 = x[QK4_0/2 + j]*id;
56+
57+
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
58+
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
59+
60+
y->qs[j] = xi0;
61+
y->qs[j] |= xi1 << 4;
62+
}
63+
}
64+
65+
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
66+
float vmin = FLT_MAX;
67+
float vmax = -FLT_MAX;
68+
69+
for (int j = 0; j < QK4_1; ++j) {
70+
const float v = x[j];
71+
if (v < vmin) vmin = v;
72+
if (v > vmax) vmax = v;
73+
}
74+
75+
const float d = (vmax - vmin) / ((1 << 4) - 1);
76+
const float id = d ? 1.0f/d : 0.0f;
77+
78+
y->dm.x = d;
79+
y->dm.y = vmin;
80+
81+
for (int j = 0; j < QK4_1/2; ++j) {
82+
const float x0 = (x[0 + j] - vmin)*id;
83+
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
84+
85+
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
86+
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
87+
88+
y->qs[j] = xi0;
89+
y->qs[j] |= xi1 << 4;
90+
}
91+
}
92+
93+
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
94+
float amax = 0.0f;
95+
float vmax = 0.0f;
96+
97+
for (int j = 0; j < QK5_0; ++j) {
98+
const float v = x[j];
99+
if (amax < fabsf(v)) {
100+
amax = fabsf(v);
101+
vmax = v;
102+
}
103+
}
104+
105+
const float d = vmax / -16;
106+
const float id = d ? 1.0f/d : 0.0f;
107+
108+
y->d = d;
109+
110+
uint32_t qh = 0;
111+
for (int j = 0; j < QK5_0/2; ++j) {
112+
const float x0 = x[0 + j]*id;
113+
const float x1 = x[QK5_0/2 + j]*id;
114+
115+
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
116+
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
117+
118+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
119+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
120+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
121+
}
122+
memcpy(y->qh, &qh, sizeof(qh));
123+
}
124+
125+
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
126+
float min = x[0];
127+
float max = x[0];
128+
129+
for (int j = 1; j < QK5_1; ++j) {
130+
const float v = x[j];
131+
min = v < min ? v : min;
132+
max = v > max ? v : max;
133+
}
134+
135+
const float d = (max - min) / 31;
136+
const float id = d ? 1.0f/d : 0.0f;
137+
138+
y->dm.x = d;
139+
y->dm.y = min;
140+
141+
uint32_t qh = 0;
142+
for (int j = 0; j < QK5_1/2; ++j) {
143+
const float x0 = (x[0 + j] - min)*id;
144+
const float x1 = (x[QK5_1/2 + j] - min)*id;
145+
146+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
147+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
148+
149+
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
150+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
151+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
152+
}
153+
memcpy(y->qh, &qh, sizeof(qh));
154+
}
155+
156+
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
157+
float amax = 0.0f; // absolute max
158+
159+
for (int j = 0; j < QK8_0; j++) {
160+
const float v = x[j];
161+
amax = fmaxf(amax, fabsf(v));
162+
}
163+
164+
const float d = amax / ((1 << 7) - 1);
165+
const float id = d ? 1.0f/d : 0.0f;
166+
167+
y->d = d;
168+
169+
for (int j = 0; j < QK8_0; ++j) {
170+
const float x0 = x[j]*id;
171+
y->qs[j] = roundf(x0);
172+
}
173+
}
174+
175+
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
176+
float amax = 0.0f;
177+
float vmax = 0.0f;
178+
179+
for (int j = 0; j < QK4_NL; ++j) {
180+
const float v = x[j];
181+
if (amax < fabsf(v)) {
182+
amax = fabsf(v);
183+
vmax = v;
184+
}
185+
}
186+
187+
float d = vmax / kvalues_iq4nl[0];
188+
const float id = d ? 1.0f/d : 0.0f;
189+
190+
float sumqx = 0, sumq2 = 0;
191+
for (int j = 0; j < QK4_NL/2; ++j) {
192+
const float x0 = x[0 + j]*id;
193+
const float x1 = x[QK4_NL/2 + j]*id;
194+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
195+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
196+
y->qs[j] = xi0 | (xi1 << 4);
197+
const float v0 = kvalues_iq4nl[xi0];
198+
const float v1 = kvalues_iq4nl[xi1];
199+
const float w0 = x[0 + j]*x[0 + j];
200+
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
201+
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
202+
sumq2 += w0*v0*v0 + w1*v1*v1;
203+
}
204+
205+
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
206+
}
207+
208+
// Wrapper functions for cpy.cu compatibility
209+
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
210+
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
211+
}
212+
213+
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
214+
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
215+
}
216+
217+
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
218+
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
219+
}
220+
221+
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
222+
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
223+
}
224+
225+
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
226+
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
227+
}
228+
229+
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
230+
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
231+
}
232+
233+
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
234+
convert_f32_f32((const float *)cxi, (float *)cdsti);
235+
}
236+
237+
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
238+
convert_f32_f16((const float *)cxi, (half *)cdsti);
239+
}
240+
241+
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
242+
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
243+
}
244+
245+
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
246+
convert_f16_f16((const half *)cxi, (half *)cdsti);
247+
}
248+
249+
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
250+
convert_f16_f32((const half *)cxi, (float *)cdsti);
251+
}

0 commit comments

Comments
 (0)