Skip to content

Commit 84d89a7

Browse files
committed
Always calculate log_gamma using f64
1 parent 15b9a39 commit 84d89a7

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

rand_distr/src/poisson.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
1212
use rand::Rng;
1313
use crate::{Distribution, Cauchy, Standard};
14-
use crate::utils::{log_gamma, Float};
14+
use crate::utils::Float;
1515

1616
/// The Poisson distribution `Poisson(lambda)`.
1717
///
@@ -59,7 +59,7 @@ where Standard: Distribution<N>
5959
exp_lambda: (-lambda).exp(),
6060
log_lambda,
6161
sqrt_2lambda: (N::from(2.0) * lambda).sqrt(),
62-
magic_val: lambda * log_lambda - log_gamma(N::from(1.0) + lambda),
62+
magic_val: lambda * log_lambda - (N::from(1.0) + lambda).log_gamma(),
6363
})
6464
}
6565
}
@@ -109,7 +109,7 @@ where Standard: Distribution<N>
109109
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
110110
// this doesn't change the resulting distribution, only increases the rate of failed drawings
111111
let check = N::from(0.9) * (N::from(1.0) + comp_dev * comp_dev)
112-
* (result * self.log_lambda - log_gamma(N::from(1.0) + result) - self.magic_val).exp();
112+
* (result * self.log_lambda - (N::from(1.0) + result).log_gamma() - self.magic_val).exp();
113113

114114
// check with uniform random value - if below the threshold, we are within the target distribution
115115
if rng.gen::<N>() <= check {

rand_distr/src/utils.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ pub trait Float: Copy + Sized + cmp::PartialOrd
5252

5353
/// Take the tangent of self
5454
fn tan(self) -> Self;
55+
/// Take the logarithm of the gamma function of self
56+
fn log_gamma(self) -> Self;
5557
}
5658

5759
impl Float for f32 {
@@ -84,6 +86,13 @@ impl Float for f32 {
8486

8587
#[inline]
8688
fn tan(self) -> Self { self.tan() }
89+
#[inline]
90+
fn log_gamma(self) -> Self {
91+
let result = log_gamma(self as f64);
92+
assert!(result <= ::core::f32::MAX as f64);
93+
assert!(result >= ::core::f32::MIN as f64);
94+
result as f32
95+
}
8796
}
8897

8998
impl Float for f64 {
@@ -116,6 +125,8 @@ impl Float for f64 {
116125

117126
#[inline]
118127
fn tan(self) -> Self { self.tan() }
128+
#[inline]
129+
fn log_gamma(self) -> Self { log_gamma(self) }
119130
}
120131

121132
/// Calculates ln(gamma(x)) (natural logarithm of the gamma
@@ -131,33 +142,33 @@ impl Float for f64 {
131142
/// `Ag(z)` is an infinite series with coefficients that can be calculated
132143
/// ahead of time - we use just the first 6 terms, which is good enough
133144
/// for most purposes.
134-
pub(crate) fn log_gamma<N: Float>(x: N) -> N {
145+
pub(crate) fn log_gamma(x: f64) -> f64 {
135146
// precalculated 6 coefficients for the first 6 terms of the series
136-
let coefficients: [N; 6] = [
137-
N::from(76.18009172947146),
138-
N::from(-86.50532032941677),
139-
N::from(24.01409824083091),
140-
N::from(-1.231739572450155),
141-
N::from(0.1208650973866179e-2),
142-
N::from(-0.5395239384953e-5),
147+
let coefficients: [f64; 6] = [
148+
76.18009172947146,
149+
-86.50532032941677,
150+
24.01409824083091,
151+
-1.231739572450155,
152+
0.1208650973866179e-2,
153+
-0.5395239384953e-5,
143154
];
144155

145156
// (x+0.5)*ln(x+g+0.5)-(x+g+0.5)
146-
let tmp = x + N::from(5.5);
147-
let log = (x + N::from(0.5)) * tmp.ln() - tmp;
157+
let tmp = x + 5.5;
158+
let log = (x + 0.5) * tmp.ln() - tmp;
148159

149160
// the first few terms of the series for Ag(x)
150-
let mut a = N::from(1.000000000190015);
161+
let mut a = 1.000000000190015;
151162
let mut denom = x;
152163
for &coeff in &coefficients {
153-
denom += N::from(1.0);
164+
denom += 1.0;
154165
a += coeff / denom;
155166
}
156167

157168
// get everything together
158169
// a is Ag(x)
159170
// 2.5066... is sqrt(2pi)
160-
log + (N::from(2.5066282746310005) * a / x).ln()
171+
log + (2.5066282746310005 * a / x).ln()
161172
}
162173

163174
/// Sample a random number using the Ziggurat method (specifically the

0 commit comments

Comments
 (0)