Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Commit 3289b05

Browse files
committed
WIP f16 fma
1 parent 122ba48 commit 3289b05

File tree

14 files changed

+198
-6
lines changed

14 files changed

+198
-6
lines changed

crates/libm-macros/src/shared.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ const ALL_OPERATIONS_NESTED: &[(FloatTy, Signature, Option<Signature>, &[&str])]
9292
None,
9393
&["copysignf128", "fdimf128"],
9494
),
95+
(
96+
// `(f16, f16, f16) -> f16`
97+
FloatTy::F16,
98+
Signature { args: &[Ty::F16, Ty::F16, Ty::F16], returns: &[Ty::F16] },
99+
None,
100+
&["fmaf16"],
101+
),
95102
(
96103
// `(f32, f32, f32) -> f32`
97104
FloatTy::F32,

crates/libm-test/src/mpfloat.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ libm_macros::for_each_function! {
188188
expm1 | expm1f => exp_m1,
189189
fabs | fabsf => abs,
190190
fdim | fdimf | fdimf16 | fdimf128 => positive_diff,
191-
fma | fmaf => mul_add,
191+
fma | fmaf | fmaf16 => mul_add,
192192
fmax | fmaxf => max,
193193
fmin | fminf => min,
194194
lgamma | lgammaf => ln_gamma,

crates/libm-test/src/precision.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,9 @@ fn int_float_common<F1: Float, F2: Float>(
588588
impl MaybeOverride<(f32, i32)> for SpecialCase {}
589589
impl MaybeOverride<(f64, i32)> for SpecialCase {}
590590

591+
#[cfg(f16_enabled)]
592+
impl MaybeOverride<(f16, f16, f16)> for SpecialCase {}
593+
591594
impl MaybeOverride<(f32, f32, f32)> for SpecialCase {
592595
fn check_float<F: Float>(
593596
input: (f32, f32, f32),
@@ -609,6 +612,9 @@ impl MaybeOverride<(f64, f64, f64)> for SpecialCase {
609612
}
610613
}
611614

615+
#[cfg(f128_enabled)]
616+
impl MaybeOverride<(f128, f128, f128)> for SpecialCase {}
617+
612618
// F1 and F2 are always the same type, this is just to please generics
613619
fn ternop_common<F1: Float, F2: Float>(
614620
input: (F1, F1, F1),

crates/libm-test/tests/compare_built_musl.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ libm_macros::for_each_function! {
8989
fdimf16,
9090
floorf128,
9191
floorf16,
92+
fmaf16,
9293
rintf128,
9394
rintf16,
9495
sqrtf128,

etc/function-definitions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,12 @@
376376
],
377377
"type": "f32"
378378
},
379+
"fmaf16": {
380+
"sources": [
381+
"src/math/fmaf16.rs"
382+
],
383+
"type": "f16"
384+
},
379385
"fmax": {
380386
"sources": [
381387
"src/libm_helper.rs",

etc/function-list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ floorf128
5353
floorf16
5454
fma
5555
fmaf
56+
fmaf16
5657
fmax
5758
fmaxf
5859
fmin

src/math/fmaf.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ use super::fenv::{
4747
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
4848
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
4949
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
50+
if true {
51+
return super::generic::fma_big::<f32, f64>(x, y, z);
52+
}
53+
5054
let xy: f64;
5155
let mut result: f64;
5256
let mut ui: u64;

src/math/fmaf16.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
2+
pub fn fmaf16(x: f16, y: f16, z: f16) -> f16 {
3+
super::generic::fma_big::<f16, f32>(x, y, z)
4+
}

src/math/generic/fma.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use super::super::fenv::{
2+
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
3+
};
4+
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, IntTy, MinInt};
5+
6+
/// FMA implementation when a hardware-backed larger float type is available.
7+
pub fn fma_big<F, B>(x: F, y: F, z: F) -> F
8+
where
9+
F: Float + HFloat<D = B>,
10+
B: Float + DFloat<H = F>,
11+
// F: Float + CastInto<B>,
12+
// B: Float + CastInto<F> + CastFrom<F>,
13+
B::Int: CastInto<i32>,
14+
i32: CastFrom<i32>,
15+
{
16+
let one = IntTy::<B>::ONE;
17+
18+
let xy: B;
19+
let mut result: B;
20+
let mut ui: B::Int;
21+
let e: i32;
22+
23+
xy = x.widen() * y.widen();
24+
result = xy + z.widen();
25+
ui = result.to_bits();
26+
e = i32::cast_from(ui >> F::SIG_BITS) & F::EXP_MAX as i32;
27+
let zb: B = z.widen();
28+
29+
let prec_diff = B::SIG_BITS - F::SIG_BITS;
30+
let excess_prec = ui & ((one << prec_diff) - one);
31+
let x = one << (prec_diff - 1);
32+
33+
// Common case: the larger precision is fine
34+
if excess_prec != x
35+
|| e == i32::cast_from(F::EXP_MAX)
36+
|| (result - xy == zb && result - zb == xy)
37+
|| fegetround() != FE_TONEAREST
38+
{
39+
// TODO: feclearexcept
40+
41+
return result.narrow();
42+
}
43+
44+
let neg = ui & B::SIGN_MASK > IntTy::<B>::ZERO;
45+
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
46+
if neg == (err < B::ZERO) {
47+
ui += one;
48+
} else {
49+
ui -= one;
50+
}
51+
52+
B::from_bits(ui).narrow()
53+
}

src/math/generic/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod copysign;
33
mod fabs;
44
mod fdim;
55
mod floor;
6+
mod fma;
67
mod rint;
78
mod sqrt;
89
mod trunc;
@@ -12,6 +13,7 @@ pub use copysign::copysign;
1213
pub use fabs::fabs;
1314
pub use fdim::fdim;
1415
pub use floor::floor;
16+
pub use fma::fma_big;
1517
pub use rint::rint;
1618
pub use sqrt::sqrt;
1719
pub use trunc::trunc;

0 commit comments

Comments
 (0)