Skip to content

Commit a8e7e6f

Browse files
committed
add FP8 support to gguf/llama:
E5M2 & E4M3: for use with FP8 distributed model E4M3_Q & E3M4_Q: for gguf quantized model. E5M2 and A4M3 type are use like FP16 / BF16 native. E4M3_Q and E3M4_Q are define like Q8_0 with bloc size of 256 (like QK_K)
1 parent 6687503 commit a8e7e6f

File tree

11 files changed

+480
-65
lines changed

11 files changed

+480
-65
lines changed

Makefile

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ GGML_NO_OPENMP := 1
138138
DEPRECATE_WARNING := 1
139139
endif
140140

141+
ifdef LLAMA_NO_OPENMP_SIMD
142+
GGML_NO_OPENMP_SIMD := 1
143+
endif
144+
141145
ifdef LLAMA_NO_METAL
142146
GGML_NO_METAL := 1
143147
DEPRECATE_WARNING := 1
@@ -548,6 +552,12 @@ ifndef GGML_NO_OPENMP
548552
endif # GGML_MUSA
549553
endif # GGML_NO_OPENMP
550554

555+
ifndef GGML_NO_OPENMP_SIMD
556+
MK_CPPFLAGS += -DGGML_USE_OPENMP_SIMD
557+
MK_CFLAGS += -fopenmp-simd
558+
MK_CXXFLAGS += -fopenmp-simd
559+
endif # GGML_NO_OPENMP_SIMD
560+
551561
ifdef GGML_OPENBLAS
552562
MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas)
553563
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
@@ -918,7 +928,8 @@ OBJ_GGML += \
918928
ggml/src/ggml-alloc.o \
919929
ggml/src/ggml-backend.o \
920930
ggml/src/ggml-quants.o \
921-
ggml/src/ggml-aarch64.o
931+
ggml/src/ggml-aarch64.o \
932+
ggml/src/ggml-fp8.o
922933

923934
OBJ_LLAMA = \
924935
src/llama.o \
@@ -1074,6 +1085,12 @@ ggml/src/ggml-aarch64.o: \
10741085
ggml/src/ggml-common.h
10751086
$(CC) $(CFLAGS) -c $< -o $@
10761087

1088+
ggml/src/ggml-fp8.o: \
1089+
ggml/src/ggml-fp8.cpp \
1090+
ggml/src/ggml-fp8.h \
1091+
ggml/src/ggml-common.h
1092+
$(CXX) $(CXXFLAGS) -std=c++17 -c $< -o $@
1093+
10771094
ggml/src/ggml-blas.o: \
10781095
ggml/src/ggml-blas.cpp \
10791096
ggml/include/ggml-blas.h

examples/quantize/quantize.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
5151
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
5252
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
5353
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
54+
55+
{ "E4M3_Q", LLAMA_FTYPE_MOSTLY_E4M3_Q, "12,21G, 0.0050 kld @ Mistral-Nemo", },
56+
{ "E3M4_Q", LLAMA_FTYPE_MOSTLY_E3M4_Q, "12,21G, 0.0016 kld @ Mistral-Nemo", },
57+
5458
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
5559
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
5660
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },

ggml/include/ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,11 @@ extern "C" {
390390
GGML_TYPE_Q4_0_8_8 = 33,
391391
GGML_TYPE_TQ1_0 = 34,
392392
GGML_TYPE_TQ2_0 = 35,
393+
GGML_TYPE_E5M2 = 36,
394+
GGML_TYPE_E4M3 = 37,
395+
GGML_TYPE_E4M3_Q = 38,
396+
GGML_TYPE_E3M4_Q = 39,
397+
// E5M6 => 12 bits vs 16 bits for BF16 = E8M7 / FP16 = E5M10
393398
GGML_TYPE_COUNT,
394399
};
395400

@@ -434,6 +439,10 @@ extern "C" {
434439
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
435440
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
436441
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
442+
GGML_FTYPE_MOSTLY_E5M2 = 28, // except 1d tensors
443+
GGML_FTYPE_MOSTLY_E4M3 = 29, // except 1d tensors
444+
GGML_FTYPE_MOSTLY_E4M3_Q = 30, // except 1d tensors
445+
GGML_FTYPE_MOSTLY_E3M4_Q = 31, // except 1d tensors
437446
};
438447

