Skip to content

Commit 892de95

Browse files
Pure Julia implementation of OpenLibm's lgamma, lgamma_r
Why bother with this port? The short answer is that for `Float64`, there are considerable performance gains to be had. (See https://github.com/andrewjradcliffe/OpenLibmPorts.jl for plots produced on my machine -- I encourage you to reproduce them for yourself). The long answer can be found in the linked package. To summarize for those unfamiliar with large probabilistic models, log-posterior computations often involve many unavoidable (i.e. cannot be eliminated by using the unnormalized density) special function calls; in particular, log-likelihoods involving the (log)gamma function are not uncommon, e.g. negative binomial, beta-binomial. Given such a likelihood function, each term in the log-likelihood sum would in necessitate a loggamma call, hence, any reduction in the latency in this particular special function translates can have a substantial impact. It is also a convenience to have a loggamma function which is differentiable using `Enzyme` (though, there are some gaps in the derivative -- at exactly 0.0, 1.0 and 2.0 -- but that is just a quirk of applying AD to OpenLibm's loggamma implementation). The author realizes that many things depend on `loggamma`, and my intention is not to cause headaches for others. Fortunately, by directly porting the OpenLibm implementation to Julia, we achieve the same approximation, albeit, with slightly different rounding due, presumably, to the difference between the implementations of `log` in Julia and OpenLibm.
1 parent 36c547b commit 892de95

File tree

3 files changed

+733
-12
lines changed

3 files changed

+733
-12
lines changed

src/e_lgamma_r.jl

Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
#=
2+
/* @(#)e_lgamma_r.c 1.3 95/01/18 */
3+
/*
4+
* ====================================================
5+
* Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
6+
*
7+
* Developed at SunSoft, a Sun Microsystems, Inc. business.
8+
* Permission to use, copy, modify, and distribute this
9+
* software is freely granted, provided that this notice
10+
* is preserved.
11+
* ====================================================
12+
*
13+
*/
14+
=#
15+
16+
#=
17+
18+
/* __ieee754_lgamma_r(x, signgamp)
19+
* Reentrant version of the logarithm of the Gamma function
20+
* with user provide pointer for the sign of Gamma(x).
21+
*
22+
* Method:
23+
* 1. Argument Reduction for 0 < x <= 8
24+
* Since gamma(1+s)=s*gamma(s), for x in [0,8], we may
25+
* reduce x to a number in [1.5,2.5] by
26+
* lgamma(1+s) = log(s) + lgamma(s)
27+
* for example,
28+
* lgamma(7.3) = log(6.3) + lgamma(6.3)
29+
* = log(6.3*5.3) + lgamma(5.3)
30+
* = log(6.3*5.3*4.3*3.3*2.3) + lgamma(2.3)
31+
* 2. Polynomial approximation of lgamma around its
32+
* minimun ymin=1.461632144968362245 to maintain monotonicity.
33+
* On [ymin-0.23, ymin+0.27] (i.e., [1.23164,1.73163]), use
34+
* Let z = x-ymin;
35+
* lgamma(x) = -1.214862905358496078218 + z^2*poly(z)
36+
* where
37+
* poly(z) is a 14 degree polynomial.
38+
* 2. Rational approximation in the primary interval [2,3]
39+
* We use the following approximation:
40+
* s = x-2.0;
41+
* lgamma(x) = 0.5*s + s*P(s)/Q(s)
42+
* with accuracy
43+
* |P/Q - (lgamma(x)-0.5s)| < 2**-61.71
44+
* Our algorithms are based on the following observation
45+
*
46+
* zeta(2)-1 2 zeta(3)-1 3
47+
* lgamma(2+s) = s*(1-Euler) + --------- * s - --------- * s + ...
48+
* 2 3
49+
*
50+
* where Euler = 0.5771... is the Euler constant, which is very
51+
* close to 0.5.
52+
*
53+
* 3. For x>=8, we have
54+
* lgamma(x)~(x-0.5)log(x)-x+0.5*log(2pi)+1/(12x)-1/(360x**3)+....
55+
* (better formula:
56+
* lgamma(x)~(x-0.5)*(log(x)-1)-.5*(log(2pi)-1) + ...)
57+
* Let z = 1/x, then we approximation
58+
* f(z) = lgamma(x) - (x-0.5)(log(x)-1)
59+
* by
60+
* 3 5 11
61+
* w = w0 + w1*z + w2*z + w3*z + ... + w6*z
62+
* where
63+
* |w - f(z)| < 2**-58.74
64+
*
65+
* 4. For negative x, since (G is gamma function)
66+
* -x*G(-x)*G(x) = pi/sin(pi*x),
67+
* we have
68+
* G(x) = pi/(sin(pi*x)*(-x)*G(-x))
69+
* since G(-x) is positive, sign(G(x)) = sign(sin(pi*x)) for x<0
70+
* Hence, for x<0, signgam = sign(sin(pi*x)) and
71+
* lgamma(x) = log(|Gamma(x)|)
72+
* = log(pi/(|x*sin(pi*x)|)) - lgamma(-x);
73+
* Note: one should avoid compute pi*(-x) directly in the
74+
* computation of sin(pi*(-x)).
75+
*
76+
* 5. Special Cases
77+
* lgamma(2+s) ~ s*(1-Euler) for tiny s
78+
* lgamma(1) = lgamma(2) = 0
79+
* lgamma(x) ~ -log(|x|) for tiny x
80+
* lgamma(0) = lgamma(neg.integer) = inf and raise divide-by-zero
81+
* lgamma(inf) = inf
82+
* lgamma(-inf) = inf (bug for bug compatible with C99!?)
83+
*
84+
*/
85+
86+
=#
87+
88+
const a0 = 7.72156649015328655494e-02 #= 0x3FB3C467, 0xE37DB0C8 =#
89+
const a1 = 3.22467033424113591611e-01 #= 0x3FD4A34C, 0xC4A60FAD =#
90+
const a2 = 6.73523010531292681824e-02 #= 0x3FB13E00, 0x1A5562A7 =#
91+
const a3 = 2.05808084325167332806e-02 #= 0x3F951322, 0xAC92547B =#
92+
const a4 = 7.38555086081402883957e-03 #= 0x3F7E404F, 0xB68FEFE8 =#
93+
const a5 = 2.89051383673415629091e-03 #= 0x3F67ADD8, 0xCCB7926B =#
94+
const a6 = 1.19270763183362067845e-03 #= 0x3F538A94, 0x116F3F5D =#
95+
const a7 = 5.10069792153511336608e-04 #= 0x3F40B6C6, 0x89B99C00 =#
96+
const a8 = 2.20862790713908385557e-04 #= 0x3F2CF2EC, 0xED10E54D =#
97+
const a9 = 1.08011567247583939954e-04 #= 0x3F1C5088, 0x987DFB07 =#
98+
const a10 = 2.52144565451257326939e-05 #= 0x3EFA7074, 0x428CFA52 =#
99+
const a11 = 4.48640949618915160150e-05 #= 0x3F07858E, 0x90A45837 =#
100+
const tc = 1.46163214496836224576e+00 #= 0x3FF762D8, 0x6356BE3F =#
101+
const tf = -1.21486290535849611461e-01 #= 0xBFBF19B9, 0xBCC38A42 =#
102+
#= tt = -(tail of tf) =#
103+
const tt = -3.63867699703950536541e-18 #= 0xBC50C7CA, 0xA48A971F =#
104+
const t0 = 4.83836122723810047042e-01 #= 0x3FDEF72B, 0xC8EE38A2 =#
105+
const t1 = -1.47587722994593911752e-01 #= 0xBFC2E427, 0x8DC6C509 =#
106+
const t2 = 6.46249402391333854778e-02 #= 0x3FB08B42, 0x94D5419B =#
107+
const t3 = -3.27885410759859649565e-02 #= 0xBFA0C9A8, 0xDF35B713 =#
108+
const t4 = 1.79706750811820387126e-02 #= 0x3F9266E7, 0x970AF9EC =#
109+
const t5 = -1.03142241298341437450e-02 #= 0xBF851F9F, 0xBA91EC6A =#
110+
const t6 = 6.10053870246291332635e-03 #= 0x3F78FCE0, 0xE370E344 =#
111+
const t7 = -3.68452016781138256760e-03 #= 0xBF6E2EFF, 0xB3E914D7 =#
112+
const t8 = 2.25964780900612472250e-03 #= 0x3F6282D3, 0x2E15C915 =#
113+
const t9 = -1.40346469989232843813e-03 #= 0xBF56FE8E, 0xBF2D1AF1 =#
114+
const t10 = 8.81081882437654011382e-04 #= 0x3F4CDF0C, 0xEF61A8E9 =#
115+
const t11 = -5.38595305356740546715e-04 #= 0xBF41A610, 0x9C73E0EC =#
116+
const t12 = 3.15632070903625950361e-04 #= 0x3F34AF6D, 0x6C0EBBF7 =#
117+
const t13 = -3.12754168375120860518e-04 #= 0xBF347F24, 0xECC38C38 =#
118+
const t14 = 3.35529192635519073543e-04 #= 0x3F35FD3E, 0xE8C2D3F4 =#
119+
const u0 = -7.72156649015328655494e-02 #= 0xBFB3C467, 0xE37DB0C8 =#
120+
const u1 = 6.32827064025093366517e-01 #= 0x3FE4401E, 0x8B005DFF =#
121+
const u2 = 1.45492250137234768737e+00 #= 0x3FF7475C, 0xD119BD6F =#
122+
const u3 = 9.77717527963372745603e-01 #= 0x3FEF4976, 0x44EA8450 =#
123+
const u4 = 2.28963728064692451092e-01 #= 0x3FCD4EAE, 0xF6010924 =#
124+
const u5 = 1.33810918536787660377e-02 #= 0x3F8B678B, 0xBF2BAB09 =#
125+
const v1 = 2.45597793713041134822e+00 #= 0x4003A5D7, 0xC2BD619C =#
126+
const v2 = 2.12848976379893395361e+00 #= 0x40010725, 0xA42B18F5 =#
127+
const v3 = 7.69285150456672783825e-01 #= 0x3FE89DFB, 0xE45050AF =#
128+
const v4 = 1.04222645593369134254e-01 #= 0x3FBAAE55, 0xD6537C88 =#
129+
const v5 = 3.21709242282423911810e-03 #= 0x3F6A5ABB, 0x57D0CF61 =#
130+
const s0 = -7.72156649015328655494e-02 #= 0xBFB3C467, 0xE37DB0C8 =#
131+
const s1 = 2.14982415960608852501e-01 #= 0x3FCB848B, 0x36E20878 =#
132+
const s2 = 3.25778796408930981787e-01 #= 0x3FD4D98F, 0x4F139F59 =#
133+
const s3 = 1.46350472652464452805e-01 #= 0x3FC2BB9C, 0xBEE5F2F7 =#
134+
const s4 = 2.66422703033638609560e-02 #= 0x3F9B481C, 0x7E939961 =#
135+
const s5 = 1.84028451407337715652e-03 #= 0x3F5E26B6, 0x7368F239 =#
136+
const s6 = 3.19475326584100867617e-05 #= 0x3F00BFEC, 0xDD17E945 =#
137+
const r1 = 1.39200533467621045958e+00 #= 0x3FF645A7, 0x62C4AB74 =#
138+
const r2 = 7.21935547567138069525e-01 #= 0x3FE71A18, 0x93D3DCDC =#
139+
const r3 = 1.71933865632803078993e-01 #= 0x3FC601ED, 0xCCFBDF27 =#
140+
const r4 = 1.86459191715652901344e-02 #= 0x3F9317EA, 0x742ED475 =#
141+
const r5 = 7.77942496381893596434e-04 #= 0x3F497DDA, 0xCA41A95B =#
142+
const r6 = 7.32668430744625636189e-06 #= 0x3EDEBAF7, 0xA5B38140 =#
143+
const w0 = 4.18938533204672725052e-01 #= 0x3FDACFE3, 0x90C97D69 =#
144+
const w1 = 8.33333333333329678849e-02 #= 0x3FB55555, 0x5555553B =#
145+
const w2 = -2.77777777728775536470e-03 #= 0xBF66C16C, 0x16B02E5C =#
146+
const w3 = 7.93650558643019558500e-04 #= 0x3F4A019F, 0x98CF38B6 =#
147+
const w4 = -5.95187557450339963135e-04 #= 0xBF4380CB, 0x8C0FE741 =#
148+
const w5 = 8.36339918996282139126e-04 #= 0x3F4B67BA, 0x4CDAD5D1 =#
149+
const w6 = -1.63092934096575273989e-03 #= 0xBF5AB89D, 0x0B9E43E4 =#
150+
151+
# Matches OpenLibm behavior exactly, including return of sign
152+
function _lgamma_r(x::Float64)
153+
u = reinterpret(UInt64, x)
154+
hx = (u >>> 32) % Int32
155+
lx = u % Int32
156+
157+
#= purge off +-inf, NaN, +-0, tiny and negative arguments =#
158+
signgamp = Int32(1)
159+
ix = signed(hx & 0x7fffffff)
160+
ix 0x7ff00000 && return x * x, signgamp
161+
ix | lx == 0 && return 1.0 / 0.0, signgamp
162+
if ix < 0x3b900000 #= |x|<2**-70, return -log(|x|) =#
163+
if hx < 0
164+
signgamp = Int32(-1)
165+
return -log(-x), signgamp
166+
else
167+
return -log(x), signgamp
168+
end
169+
end
170+
if hx < 0
171+
ix 0x43300000 && return 1.0 / 0.0, signgamp #= |x|>=2**52, must be -integer =#
172+
t = sinpi(x)
173+
t == 0.0 && return 1.0 / 0.0, signgamp #= -integer =#
174+
nadj = log/ abs(t * x))
175+
if t < 0.0; signgamp = Int32(-1); end
176+
x = -x
177+
end
178+
179+
#= purge off 1 and 2 =#
180+
if ((ix - 0x3ff00000) | lx) == 0 || ((ix - 0x40000000) | lx) == 0
181+
r = 0.0
182+
#= for x < 2.0 =#
183+
elseif ix < 0x40000000
184+
if ix 0x3feccccc #= lgamma(x) = lgamma(x+1)-log(x) =#
185+
r = -log(x)
186+
if ix 0x3FE76944
187+
y = 1.0 - x
188+
i = Int8(0)
189+
elseif ix 0x3FCDA661
190+
y = x - (tc - 1.0)
191+
i = Int8(1)
192+
else
193+
y = x
194+
i = Int8(2)
195+
end
196+
else
197+
r = 0.0
198+
if ix 0x3FFBB4C3 #= [1.7316,2] =#
199+
y = 2.0 - x
200+
i = Int8(0)
201+
elseif ix 0x3FF3B4C4 #= [1.23,1.73] =#
202+
y = x - tc
203+
i = Int8(1)
204+
else
205+
y = x - 1.0
206+
i = Int8(2)
207+
end
208+
end
209+
if i == Int8(0)
210+
z = y*y;
211+
p1 = a0+z*(a2+z*(a4+z*(a6+z*(a8+z*a10))));
212+
p2 = z*(a1+z*(a3+z*(a5+z*(a7+z*(a9+z*a11)))));
213+
p = y*p1+p2;
214+
r += (p-0.5*y);
215+
elseif i == Int8(1)
216+
z = y*y;
217+
w = z*y;
218+
p1 = t0+w*(t3+w*(t6+w*(t9 +w*t12))); #= parallel comp =#
219+
p2 = t1+w*(t4+w*(t7+w*(t10+w*t13)));
220+
p3 = t2+w*(t5+w*(t8+w*(t11+w*t14)));
221+
p = z*p1-(tt-w*(p2+y*p3));
222+
r += (tf + p)
223+
elseif i == Int8(2)
224+
p1 = y*(u0+y*(u1+y*(u2+y*(u3+y*(u4+y*u5)))));
225+
p2 = 1.0+y*(v1+y*(v2+y*(v3+y*(v4+y*v5))));
226+
r += (-0.5*y + p1/p2);
227+
end
228+
elseif ix < 0x40200000 #= x < 8.0 =#
229+
i = Base.unsafe_trunc(Int8, x)
230+
y = x - float(i)
231+
# If performed here, performance is 2x worse; hence, move it below.
232+
# p = y*(s0+y*(s1+y*(s2+y*(s3+y*(s4+y*(s5+y*s6))))));
233+
# q = 1.0+y*(r1+y*(r2+y*(r3+y*(r4+y*(r5+y*r6)))));
234+
# r = 0.5*y+p/q;
235+
z = 1.0; #= lgamma(1+s) = log(s) + lgamma(s) =#
236+
if i == Int8(7)
237+
z *= (y + 6.0)
238+
@goto case6
239+
elseif i == Int8(6)
240+
@label case6
241+
z *= (y + 5.0)
242+
@goto case5
243+
elseif i == Int8(5)
244+
@label case5
245+
z *= (y + 4.0)
246+
@goto case4
247+
elseif i == Int8(4)
248+
@label case4
249+
z *= (y + 3.0)
250+
@goto case3
251+
elseif i == Int8(3)
252+
@label case3
253+
z *= (y + 2.0)
254+
end
255+
# r += log(z)
256+
p = y*(s0+y*(s1+y*(s2+y*(s3+y*(s4+y*(s5+y*s6))))));
257+
q = 1.0+y*(r1+y*(r2+y*(r3+y*(r4+y*(r5+y*r6)))));
258+
r = log(z) + 0.5*y+p/q;
259+
#= 8.0 ≤ x < 2^58 =#
260+
elseif ix < 0x43900000
261+
t = log(x)
262+
z = 1.0 / x
263+
y = z * z
264+
w = w0+z*(w1+y*(w2+y*(w3+y*(w4+y*(w5+y*w6)))));
265+
r = (x-0.5)*(t-1.0)+w;
266+
else
267+
#= 2^58 ≤ x ≤ Inf =#
268+
r = x * (log(x) - 1.0)
269+
end
270+
if hx < 0
271+
r = nadj - r
272+
end
273+
return r, signgamp
274+
end
275+
276+
# Deviates from OpenLibm: throws instead of returning negative sign; approximately 25% faster
277+
# when sign is not needed in subsequent computations.
278+
function _loggamma_r(x::Float64)
279+
u = reinterpret(UInt64, x)
280+
hx = (u >>> 32) % Int32
281+
lx = u % Int32
282+
283+
#= purge off +-inf, NaN, +-0, tiny and negative arguments =#
284+
ix = signed(hx & 0x7fffffff)
285+
ix 0x7ff00000 && return x * x
286+
ix | lx == 0 && return 1.0 / 0.0
287+
if ix < 0x3b900000 #= |x|<2**-70, return -log(|x|) =#
288+
if hx < 0
289+
# return -log(-x)
290+
throw(DomainError(x, "`gamma(x)` must be non-negative"))
291+
else
292+
return -log(x)
293+
end
294+
end
295+
if hx < 0
296+
ix 0x43300000 && return 1.0 / 0.0 #= |x|>=2**52, must be -integer =#
297+
t = sinpi(x)
298+
t == 0.0 && return 1.0 / 0.0 #= -integer =#
299+
nadj = log/ abs(t * x))
300+
t < 0.0 && throw(DomainError(x, "`gamma(x)` must be non-negative"))
301+
x = -x
302+
end
303+
304+
#= purge off 1 and 2 =#
305+
if ((ix - 0x3ff00000) | lx) == 0 || ((ix - 0x40000000) | lx) == 0
306+
r = 0.0
307+
#= for x < 2.0 =#
308+
elseif ix < 0x40000000
309+
if ix 0x3feccccc #= lgamma(x) = lgamma(x+1)-log(x) =#
310+
r = -log(x)
311+
if ix 0x3FE76944
312+
y = 1.0 - x
313+
i = Int8(0)
314+
elseif ix 0x3FCDA661
315+
y = x - (tc - 1.0)
316+
i = Int8(1)
317+
else
318+
y = x
319+
i = Int8(2)
320+
end
321+
else
322+
r = 0.0
323+
if ix 0x3FFBB4C3 #= [1.7316,2] =#
324+
y = 2.0 - x
325+
i = Int8(0)
326+
elseif ix 0x3FF3B4C4 #= [1.23,1.73] =#
327+
y = x - tc
328+
i = Int8(1)
329+
else
330+
y = x - 1.0
331+
i = Int8(2)
332+
end
333+
end
334+
if i == Int8(0)
335+
z = y*y;
336+
p1 = a0+z*(a2+z*(a4+z*(a6+z*(a8+z*a10))));
337+
p2 = z*(a1+z*(a3+z*(a5+z*(a7+z*(a9+z*a11)))));
338+
p = y*p1+p2;
339+
r += (p-0.5*y);
340+
elseif i == Int8(1)
341+
z = y*y;
342+
w = z*y;
343+
p1 = t0+w*(t3+w*(t6+w*(t9 +w*t12))); #= parallel comp =#
344+
p2 = t1+w*(t4+w*(t7+w*(t10+w*t13)));
345+
p3 = t2+w*(t5+w*(t8+w*(t11+w*t14)));
346+
p = z*p1-(tt-w*(p2+y*p3));
347+
r += (tf + p)
348+
elseif i == Int8(2)
349+
p1 = y*(u0+y*(u1+y*(u2+y*(u3+y*(u4+y*u5)))));
350+
p2 = 1.0+y*(v1+y*(v2+y*(v3+y*(v4+y*v5))));
351+
r += (-0.5*y + p1/p2);
352+
end
353+
elseif ix < 0x40200000 #= x < 8.0 =#
354+
i = Base.unsafe_trunc(Int8, x)
355+
y = x - float(i)
356+
# If performed here, performance is 2x worse; hence, move it below.
357+
# p = y*(s0+y*(s1+y*(s2+y*(s3+y*(s4+y*(s5+y*s6))))));
358+
# q = 1.0+y*(r1+y*(r2+y*(r3+y*(r4+y*(r5+y*r6)))));
359+
# r = 0.5*y+p/q;
360+
z = 1.0; #= lgamma(1+s) = log(s) + lgamma(s) =#
361+
if i == Int8(7)
362+
z *= (y + 6.0)
363+
@goto case6
364+
elseif i == Int8(6)
365+
@label case6
366+
z *= (y + 5.0)
367+
@goto case5
368+
elseif i == Int8(5)
369+
@label case5
370+
z *= (y + 4.0)
371+
@goto case4
372+
elseif i == Int8(4)
373+
@label case4
374+
z *= (y + 3.0)
375+
@goto case3
376+
elseif i == Int8(3)
377+
@label case3
378+
z *= (y + 2.0)
379+
end
380+
# r += log(z)
381+
p = y*(s0+y*(s1+y*(s2+y*(s3+y*(s4+y*(s5+y*s6))))));
382+
q = 1.0+y*(r1+y*(r2+y*(r3+y*(r4+y*(r5+y*r6)))));
383+
r = log(z) + 0.5*y+p/q;
384+
#= 8.0 ≤ x < 2^58 =#
385+
elseif ix < 0x43900000
386+
t = log(x)
387+
z = 1.0 / x
388+
y = z * z
389+
w = w0+z*(w1+y*(w2+y*(w3+y*(w4+y*(w5+y*w6)))));
390+
r = (x-0.5)*(t-1.0)+w;
391+
else
392+
#= 2^58 ≤ x ≤ Inf =#
393+
r = x * (log(x) - 1.0)
394+
end
395+
if hx < 0
396+
r = nadj - r
397+
end
398+
return r
399+
end

0 commit comments

Comments
 (0)