Skip to content

Commit 55e962a

Browse files
committed
Introduce bfloat16 support
Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───┐ 0b0000000000000000 brain16 This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───────────────────┐ 0b00000000000000000000000000000000 IEEE binary32 The issue is that converting bf16 to fp16 can result in information loss. Only 13% of bf16 numbers can be precisely represented in fp16 which in practice ends up being 99.71% of Mistral 7b v0.2's weights however there is currently no way other than fp32 to get the others ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌─┴─┐┌─┴──────┐ 0b0000000000000000 IEEE binary16 This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16
1 parent c780e75 commit 55e962a

File tree

8 files changed

+1788
-227
lines changed

8 files changed

+1788
-227
lines changed

examples/finetune/finetune.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
575575
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
576576

577577
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
578-
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
578+
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
579579
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
580580
} else if (a->type == GGML_TYPE_F32) {
581581
return ggml_add(ctx, a, b);

examples/quantize/quantize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
4646
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
4747
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
4848
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
49-
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", },
49+
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
50+
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
5051
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
5152
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
5253
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },

ggml-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,9 @@ size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml
518518
// return index, asserts if table is full
519519
size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key);
520520

521+
#define GGML_FP32_TO_BF16(x) ggml_fp32_to_bf16(x)
522+
#define GGML_BF16_TO_FP32(x) ggml_bf16_to_fp32(x)
523+
521524
#ifdef __cplusplus
522525
}
523526
#endif

ggml.c

Lines changed: 1675 additions & 223 deletions
Large diffs are not rendered by default.

ggml.h

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ extern "C" {
370370
GGML_TYPE_I64 = 27,
371371
GGML_TYPE_F64 = 28,
372372
GGML_TYPE_IQ1_M = 29,
373+
GGML_TYPE_BF16 = 30,
373374
GGML_TYPE_COUNT,
374375
};
375376

@@ -410,6 +411,7 @@ extern "C" {
410411
GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
411412
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
412413
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
414+
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
413415
};
414416

415417
// available tensor operations:
@@ -2390,6 +2392,90 @@ extern "C" {
23902392
GGML_API int ggml_cpu_has_vsx (void);
23912393
GGML_API int ggml_cpu_has_matmul_int8(void);
23922394

2395+
/**
2396+
* Google Brain 16-bit floating point number.
2397+
*
2398+
* ┌sign
2399+
* │
2400+
* │ ┌exponent
2401+
* │ │
2402+
* │ │ ┌mantissa
2403+
* │ │ │
2404+
* │┌──┴───┐┌─┴───┐
2405+
* 0b0000000000000000 brain16
2406+
*
2407+
* Since bf16 has the same number of exponent bits as a 32bit float,
2408+
* encoding and decoding numbers becomes relatively straightforward.
2409+
*
2410+
* ┌sign
2411+
* │
2412+
* │ ┌exponent
2413+
* │ │
2414+
* │ │ ┌mantissa
2415+
* │ │ │
2416+
* │┌──┴───┐┌─┴───────────────────┐
2417+
* 0b00000000000000000000000000000000 IEEE binary32
2418+
*
2419+
* For comparison, the standard fp16 format has fewer exponent bits.
2420+
*
2421+
* ┌sign
2422+
* │
2423+
* │ ┌exponent
2424+
* │ │
2425+
* │ │ ┌mantissa
2426+
* │ │ │
2427+
* │┌─┴─┐┌─┴──────┐
2428+
* 0b0000000000000000 IEEE binary16
2429+
*
2430+
* So be warned that converting between them, destroys several bits.
2431+
*
2432+
* @see IEEE 754-2008
2433+
*/
2434+
typedef struct {
2435+
uint16_t x;
2436+
} ggml_bf16_t;
2437+
2438+
/**
2439+
* Converts brain16 to float32.
2440+
*/
2441+
static inline float ggml_bf16_to_fp32(ggml_bf16_t h) {
2442+
union {
2443+
float f;
2444+
uint32_t i;
2445+
} u;
2446+
u.i = (uint32_t)h.x << 16;
2447+
return u.f;
2448+
}
2449+
2450+
/**
2451+
* Converts float32 to brain16.
2452+
*
2453+
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
2454+
* Subnormals shall be flushed to zero, and NANs will be quiet.
2455+
* This code should vectorize nicely if using modern compilers.
2456+
*/
2457+
static inline ggml_bf16_t ggml_fp32_to_bf16(float s) {
2458+
ggml_bf16_t h;
2459+
union {
2460+
float f;
2461+
uint32_t i;
2462+
} u;
2463+
u.f = s;
2464+
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
2465+
h.x = (u.i >> 16) | 64; /* force to quiet */
2466+
return h;
2467+
}
2468+
if (!(u.i & 0x7f800000)) { /* subnormal */
2469+
h.x = (u.i & 0x80000000) >> 16; /* flush to zero */
2470+
return h;
2471+
}
2472+
h.x = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
2473+
return h;
2474+
}
2475+
2476+
GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int n);
2477+
GGML_API void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int n);
2478+
23932479
//
23942480
// Internal types and functions exposed for tests and benchmarks
23952481
//

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ class GGMLQuantizationType(IntEnum):
817817
I64 = 27
818818
F64 = 28
819819
IQ1_M = 29
820+
BF16 = 30
820821

