Skip to content

Commit 49c667d

Browse files
authored
Try to fix fma on windows (#43530)
* enable fma on Windows
1 parent bff3eb6 commit 49c667d

File tree

3 files changed

+129
-26
lines changed

3 files changed

+129
-26
lines changed

src/llvm-cpufeatures.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@ Optional<bool> always_have_fma(Function &intr) {
3636
auto intr_name = intr.getName();
3737
auto typ = intr_name.substr(strlen("julia.cpu.have_fma."));
3838

39-
#if defined(_OS_WINDOWS_)
40-
// FMA on Windows is weirdly broken (#43088)
41-
return false;
42-
#elif defined(_CPU_AARCH64_)
39+
#if defined(_CPU_AARCH64_)
4340
return typ == "f32" || typ == "f64";
4441
#else
4542
(void)typ;

src/runtime_intrinsics.c

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,8 +1175,112 @@ bi_fintrinsic(div,div_float)
11751175
bi_fintrinsic(frem,rem_float)
11761176

11771177
// ternary operators //
1178+
// runtime fma is broken on windows, define julia_fma(f) ourself with fma_emulated as reference.
1179+
#if defined(_OS_WINDOWS_)
1180+
// reinterpret(UInt64, ::Float64)
1181+
uint64_t bitcast_d2u(double d) {
1182+
uint64_t r;
1183+
memcpy(&r, &d, 8);
1184+
return r;
1185+
}
1186+
// reinterpret(Float64, ::UInt64)
1187+
double bitcast_u2d(uint64_t d) {
1188+
double r;
1189+
memcpy(&r, &d, 8);
1190+
return r;
1191+
}
1192+
// Base.splitbits(::Float64)
1193+
void splitbits(double *hi, double *lo, double d) {
1194+
*hi = bitcast_u2d(bitcast_d2u(d) & 0xfffffffff8000000);
1195+
*lo = d - *hi;
1196+
}
1197+
// Base.exponent(::Float64)
1198+
int exponent(double a) {
1199+
int e;
1200+
frexp(a, &e);
1201+
return e - 1;
1202+
}
1203+
// Base.fma_emulated(::Float32, ::Float32, ::Float32)
1204+
float julia_fmaf(float a, float b, float c) {
1205+
double ab, res;
1206+
ab = (double)a * b;
1207+
res = ab + (double)c;
1208+
if ((bitcast_d2u(res) & 0x1fffffff) == 0x10000000){
1209+
double reslo = fabsf(c) > fabs(ab) ? ab-(res - c) : c-(res - ab);
1210+
if (reslo != 0)
1211+
res = nextafter(res, copysign(1.0/0.0, reslo));
1212+
}
1213+
return (float)res;
1214+
}
1215+
// Base.twomul(::Float64, ::Float64)
1216+
void two_mul(double *abhi, double *ablo, double a, double b) {
1217+
double ahi, alo, bhi, blo, blohi, blolo;
1218+
splitbits(&ahi, &alo, a);
1219+
splitbits(&bhi, &blo, b);
1220+
splitbits(&blohi, &blolo, blo);
1221+
*abhi = a*b;
1222+
*ablo = alo*blohi - (((*abhi - ahi*bhi) - alo*bhi) - ahi*blo) + blolo*alo;
1223+
}
1224+
// Base.issubnormal(::Float64) (Win32's fpclassify seems broken)
1225+
int issubnormal(double d) {
1226+
uint64_t y = bitcast_d2u(d);
1227+
return ((y & 0x7ff0000000000000) == 0) & ((y & 0x000fffffffffffff) != 0);
1228+
}
1229+
#if defined(_WIN32)
1230+
// Win32 needs volatile (avoid over optimization?)
1231+
#define VDOUBLE volatile double
1232+
#else
1233+
#define VDOUBLE double
1234+
#endif
1235+
1236+
// Base.fma_emulated(::Float64, ::Float64, ::Float64)
1237+
double julia_fma(double a, double b, double c) {
1238+
double abhi, ablo, r, s;
1239+
two_mul(&abhi, &ablo, a, b);
1240+
if (!isfinite(abhi+c) || fabs(abhi) < 2.0041683600089732e-292 ||
1241+
issubnormal(a) || issubnormal(b)) {
1242+
int aandbfinite = isfinite(a) && isfinite(b);
1243+
if (!(aandbfinite && isfinite(c)))
1244+
return aandbfinite ? c : abhi+c;
1245+
if (a == 0 || b == 0)
1246+
return abhi+c;
1247+
int bias = exponent(a) + exponent(b);
1248+
VDOUBLE c_denorm = ldexp(c, -bias);
1249+
if (isfinite(c_denorm)) {
1250+
if (issubnormal(a))
1251+
a *= 4.503599627370496e15;
1252+
if (issubnormal(b))
1253+
b *= 4.503599627370496e15;
1254+
a = bitcast_u2d((bitcast_d2u(a) & 0x800fffffffffffff) | 0x3ff0000000000000);
1255+
b = bitcast_u2d((bitcast_d2u(b) & 0x800fffffffffffff) | 0x3ff0000000000000);
1256+
c = c_denorm;
1257+
two_mul(&abhi, &ablo, a, b);
1258+
r = abhi+c;
1259+
s = (fabs(abhi) > fabs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo);
1260+
double sumhi = r+s;
1261+
if (issubnormal(ldexp(sumhi, bias))) {
1262+
double sumlo = r-sumhi+s;
1263+
int bits_lost = -bias-exponent(sumhi)-1022;
1264+
if ((bits_lost != 1) ^ ((bitcast_d2u(sumhi)&1) == 1))
1265+
if (sumlo != 0)
1266+
sumhi = nextafter(sumhi, copysign(1.0/0.0, sumlo));
1267+
}
1268+
return ldexp(sumhi, bias);
1269+
}
1270+
if (isinf(abhi) && signbit(c) == signbit(a*b))
1271+
return abhi;
1272+
}
1273+
r = abhi+c;
1274+
s = (fabs(abhi) > fabs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo);
1275+
return r+s;
1276+
}
1277+
#define fma(a, b, c) \
1278+
sizeof(a) == sizeof(float) ? julia_fmaf(a, b, c) : julia_fma(a, b, c)
1279+
#else // On other systems use fma(f) directly
11781280
#define fma(a, b, c) \
11791281
sizeof(a) == sizeof(float) ? fmaf(a, b, c) : fma(a, b, c)
1282+
#endif
1283+
11801284
#define muladd(a, b, c) a * b + c
11811285
ter_fintrinsic(fma,fma_float)
11821286
ter_fintrinsic(muladd,muladd_float)

test/math.jl

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,29 +1289,31 @@ end
12891289
end
12901290

12911291
@testset "fma" begin
1292-
if !(@static Sys.iswindows() && Int===Int64) # windows fma currently seems broken somehow.
1293-
for func in (fma, Base.fma_emulated)
1294-
@test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16
1295-
@test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7
1296-
@testset "$T" for T in (Float32, Float64)
1297-
@test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T)
1298-
@test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf)
1299-
@test func(T(Inf), T(Inf), T(Inf)) === T(Inf)
1300-
@test func(floatmax(T), floatmax(T), -T(Inf)) === -T(Inf)
1301-
@test func(floatmax(T), -floatmax(T), T(Inf)) === T(Inf)
1302-
@test isnan_type(T, func(T(Inf), T(1), -T(Inf)))
1303-
@test isnan_type(T, func(T(Inf), T(0), -T(0)))
1304-
@test func(-zero(T), zero(T), -zero(T)) === -zero(T)
1305-
for _ in 1:2^18
1306-
a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3))
1307-
@test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c)
1308-
end
1292+
fma_list = (fma, Base.fma_emulated)
1293+
if !(Sys.islinux() && Int == Int32) # test runtime fma (skip linux32)
1294+
fma_list = (fma_list..., Base.fma_float)
1295+
end
1296+
for func in fma_list
1297+
@test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16
1298+
@test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7
1299+
@testset "$T" for T in (Float32, Float64)
1300+
@test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T)
1301+
@test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf)
1302+
@test func(T(Inf), T(Inf), T(Inf)) === T(Inf)
1303+
@test func(floatmax(T), floatmax(T), -T(Inf)) === -T(Inf)
1304+
@test func(floatmax(T), -floatmax(T), T(Inf)) === T(Inf)
1305+
@test isnan_type(T, func(T(Inf), T(1), -T(Inf)))
1306+
@test isnan_type(T, func(T(Inf), T(0), -T(0)))
1307+
@test func(-zero(T), zero(T), -zero(T)) === -zero(T)
1308+
for _ in 1:2^18
1309+
a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3))
1310+
@test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c)
13091311
end
1310-
@test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292
1311-
@test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31
1312-
@test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf
1313-
@test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf
1314-
@test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6
13151312
end
1313+
@test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292
1314+
@test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31
1315+
@test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf
1316+
@test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf
1317+
@test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6
13161318
end
13171319
end

0 commit comments

Comments
 (0)