Skip to content

Commit bb7cea0

Browse files
[libc][math][c++23] Add bfloat16 support in LLVM libc (#144463)
This PR enables support for BFloat16 type in LLVM libc along with support for testing BFloat16 functions via MPFR. --------- Signed-off-by: krishna2803 <kpandey81930@gmail.com> Signed-off-by: Krishna Pandey <kpandey81930@gmail.com> Co-authored-by: OverMighty <its.overmighty@gmail.com>
1 parent ff365ce commit bb7cea0

File tree

14 files changed

+298
-10
lines changed

14 files changed

+298
-10
lines changed

libc/src/__support/CPP/type_traits/is_floating_point.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ template <typename T> struct is_floating_point {
3636
,
3737
float128
3838
#endif
39-
>();
39+
,
40+
bfloat16>();
4041
};
4142
template <typename T>
4243
LIBC_INLINE_VAR constexpr bool is_floating_point_v =

libc/src/__support/FPUtil/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,17 @@ add_header_library(
257257
libc.src.__support.macros.properties.types
258258
)
259259

260+
add_header_library(
261+
bfloat16
262+
HDRS
263+
bfloat16.h
264+
DEPENDS
265+
.cast
266+
.dyadic_float
267+
libc.src.__support.CPP.bit
268+
libc.src.__support.CPP.type_traits
269+
libc.src.__support.macros.config
270+
libc.src.__support.macros.properties.types
271+
)
272+
260273
add_subdirectory(generic)

libc/src/__support/FPUtil/FPBits.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ enum class FPType {
3838
IEEE754_Binary64,
3939
IEEE754_Binary128,
4040
X86_Binary80,
41+
BFloat16
4142
};
4243

4344
// The classes hierarchy is as follows:
@@ -138,6 +139,14 @@ template <> struct FPLayout<FPType::X86_Binary80> {
138139
LIBC_INLINE_VAR static constexpr int FRACTION_LEN = SIG_LEN - 1;
139140
};
140141

142+
template <> struct FPLayout<FPType::BFloat16> {
143+
using StorageType = uint16_t;
144+
LIBC_INLINE_VAR static constexpr int SIGN_LEN = 1;
145+
LIBC_INLINE_VAR static constexpr int EXP_LEN = 8;
146+
LIBC_INLINE_VAR static constexpr int SIG_LEN = 7;
147+
LIBC_INLINE_VAR static constexpr int FRACTION_LEN = SIG_LEN;
148+
};
149+
141150
// FPStorage derives useful constants from the FPLayout above.
142151
template <FPType fp_type> struct FPStorage : public FPLayout<fp_type> {
143152
using UP = FPLayout<fp_type>;
@@ -801,6 +810,8 @@ template <typename T> LIBC_INLINE static constexpr FPType get_fp_type() {
801810
else if constexpr (cpp::is_same_v<UnqualT, float128>)
802811
return FPType::IEEE754_Binary128;
803812
#endif
813+
else if constexpr (cpp::is_same_v<UnqualT, bfloat16>)
814+
return FPType::BFloat16;
804815
else
805816
static_assert(cpp::always_false<UnqualT>, "Unsupported type");
806817
}

libc/src/__support/FPUtil/bfloat16.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===-- Definition of bfloat16 data type. -----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_BFLOAT16_H
10+
#define LLVM_LIBC_SRC___SUPPORT_FPUTIL_BFLOAT16_H
11+
12+
#include "src/__support/CPP/bit.h"
13+
#include "src/__support/CPP/type_traits.h"
14+
#include "src/__support/FPUtil/cast.h"
15+
#include "src/__support/FPUtil/dyadic_float.h"
16+
#include "src/__support/macros/config.h"
17+
#include "src/__support/macros/properties/types.h"
18+
19+
#include <stdint.h>
20+
21+
namespace LIBC_NAMESPACE_DECL {
22+
namespace fputil {
23+
24+
struct BFloat16 {
25+
uint16_t bits;
26+
27+
LIBC_INLINE BFloat16() = default;
28+
29+
LIBC_INLINE constexpr explicit BFloat16(uint16_t bits) : bits(bits) {}
30+
31+
template <typename T> LIBC_INLINE constexpr explicit BFloat16(T value) {
32+
if constexpr (cpp::is_floating_point_v<T>) {
33+
bits = fputil::cast<bfloat16>(value).bits;
34+
} else if constexpr (cpp::is_integral_v<T>) {
35+
Sign sign = Sign::POS;
36+
37+
if constexpr (cpp::is_signed_v<T>) {
38+
if (value < 0) {
39+
sign = Sign::NEG;
40+
value = -value;
41+
}
42+
}
43+
44+
fputil::DyadicFloat<cpp::numeric_limits<cpp::make_unsigned_t<T>>::digits>
45+
xd(sign, 0, value);
46+
bits = xd.template as<bfloat16, /*ShouldSignalExceptions=*/true>().bits;
47+
48+
} else {
49+
bits = fputil::cast<bfloat16>(static_cast<float>(value)).bits;
50+
}
51+
}
52+
53+
template <cpp::enable_if_t<fputil::get_fp_type<float>() ==
54+
fputil::FPType::IEEE754_Binary32,
55+
int> = 0>
56+
LIBC_INLINE constexpr operator float() const {
57+
uint32_t x_bits = static_cast<uint32_t>(bits) << 16U;
58+
return cpp::bit_cast<float>(x_bits);
59+
}
60+
}; // struct BFloat16
61+
62+
} // namespace fputil
63+
} // namespace LIBC_NAMESPACE_DECL
64+
65+
#endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_BFLOAT16_H

libc/src/__support/FPUtil/cast.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,18 @@ LIBC_INLINE constexpr cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
2626
cpp::is_floating_point_v<InType>,
2727
OutType>
2828
cast(InType x) {
29+
// Casting to the same type is a no-op.
30+
if constexpr (cpp::is_same_v<InType, OutType>)
31+
return x;
32+
33+
// bfloat16 is always defined (for now)
34+
if constexpr (cpp::is_same_v<OutType, bfloat16> ||
35+
cpp::is_same_v<InType, bfloat16>
2936
#if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
30-
if constexpr (cpp::is_same_v<OutType, float16> ||
31-
cpp::is_same_v<InType, float16>) {
37+
|| cpp::is_same_v<OutType, float16> ||
38+
cpp::is_same_v<InType, float16>
39+
#endif
40+
) {
3241
using InFPBits = FPBits<InType>;
3342
using InStorageType = typename InFPBits::StorageType;
3443
using OutFPBits = FPBits<OutType>;
@@ -58,7 +67,6 @@ cast(InType x) {
5867
DyadicFloat<cpp::bit_ceil(MAX_FRACTION_LEN)> xd(x);
5968
return xd.template as<OutType, /*ShouldSignalExceptions=*/true>();
6069
}
61-
#endif
6270

6371
return static_cast<OutType>(x);
6472
}

libc/src/__support/FPUtil/dyadic_float.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,14 @@ template <size_t Bits> struct DyadicFloat {
411411
(FPBits<T>::FRACTION_LEN < Bits),
412412
void>>
413413
LIBC_INLINE constexpr T as() const {
414+
if constexpr (cpp::is_same_v<T, bfloat16>
414415
#if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
415-
if constexpr (cpp::is_same_v<T, float16>)
416-
return generic_as<T, ShouldSignalExceptions>();
416+
|| cpp::is_same_v<T, float16>
417417
#endif
418-
return fast_as<T, ShouldSignalExceptions>();
418+
)
419+
return generic_as<T, ShouldSignalExceptions>();
420+
else
421+
return fast_as<T, ShouldSignalExceptions>();
419422
}
420423

421424
template <typename T,

libc/src/__support/macros/properties/types.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
#ifndef LLVM_LIBC_SRC___SUPPORT_MACROS_PROPERTIES_TYPES_H
1111
#define LLVM_LIBC_SRC___SUPPORT_MACROS_PROPERTIES_TYPES_H
1212

13-
#include "hdr/float_macros.h" // LDBL_MANT_DIG
13+
#include "hdr/float_macros.h" // LDBL_MANT_DIG
1414
#include "include/llvm-libc-macros/float16-macros.h" // LIBC_TYPES_HAS_FLOAT16
1515
#include "include/llvm-libc-types/float128.h" // float128
16+
#include "src/__support/macros/config.h" // LIBC_NAMESPACE_DECL
1617
#include "src/__support/macros/properties/architectures.h"
1718
#include "src/__support/macros/properties/compiler.h"
1819
#include "src/__support/macros/properties/cpu_features.h"
@@ -58,4 +59,14 @@ using float16 = _Float16;
5859
// LIBC_TYPES_HAS_FLOAT128 and 'float128' type are provided by
5960
// "include/llvm-libc-types/float128.h"
6061

62+
// -- bfloat16 support ---------------------------------------------------------
63+
64+
namespace LIBC_NAMESPACE_DECL {
65+
namespace fputil {
66+
struct BFloat16;
67+
}
68+
} // namespace LIBC_NAMESPACE_DECL
69+
70+
using bfloat16 = LIBC_NAMESPACE::fputil::BFloat16;
71+
6172
#endif // LLVM_LIBC_SRC___SUPPORT_MACROS_PROPERTIES_TYPES_H

libc/test/src/__support/FPUtil/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,14 @@ add_fp_unittest(
3838
DEPENDS
3939
libc.src.__support.FPUtil.rounding_mode
4040
)
41+
42+
add_fp_unittest(
43+
bfloat16_test
44+
NEED_MPFR
45+
SUITE
46+
libc-fputil-tests
47+
SRCS
48+
bfloat16_test.cpp
49+
DEPENDS
50+
libc.src.__support.FPUtil.bfloat16
51+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//===-- Unit tests for bfloat16 type --------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/__support/FPUtil/bfloat16.h"
10+
#include "test/UnitTest/FPMatcher.h"
11+
#include "test/UnitTest/Test.h"
12+
#include "utils/MPFRWrapper/MPCommon.h"
13+
14+
using BFloat16 = LIBC_NAMESPACE::fputil::BFloat16;
15+
using LlvmLibcBfloat16ConversionTest =
16+
LIBC_NAMESPACE::testing::FPTest<BFloat16>;
17+
18+
// range: [0, inf]
19+
static constexpr uint16_t POS_START = 0x0000U;
20+
static constexpr uint16_t POS_STOP = 0x7f80U;
21+
22+
// range: [-0, -inf]
23+
static constexpr uint16_t NEG_START = 0x8000U;
24+
static constexpr uint16_t NEG_STOP = 0xff80U;
25+
26+
using MPFRNumber = LIBC_NAMESPACE::testing::mpfr::MPFRNumber;
27+
28+
TEST_F(LlvmLibcBfloat16ConversionTest, ToFloatPositiveRange) {
29+
for (uint16_t bits = POS_START; bits <= POS_STOP; bits++) {
30+
BFloat16 bf16_num{bits};
31+
MPFRNumber mpfr_num{bf16_num};
32+
33+
// bfloat16 to float
34+
float mpfr_float = mpfr_num.as<float>();
35+
EXPECT_FP_EQ_ALL_ROUNDING(mpfr_float, static_cast<float>(bf16_num));
36+
37+
// float to bfloat16
38+
BFloat16 bf16_from_float{mpfr_float};
39+
MPFRNumber mpfr_num_2{mpfr_float};
40+
BFloat16 mpfr_bfloat = mpfr_num_2.as<BFloat16>();
41+
EXPECT_FP_EQ_ALL_ROUNDING(mpfr_bfloat, bf16_from_float);
42+
}
43+
}
44+
45+
TEST_F(LlvmLibcBfloat16ConversionTest, ToFloatNegativeRange) {
46+
for (uint16_t bits = NEG_START; bits <= NEG_STOP; bits++) {
47+
BFloat16 bf16_num{bits};
48+
MPFRNumber mpfr_num{bf16_num};
49+
50+
// bfloat16 to float
51+
float mpfr_float = mpfr_num.as<float>();
52+
EXPECT_FP_EQ_ALL_ROUNDING(mpfr_float, static_cast<float>(bf16_num));
53+
54+
// float to bfloat16
55+
BFloat16 bf16_from_float{mpfr_float};
56+
MPFRNumber mpfr_num_2{mpfr_float};
57+
BFloat16 mpfr_bfloat = mpfr_num_2.as<BFloat16>();
58+
EXPECT_FP_EQ_ALL_ROUNDING(mpfr_bfloat, bf16_from_float);
59+
}
60+
}
61+
62+
TEST_F(LlvmLibcBfloat16ConversionTest, FromInteger) {
63+
constexpr int RANGE = 100'000;
64+
for (int i = -RANGE; i <= RANGE; i++) {
65+
BFloat16 mpfr_bfloat = MPFRNumber(i).as<BFloat16>();
66+
BFloat16 libc_bfloat{i};
67+
EXPECT_FP_EQ_ALL_ROUNDING(mpfr_bfloat, libc_bfloat);
68+
}
69+
}

libc/test/src/math/exhaustive/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,3 +535,19 @@ add_fp_unittest(
535535
LINK_LIBRARIES
536536
-lpthread
537537
)
538+
539+
add_fp_unittest(
540+
bfloat16_test
541+
NO_RUN_POSTBUILD
542+
NEED_MPFR
543+
SUITE
544+
libc_math_exhaustive_tests
545+
SRCS
546+
bfloat16_test.cpp
547+
DEPENDS
548+
.exhaustive_test
549+
libc.src.__support.FPUtil.bfloat16
550+
libc.src.__support.FPUtil.fp_bits
551+
LINK_LIBRARIES
552+
-lpthread
553+
)

0 commit comments

Comments
 (0)