Skip to content

Commit c0b8722

Browse files
authored
Merge pull request #795 from vks/fp-poisson
Poisson: Fix undefined behavior and support f64 output
2 parents 4ef40b6 + ec99801 commit c0b8722

File tree

2 files changed

+161
-43
lines changed

2 files changed

+161
-43
lines changed

rand_distr/src/poisson.rs

Lines changed: 109 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
//! The Poisson distribution.
1111
1212
use rand::Rng;
13-
use crate::{Distribution, Cauchy};
14-
use crate::utils::log_gamma;
13+
use crate::{Distribution, Cauchy, Standard};
14+
use crate::utils::Float;
1515

1616
/// The Poisson distribution `Poisson(lambda)`.
1717
///
@@ -24,17 +24,17 @@ use crate::utils::log_gamma;
2424
/// use rand_distr::{Poisson, Distribution};
2525
///
2626
/// let poi = Poisson::new(2.0).unwrap();
27-
/// let v = poi.sample(&mut rand::thread_rng());
27+
/// let v: u64 = poi.sample(&mut rand::thread_rng());
2828
/// println!("{} is from a Poisson(2) distribution", v);
2929
/// ```
3030
#[derive(Clone, Copy, Debug)]
31-
pub struct Poisson {
32-
lambda: f64,
31+
pub struct Poisson<N> {
32+
lambda: N,
3333
// precalculated values
34-
exp_lambda: f64,
35-
log_lambda: f64,
36-
sqrt_2lambda: f64,
37-
magic_val: f64,
34+
exp_lambda: N,
35+
log_lambda: N,
36+
sqrt_2lambda: N,
37+
magic_val: N,
3838
}
3939

4040
/// Error type returned from `Poisson::new`.
@@ -44,48 +44,51 @@ pub enum Error {
4444
ShapeTooSmall,
4545
}
4646

47-
impl Poisson {
47+
impl<N: Float> Poisson<N>
48+
where Standard: Distribution<N>
49+
{
4850
/// Construct a new `Poisson` with the given shape parameter
4951
/// `lambda`.
50-
pub fn new(lambda: f64) -> Result<Poisson, Error> {
51-
if !(lambda > 0.0) {
52+
pub fn new(lambda: N) -> Result<Poisson<N>, Error> {
53+
if !(lambda > N::from(0.0)) {
5254
return Err(Error::ShapeTooSmall);
5355
}
5456
let log_lambda = lambda.ln();
5557
Ok(Poisson {
5658
lambda,
5759
exp_lambda: (-lambda).exp(),
5860
log_lambda,
59-
sqrt_2lambda: (2.0 * lambda).sqrt(),
60-
magic_val: lambda * log_lambda - log_gamma(1.0 + lambda),
61+
sqrt_2lambda: (N::from(2.0) * lambda).sqrt(),
62+
magic_val: lambda * log_lambda - (N::from(1.0) + lambda).log_gamma(),
6163
})
6264
}
6365
}
6466

65-
impl Distribution<u64> for Poisson {
66-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
67+
impl<N: Float> Distribution<N> for Poisson<N>
68+
where Standard: Distribution<N>
69+
{
70+
#[inline]
71+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
6772
// using the algorithm from Numerical Recipes in C
6873

6974
// for low expected values use the Knuth method
70-
if self.lambda < 12.0 {
71-
let mut result = 0;
72-
let mut p = 1.0;
75+
if self.lambda < N::from(12.0) {
76+
let mut result = N::from(0.);
77+
let mut p = N::from(1.0);
7378
while p > self.exp_lambda {
74-
p *= rng.gen::<f64>();
75-
result += 1;
79+
p *= rng.gen::<N>();
80+
result += N::from(1.);
7681
}
77-
result - 1
82+
result - N::from(1.)
7883
}
7984
// high expected values - rejection method
8085
else {
81-
let mut int_result: u64;
82-
8386
// we use the Cauchy distribution as the comparison distribution
8487
// f(x) ~ 1/(1+x^2)
85-
let cauchy = Cauchy::new(0.0, 1.0).unwrap();
88+
let cauchy = Cauchy::new(N::from(0.0), N::from(1.0)).unwrap();
89+
let mut result;
8690

8791
loop {
88-
let mut result;
8992
let mut comp_dev;
9093

9194
loop {
@@ -94,32 +97,41 @@ impl Distribution<u64> for Poisson {
9497
// shift the peak of the comparison ditribution
9598
result = self.sqrt_2lambda * comp_dev + self.lambda;
9699
// repeat the drawing until we are in the range of possible values
97-
if result >= 0.0 {
100+
if result >= N::from(0.0) {
98101
break;
99102
}
100103
}
101104
// now the result is a random variable greater than 0 with Cauchy distribution
102105
// the result should be an integer value
103106
result = result.floor();
104-
int_result = result as u64;
105107

106108
// this is the ratio of the Poisson distribution to the comparison distribution
107109
// the magic value scales the distribution function to a range of approximately 0-1
108110
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
109111
// this doesn't change the resulting distribution, only increases the rate of failed drawings
110-
let check = 0.9 * (1.0 + comp_dev * comp_dev)
111-
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp();
112+
let check = N::from(0.9) * (N::from(1.0) + comp_dev * comp_dev)
113+
* (result * self.log_lambda - (N::from(1.0) + result).log_gamma() - self.magic_val).exp();
112114

113115
// check with uniform random value - if below the threshold, we are within the target distribution
114-
if rng.gen::<f64>() <= check {
116+
if rng.gen::<N>() <= check {
115117
break;
116118
}
117119
}
118-
int_result
120+
result
119121
}
120122
}
121123
}
122124

125+
impl<N: Float> Distribution<u64> for Poisson<N>
126+
where Standard: Distribution<N>
127+
{
128+
#[inline]
129+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
130+
let result: N = self.sample(rng);
131+
result.to_u64().unwrap()
132+
}
133+
}
134+
123135
#[cfg(test)]
124136
mod test {
125137
use crate::Distribution;
@@ -129,27 +141,82 @@ mod test {
129141
fn test_poisson_10() {
130142
let poisson = Poisson::new(10.0).unwrap();
131143
let mut rng = crate::test::rng(123);
132-
let mut sum = 0;
144+
let mut sum_u64 = 0;
145+
let mut sum_f64 = 0.;
133146
for _ in 0..1000 {
134-
sum += poisson.sample(&mut rng);
147+
let s_u64: u64 = poisson.sample(&mut rng);
148+
let s_f64: f64 = poisson.sample(&mut rng);
149+
sum_u64 += s_u64;
150+
sum_f64 += s_f64;
151+
}
152+
let avg_u64 = (sum_u64 as f64) / 1000.0;
153+
let avg_f64 = sum_f64 / 1000.0;
154+
println!("Poisson averages: {} (u64) {} (f64)", avg_u64, avg_f64);
155+
for &avg in &[avg_u64, avg_f64] {
156+
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
135157
}
136-
let avg = (sum as f64) / 1000.0;
137-
println!("Poisson average: {}", avg);
138-
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
139158
}
140159

141160
#[test]
142161
fn test_poisson_15() {
143162
// Take the 'high expected values' path
144163
let poisson = Poisson::new(15.0).unwrap();
145164
let mut rng = crate::test::rng(123);
146-
let mut sum = 0;
165+
let mut sum_u64 = 0;
166+
let mut sum_f64 = 0.;
147167
for _ in 0..1000 {
148-
sum += poisson.sample(&mut rng);
168+
let s_u64: u64 = poisson.sample(&mut rng);
169+
let s_f64: f64 = poisson.sample(&mut rng);
170+
sum_u64 += s_u64;
171+
sum_f64 += s_f64;
172+
}
173+
let avg_u64 = (sum_u64 as f64) / 1000.0;
174+
let avg_f64 = sum_f64 / 1000.0;
175+
println!("Poisson average: {} (u64) {} (f64)", avg_u64, avg_f64);
176+
for &avg in &[avg_u64, avg_f64] {
177+
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
178+
}
179+
}
180+
181+
#[test]
182+
fn test_poisson_10_f32() {
183+
let poisson = Poisson::new(10.0f32).unwrap();
184+
let mut rng = crate::test::rng(123);
185+
let mut sum_u64 = 0;
186+
let mut sum_f32 = 0.;
187+
for _ in 0..1000 {
188+
let s_u64: u64 = poisson.sample(&mut rng);
189+
let s_f32: f32 = poisson.sample(&mut rng);
190+
sum_u64 += s_u64;
191+
sum_f32 += s_f32;
192+
}
193+
let avg_u64 = (sum_u64 as f32) / 1000.0;
194+
let avg_f32 = sum_f32 / 1000.0;
195+
println!("Poisson averages: {} (u64) {} (f32)", avg_u64, avg_f32);
196+
for &avg in &[avg_u64, avg_f32] {
197+
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
198+
}
199+
}
200+
201+
#[test]
202+
fn test_poisson_15_f32() {
203+
// Take the 'high expected values' path
204+
let poisson = Poisson::new(15.0f32).unwrap();
205+
let mut rng = crate::test::rng(123);
206+
let mut sum_u64 = 0;
207+
let mut sum_f32 = 0.;
208+
for _ in 0..1000 {
209+
let s_u64: u64 = poisson.sample(&mut rng);
210+
let s_f32: f32 = poisson.sample(&mut rng);
211+
sum_u64 += s_u64;
212+
sum_f32 += s_f32;
213+
}
214+
let avg_u64 = (sum_u64 as f32) / 1000.0;
215+
let avg_f32 = sum_f32 / 1000.0;
216+
println!("Poisson average: {} (u64) {} (f32)", avg_u64, avg_f32);
217+
for &avg in &[avg_u64, avg_f32] {
218+
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
149219
}
150-
let avg = (sum as f64) / 1000.0;
151-
println!("Poisson average: {}", avg);
152-
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
153220
}
154221

155222
#[test]

rand_distr/src/utils.rs

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,13 @@ pub trait Float: Copy + Sized + cmp::PartialOrd
3333
fn pi() -> Self;
3434
/// Support approximate representation of a f64 value
3535
fn from(x: f64) -> Self;
36+
/// Support converting to an unsigned integer.
37+
fn to_u64(self) -> Option<u64>;
3638

3739
/// Take the absolute value of self
3840
fn abs(self) -> Self;
41+
/// Take the largest integer less than or equal to self
42+
fn floor(self) -> Self;
3943

4044
/// Take the exponential of self
4145
fn exp(self) -> Self;
@@ -48,34 +52,81 @@ pub trait Float: Copy + Sized + cmp::PartialOrd
4852

4953
/// Take the tangent of self
5054
fn tan(self) -> Self;
55+
/// Take the logarithm of the gamma function of self
56+
fn log_gamma(self) -> Self;
5157
}
5258

5359
impl Float for f32 {
60+
#[inline]
5461
fn pi() -> Self { core::f32::consts::PI }
62+
#[inline]
5563
fn from(x: f64) -> Self { x as f32 }
64+
#[inline]
65+
fn to_u64(self) -> Option<u64> {
66+
if self >= 0. && self <= ::core::u64::MAX as f32 {
67+
Some(self as u64)
68+
} else {
69+
None
70+
}
71+
}
5672

73+
#[inline]
5774
fn abs(self) -> Self { self.abs() }
75+
#[inline]
76+
fn floor(self) -> Self { self.floor() }
5877

78+
#[inline]
5979
fn exp(self) -> Self { self.exp() }
80+
#[inline]
6081
fn ln(self) -> Self { self.ln() }
82+
#[inline]
6183
fn sqrt(self) -> Self { self.sqrt() }
84+
#[inline]
6285
fn powf(self, power: Self) -> Self { self.powf(power) }
6386

87+
#[inline]
6488
fn tan(self) -> Self { self.tan() }
89+
#[inline]
90+
fn log_gamma(self) -> Self {
91+
let result = log_gamma(self as f64);
92+
assert!(result <= ::core::f32::MAX as f64);
93+
assert!(result >= ::core::f32::MIN as f64);
94+
result as f32
95+
}
6596
}
6697

6798
impl Float for f64 {
99+
#[inline]
68100
fn pi() -> Self { core::f64::consts::PI }
101+
#[inline]
69102
fn from(x: f64) -> Self { x }
103+
#[inline]
104+
fn to_u64(self) -> Option<u64> {
105+
if self >= 0. && self <= ::core::u64::MAX as f64 {
106+
Some(self as u64)
107+
} else {
108+
None
109+
}
110+
}
70111

112+
#[inline]
71113
fn abs(self) -> Self { self.abs() }
114+
#[inline]
115+
fn floor(self) -> Self { self.floor() }
72116

117+
#[inline]
73118
fn exp(self) -> Self { self.exp() }
119+
#[inline]
74120
fn ln(self) -> Self { self.ln() }
121+
#[inline]
75122
fn sqrt(self) -> Self { self.sqrt() }
123+
#[inline]
76124
fn powf(self, power: Self) -> Self { self.powf(power) }
77125

126+
#[inline]
78127
fn tan(self) -> Self { self.tan() }
128+
#[inline]
129+
fn log_gamma(self) -> Self { log_gamma(self) }
79130
}
80131

81132
/// Calculates ln(gamma(x)) (natural logarithm of the gamma
@@ -109,7 +160,7 @@ pub(crate) fn log_gamma(x: f64) -> f64 {
109160
// the first few terms of the series for Ag(x)
110161
let mut a = 1.000000000190015;
111162
let mut denom = x;
112-
for coeff in &coefficients {
163+
for &coeff in &coefficients {
113164
denom += 1.0;
114165
a += coeff / denom;
115166
}

0 commit comments

Comments
 (0)