439448
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,24 @@ typedef struct {
418418
} block_iq4_xs;
419419
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
420420

421+
// the fp8 types.
422+
typedef uint8_t ggml_e5m2;
423+
typedef uint8_t ggml_e4m3;
424+
typedef uint8_t ggml_e3m4;
425+
426+
// fp8 with bloc delta => 8.125 bpw
427+
typedef struct {
428+
float d; // delta
429+
ggml_e4m3 qs[QK_K];
430+
} block_e4m3_q;
431+
static_assert(sizeof(block_e4m3_q) == sizeof(float) + QK_K, "wrong block_e4m3_q block size/padding");
432+
433+
typedef struct {
434+
float d; // delta
435+
ggml_e3m4 qs[QK_K];
436+
} block_e3m4_q;
437+
static_assert(sizeof(block_e3m4_q) == sizeof(float) + QK_K, "wrong block_e3m4_q block size/padding");
438+
421439
#endif // GGML_COMMON_DECL
422440
#endif // GGML_COMMON_DECL
423441

ggml/src/ggml-fp8.cpp

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#define GGML_COMMON_IMPL_C
2+
#include "ggml-common.h"
3+
4+
#include "ggml-fp8.h"
5+
6+
#include <cassert>
7+
8+
/*
9+
# ./llama-quantize --output-tensor-type fp8_e3m4_q ~/LLM/Mistral-Nemo-Instruct-2407.BF16.gguf ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf E3M4_Q
10+
./llama-quantize ~/LLM/Mistral-Nemo-Instruct-2407.BF16.gguf ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf E3M4_Q
11+
./llama-cli -c 1024 -m ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf -p "[INST]bonjour a tu un nom. je ne sais pas comment t'appeler. Si tu n'en as pas je peux t'appeler TINTIN[/INST]" -s 42
12+
# ./llama-perplexity -f ~/LLM/wikitext-2-raw/wiki.test.raw -s 31337 -m ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf
13+
./llama-perplexity --kl-divergence-base ~/LLM/Mistral-Nemo-Instruct-2407.BF16.kld --kl-divergence -s 31337 -m ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf
14+
15+
*/
16+
17+
#include <iostream>
18+
#include <cstdint>
19+
#include <immintrin.h>
20+
21+
template<int N> constexpr float EXP2() {
22+
if constexpr (N==0) return 1;
23+
if constexpr (N>0) return EXP2<N-1>()*2;
24+
if constexpr (N<0) return EXP2<N+1>()/2;
25+
}
26+
27+
// 2^N avec N>0 en entier
28+
template<int N> constexpr int EXP_I2() {
29+
if constexpr (N==0) return 1;
30+
if constexpr (N>0) return EXP_I2<N-1>()*2;
31+
}
32+
33+
template<int _E> //, int M=7-E> 1.7 bits!
34+
struct FP8 {
35+
uint8_t bits;
36+
using type = FP8<_E>;
37+
static constexpr int E=_E;
38+
static constexpr int M=7-_E;
39+
static constexpr int E_BIAS=EXP2<_E-1>()-1;
40+
static constexpr float MAX() { return (2-EXP2<-M+1>())*EXP2<EXP_I2<_E-1>()>(); }
41+
static constexpr float MIN() { return EXP2<-M>()*EXP2<2-EXP_I2<_E-1>()>(); }
42+
//=============================================
43+
44+
#pragma omp declare simd
45+
void operator=(float value) {
46+
union {
47+
float f;
48+
uint32_t bits;
49+
} in = {value};
50+
// le signe:
51+
bits = (in.bits >> 24) & 0x80;
52+
// la valeur sans la signe!
53+
in.bits &= 0x7fffffff;
54+
//GGML_ASSERT(in.bits < 0x7f800000); // +/- infini ou NAN
55+
if (in.f >= MAX()) {
56+
bits |= 0x7E;
57+
} else if (in.f<MIN()) { // => 0.
58+
// OK: S.0000000
59+
} else {
60+
in.f *= EXP2<E_BIAS-127>();
61+
in.bits += 1<<(22-M); // for rounding
62+
bits |= (in.bits >> (23-M)) & 0x7F;
63+
}
64+
}
65+
66+
#pragma omp declare simd
67+
operator float () const {
68+
union {
69+
float f;
70+
uint32_t bits;
71+
} out = {0};
72+
// le signe:
73+
out.bits = bits & 0x80;
74+
out.bits <<= 24;
75+
uint32_t _bits = bits & 0x7F;
76+
_bits <<= (23-M);
77+
out.bits |= _bits;
78+
out.f *= EXP2<127-E_BIAS>();
79+
return out.f;
80+
}
81+
};
82+
83+
// block_e4m3_q
84+
//typedef struct {
85+
// float d; // delta
86+
// ggml_e4m3 qs[QK_K];
87+
//} block_e4m3_q;
88+
89+
template<int E>
90+
static inline void conv(const FP8<E>* x, float* y, int64_t size) {
91+
#pragma omp simd
92+
for (int64_t i=0; i<size; i++) {
93+
y[i] = (float) x[i];
94+
}
95+
}
96+
97+
template<int E>
98+
static inline void conv(const float* x, FP8<E>* y, int64_t size) {
99+
#pragma omp simd
100+
for (int64_t i=0; i<size; i++) {
101+
y[i] = x[i];
102+
}
103+
}
104+
105+
template<int E>
106+
static inline float dot(const FP8<E>* x, const float* y, int64_t size) {
107+
float z = 0;
108+
#pragma omp simd reduction(+:z)
109+
for (int64_t i=0; i<size; i++) {
110+
z += ((float)x[i])*y[i];
111+
}
112+
return z;
113+
}
114+
115+
template <int E, int QK>
116+
struct bloc_fp8 {
117+
float d;
118+
FP8<E> qs[QK];
119+
};
120+
121+
template <int E, int QK>
122+
static inline void conv(const bloc_fp8<E, QK>* x, float* y, int64_t size) {
123+
const auto qk_size = size / QK;
124+
for (int64_t q=0; q<qk_size; ++q) {
125+
#pragma omp simd
126+
for (int64_t i=0; i<QK; i++) {
127+
y[q*QK+i] = ((float) x[q].qs[i])*(x[q]).d;
128+
}
129+
}
130+
}
131+
132+
template <int E, int QK>
133+
static inline void conv(const float* x, bloc_fp8<E, QK>* y, int64_t size) {
134+
const auto qk_size = size / QK;
135+
for (int64_t q=0; q<qk_size; ++q) {
136+
float m = 0;
137+
#pragma omp simd reduction(max:m)
138+
for (int64_t i=0; i<QK; i++) {
139+
m = std::max(std::abs(x[q*QK+i]),m);
140+
}
141+
const float D = FP8<E>::MAX()/m;
142+
y[q].d = m/FP8<E>::MAX();
143+
#pragma omp simd
144+
for (int64_t i=0; i<QK; i++) {
145+
y[q].qs[i] = x[q*QK+i]*D;
146+
}
147+
}
148+
}
149+
150+
template <int E, int QK>
151+
static inline float dot(const bloc_fp8<E, QK>* x, const float* y, int64_t size) {
152+
float z = 0;
153+
const auto qk_size = size / QK;
154+
for (int64_t q=0; q<qk_size; ++q) {
155+
float z0 = 0;
156+
#pragma omp simd reduction(+:z0)
157+
for (int64_t i=0; i<QK; i++) {
158+
z0 += ((float)x[q].qs[i])*y[q*QK+i];
159+
}
160+
z += (x[q]).d * z0;
161+
}
162+
return z;
163+
}
164+
165+
// the C API.
166+
void ggml_e5m2_to_fp32_row(const ggml_e5m2_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
167+
conv(reinterpret_cast<const FP8<5>*>(x), y, k);
168+
}
169+
void ggml_fp32_to_e5m2_row(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k) {
170+
conv(x, reinterpret_cast<FP8<5>*>(y), k);
171+
}
172+
void ggml_fp32_to_e5m2_row_ref(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k) {
173+
for (int64_t i =0; i<k; ++i) {
174+
reinterpret_cast<FP8<5>*>(y)[i] = x[i];
175+
}
176+
}
177+
178+
void ggml_e4m3_to_fp32_row(const ggml_e4m3_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
179+
conv(reinterpret_cast<const FP8<4>*>(x), y, k);
180+
}
181+
void ggml_fp32_to_e4m3_row(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k) {
182+
conv(x, reinterpret_cast<FP8<4>*>(y), k);
183+
}
184+
void ggml_fp32_to_e4m3_row_ref(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k) {
185+
for (int64_t i =0; i<k; ++i) {
186+
reinterpret_cast<FP8<4>*>(y)[i] = x[i];
187+
}
188+
}
189+
190+
void dequantize_row_e4m3_q(const block_e4m3_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
191+
assert(k % QK_K == 0);
192+
conv(reinterpret_cast<const bloc_fp8<4, QK_K>*>(x), y, k);
193+
}
194+
void quantize_row_e4m3_q(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k) {
195+
assert(k % QK_K == 0);
196+
conv(x, reinterpret_cast<bloc_fp8<4, QK_K>*>(y), k);
197+
}
198+
void quantize_row_e4m3_q_ref(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k) {
199+
assert(k % QK_K == 0);
200+
conv(x, reinterpret_cast<bloc_fp8<4, QK_K>*>(y), k);
201+
}
202+
203+
void dequantize_row_e3m4_q(const block_e3m4_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
204+
assert(k % QK_K == 0);
205+
conv(reinterpret_cast<const bloc_fp8<3, QK_K>*>(x), y, k);
206+
}
207+
void quantize_row_e3m4_q(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k) {
208+
assert(k % QK_K == 0);
209+
conv(x, reinterpret_cast<bloc_fp8<3, QK_K>*>(y), k);
210+
}
211+
void quantize_row_e3m4_q_ref(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k) {
212+
assert(k % QK_K == 0);
213+
conv(x, reinterpret_cast<bloc_fp8<3, QK_K>*>(y), k);
214+
}
215+
216+
// the dot product for FP8 weight
217+
void ggml_vec_dot_e5m2(int n, float * GGML_RESTRICT s, size_t bs, const ggml_e5m2_t * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc) {
218+
assert(nrc == 1);
219+
GGML_UNUSED(nrc);
220+
GGML_UNUSED(bx);
221+
GGML_UNUSED(by);
222+
GGML_UNUSED(bs);
223+
*s = dot(reinterpret_cast<const FP8<5>*>(vx), vy, n);
224+
}
225+
226+
void ggml_vec_dot_e4m3(int n, float * GGML_RESTRICT s, size_t bs, const ggml_e4m3_t * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc) {
227+
assert(nrc == 1);
228+
GGML_UNUSED(nrc);
229+
GGML_UNUSED(bx);
230+
GGML_UNUSED(by);
231+
GGML_UNUSED(bs);
232+
*s = dot(reinterpret_cast<const FP8<4>*>(vx), vy, n);
233+
}
234+
235+
void ggml_vec_dot_e4m3_q(int n, float * GGML_RESTRICT s, size_t bs, const block_e4m3_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc) {
236+
assert(nrc == 1);
237+
GGML_UNUSED(nrc);
238+
GGML_UNUSED(bx);
239+
GGML_UNUSED(by);
240+
GGML_UNUSED(bs);
241+
*s = dot(reinterpret_cast<const bloc_fp8<4, QK_K>*>(vx), vy, n);
242+
}
243+
244+
void ggml_vec_dot_e3m4_q(int n, float * GGML_RESTRICT s, size_t bs, const block_e3m4_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc) {
245+
assert(nrc == 1);
246+
GGML_UNUSED(nrc);
247+
GGML_UNUSED(bx);
248+
GGML_UNUSED(by);
249+
GGML_UNUSED(bs);
250+
*s = dot(reinterpret_cast<const bloc_fp8<3, QK_K>*>(vx), vy, n);
251+
}

0 commit comments

Comments
 (0)