821822

822823
class GGUFEndian(IntEnum):
@@ -862,6 +863,7 @@ def get_type(val: Any) -> GGUFValueType:
862863
GGML_QUANT_SIZES = {
863864
GGMLQuantizationType.F32: (1, 4),
864865
GGMLQuantizationType.F16: (1, 2),
866+
GGMLQuantizationType.BF16: (1, 2),
865867
GGMLQuantizationType.Q4_0: (32, 2 + 16),
866868
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
867869
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),

llama.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3175,6 +3175,7 @@ struct llama_model_loader {
31753175
switch (type_max) {
31763176
case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break;
31773177
case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break;
3178+
case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break;
31783179
case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break;
31793180
case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break;
31803181
case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break;
@@ -3666,6 +3667,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
36663667
switch (ftype) {
36673668
case LLAMA_FTYPE_ALL_F32: return "all F32";
36683669
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
3670+
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
36693671
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
36703672
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
36713673
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
@@ -6129,6 +6131,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
61296131
|| !(
61306132
model.ftype == LLAMA_FTYPE_ALL_F32 ||
61316133
model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
6134+
model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
61326135
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
61336136
model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
61346137
)
@@ -14158,13 +14161,16 @@ static void llama_tensor_dequantize_internal(
1415814161
if (qtype.to_float == NULL) {
1415914162
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
1416014163
}
14161-
} else if (tensor->type != GGML_TYPE_F16) {
14164+
} else if (tensor->type != GGML_TYPE_F16 &&
14165+
tensor->type != GGML_TYPE_BF16) {
1416214166
throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
1416314167
}
1416414168

1416514169
if (nthread < 2) {
1416614170
if (tensor->type == GGML_TYPE_F16) {
1416714171
ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
14172+
} else if (tensor->type == GGML_TYPE_BF16) {
14173+
ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
1416814174
} else if (ggml_is_quantized(tensor->type)) {
1416914175
qtype.to_float(tensor->data, f32_output, nelements);
1417014176
} else {
@@ -14173,7 +14179,14 @@ static void llama_tensor_dequantize_internal(
1417314179
return;
1417414180
}
1417514181

14176-
size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
14182+
size_t block_size;
14183+
if (tensor->type == GGML_TYPE_F16 ||
14184+
tensor->type == GGML_TYPE_BF16) {
14185+
block_size = 1;
14186+
} else {
14187+
block_size = (size_t)ggml_blck_size(tensor->type);
14188+
}
14189+
1417714190
size_t block_size_bytes = ggml_type_size(tensor->type);
1417814191

1417914192
GGML_ASSERT(nelements % block_size == 0);
@@ -14192,6 +14205,8 @@ static void llama_tensor_dequantize_internal(
1419214205
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
1419314206
if (typ == GGML_TYPE_F16) {
1419414207
ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
14208+
} else if (typ == GGML_TYPE_BF16) {
14209+
ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
1419514210
} else {
1419614211
qtype.to_float(inbuf, outbuf, nels);
1419714212
}
@@ -14552,6 +14567,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
1455214567
case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
1455314568
case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
1455414569
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
14570+
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
1455514571
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
1455614572

1455714573
// K-quants

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ extern "C" {
137137
LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors
138138
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
139139
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
140+
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
140141

141142
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
142143
};

0 commit comments

Comments
 (0)