Skip to content

Commit 4ef40b6

Browse files
authored
Merge pull request #794 from vks/update-binomial
rand_distr: Update binomial distribution implementation
2 parents 0f1b1ff + 664dad2 commit 4ef40b6

File tree

1 file changed

+170
-43
lines changed

1 file changed

+170
-43
lines changed

rand_distr/src/binomial.rs

Lines changed: 170 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
//! The binomial distribution.
1111
1212
use rand::Rng;
13-
use crate::{Distribution, Cauchy};
14-
use crate::utils::log_gamma;
13+
use crate::{Distribution, Uniform};
1514

1615
/// The binomial distribution `Binomial(n, p)`.
1716
///
@@ -58,6 +57,13 @@ impl Binomial {
5857
}
5958
}
6059

60+
/// Convert a `f64` to an `i64`, panicing on overflow.
61+
// In the future (Rust 1.34), this might be replaced with `TryFrom`.
62+
fn f64_to_i64(x: f64) -> i64 {
63+
assert!(x < (::std::i64::MAX as f64));
64+
x as i64
65+
}
66+
6167
impl Distribution<u64> for Binomial {
6268
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
6369
// Handle these values directly.
@@ -67,25 +73,33 @@ impl Distribution<u64> for Binomial {
6773
return self.n;
6874
}
6975

70-
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k
71-
// switch p so that it is less than 0.5 - this allows for lower expected values
72-
// we will just invert the result at the end
76+
// The binomial distribution is symmetrical with respect to p -> 1-p,
77+
// k -> n-k switch p so that it is less than 0.5 - this allows for lower
78+
// expected values we will just invert the result at the end
7379
let p = if self.p <= 0.5 {
7480
self.p
7581
} else {
7682
1.0 - self.p
7783
};
7884

7985
let result;
86+
let q = 1. - p;
8087

8188
// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
82-
// transformation of the binomial distribution is more efficient:
89+
// transformation of the binomial distribution is efficient. Otherwise,
90+
// the BTPE algorithm is used.
8391
//
8492
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial
8593
// random variate generation. Commun. ACM 31, 2 (February 1988),
8694
// 216-222. http://dx.doi.org/10.1145/42372.42381
87-
if (self.n as f64) * p < 10. && self.n <= (::std::i32::MAX as u64) {
88-
let q = 1. - p;
95+
96+
// Threshold for prefering the BINV algorithm. The paper suggests 10,
97+
// Ranlib uses 30, and GSL uses 14.
98+
const BINV_THRESHOLD: f64 = 10.;
99+
100+
if (self.n as f64) * p < BINV_THRESHOLD &&
101+
self.n <= (::std::i32::MAX as u64) {
102+
// Use the BINV algorithm.
89103
let s = p / q;
90104
let a = ((self.n + 1) as f64) * s;
91105
let mut r = q.powi(self.n as i32);
@@ -98,52 +112,165 @@ impl Distribution<u64> for Binomial {
98112
}
99113
result = x;
100114
} else {
101-
// FIXME: Using the BTPE algorithm is probably faster.
102-
103-
// prepare some cached values
104-
let float_n = self.n as f64;
105-
let ln_fact_n = log_gamma(float_n + 1.0);
106-
let pc = 1.0 - p;
107-
let log_p = p.ln();
108-
let log_pc = pc.ln();
109-
let expected = self.n as f64 * p;
110-
let sq = (expected * (2.0 * pc)).sqrt();
111-
let mut lresult;
112-
113-
// we use the Cauchy distribution as the comparison distribution
114-
// f(x) ~ 1/(1+x^2)
115-
let cauchy = Cauchy::new(0.0, 1.0).unwrap();
115+
// Use the BTPE algorithm.
116+
117+
// Threshold for using the squeeze algorithm. This can be freely
118+
// chosen based on performance. Ranlib and GSL use 20.
119+
const SQUEEZE_THRESHOLD: i64 = 20;
120+
121+
// Step 0: Calculate constants as functions of `n` and `p`.
122+
let n = self.n as f64;
123+
let np = n * p;
124+
let npq = np * q;
125+
let f_m = np + p;
126+
let m = f64_to_i64(f_m);
127+
// radius of triangle region, since height=1 also area of region
128+
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
129+
// tip of triangle
130+
let x_m = (m as f64) + 0.5;
131+
// left edge of triangle
132+
let x_l = x_m - p1;
133+
// right edge of triangle
134+
let x_r = x_m + p1;
135+
let c = 0.134 + 20.5 / (15.3 + (m as f64));
136+
// p1 + area of parallelogram region
137+
let p2 = p1 * (1. + 2. * c);
138+
139+
fn lambda(a: f64) -> f64 {
140+
a * (1. + 0.5 * a)
141+
}
142+
143+
let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
144+
let lambda_r = lambda((x_r - f_m) / (x_r * q));
145+
// p1 + area of left tail
146+
let p3 = p2 + c / lambda_l;
147+
// p1 + area of right tail
148+
let p4 = p3 + c / lambda_r;
149+
150+
// return value
151+
let mut y: i64;
152+
153+
let gen_u = Uniform::new(0., p4);
154+
let gen_v = Uniform::new(0., 1.);
155+
116156
loop {
117-
let mut comp_dev: f64;
118-
loop {
119-
// draw from the Cauchy distribution
120-
comp_dev = rng.sample(cauchy);
121-
// shift the peak of the comparison ditribution
122-
lresult = expected + sq * comp_dev;
123-
// repeat the drawing until we are in the range of possible values
124-
if lresult >= 0.0 && lresult < float_n + 1.0 {
125-
break;
157+
// Step 1: Generate `u` for selecting the region. If region 1 is
158+
// selected, generate a triangularly distributed variate.
159+
let u = gen_u.sample(rng);
160+
let mut v = gen_v.sample(rng);
161+
if !(u > p1) {
162+
y = f64_to_i64(x_m - p1 * v + u);
163+
break;
164+
}
165+
166+
if !(u > p2) {
167+
// Step 2: Region 2, parallelograms. Check if region 2 is
168+
// used. If so, generate `y`.
169+
let x = x_l + (u - p1) / c;
170+
v = v * c + 1.0 - (x - x_m).abs() / p1;
171+
if v > 1. {
172+
continue;
173+
} else {
174+
y = f64_to_i64(x);
175+
}
176+
} else if !(u > p3) {
177+
// Step 3: Region 3, left exponential tail.
178+
y = f64_to_i64(x_l + v.ln() / lambda_l);
179+
if y < 0 {
180+
continue;
181+
} else {
182+
v *= (u - p2) * lambda_l;
183+
}
184+
} else {
185+
// Step 4: Region 4, right exponential tail.
186+
y = f64_to_i64(x_r - v.ln() / lambda_r);
187+
if y > 0 && (y as u64) > self.n {
188+
continue;
189+
} else {
190+
v *= (u - p3) * lambda_r;
126191
}
127192
}
128193

129-
// the result should be discrete
130-
lresult = lresult.floor();
194+
// Step 5: Acceptance/rejection comparison.
131195

132-
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) -
133-
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc;
134-
// this is the binomial probability divided by the comparison probability
135-
// we will generate a uniform random value and if it is larger than this,
136-
// we interpret it as a value falling out of the distribution and repeat
137-
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev));
196+
// Step 5.0: Test for appropriate method of evaluating f(y).
197+
let k = (y - m).abs();
198+
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
199+
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
200+
// search from the mode.
201+
let s = p / q;
202+
let a = s * (n + 1.);
203+
let mut f = 1.0;
204+
if m < y {
205+
let mut i = m;
206+
loop {
207+
i += 1;
208+
f *= a / (i as f64) - s;
209+
if i == y {
210+
break;
211+
}
212+
}
213+
} else if m > y {
214+
let mut i = y;
215+
loop {
216+
i += 1;
217+
f /= a / (i as f64) - s;
218+
if i == m {
219+
break;
220+
}
221+
}
222+
}
223+
if v > f {
224+
continue;
225+
} else {
226+
break;
227+
}
228+
}
138229

139-
if comparison_coeff >= rng.gen() {
230+
// Step 5.2: Squeezing. Check the value of ln(v) againts upper and
231+
// lower bound of ln(f(y)).
232+
let k = k as f64;
233+
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1./6.) / npq + 0.5);
234+
let t = -0.5 * k*k / npq;
235+
let alpha = v.ln();
236+
if alpha < t - rho {
140237
break;
141238
}
239+
if alpha > t + rho {
240+
continue;
241+
}
242+
243+
// Step 5.3: Final acceptance/rejection test.
244+
let x1 = (y + 1) as f64;
245+
let f1 = (m + 1) as f64;
246+
let z = (f64_to_i64(n) + 1 - m) as f64;
247+
let w = (f64_to_i64(n) - y + 1) as f64;
248+
249+
fn stirling(a: f64) -> f64 {
250+
let a2 = a * a;
251+
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
252+
}
253+
254+
if alpha > x_m * (f1 / x1).ln()
255+
+ (n - (m as f64) + 0.5) * (z / w).ln()
256+
+ ((y - m) as f64) * (w * p / (x1 * q)).ln()
257+
// We use the signs from the GSL implementation, which are
258+
// different than the ones in the reference. According to
259+
// the GSL authors, the new signs were verified to be
260+
// correct by one of the original designers of the
261+
// algorithm.
262+
+ stirling(f1) + stirling(z) - stirling(x1) - stirling(w)
263+
{
264+
continue;
265+
}
266+
267+
break;
142268
}
143-
result = lresult as u64;
269+
assert!(y >= 0);
270+
result = y as u64;
144271
}
145272

146-
// invert the result for p < 0.5
273+
// Invert the result for p < 0.5.
147274
if p != self.p {
148275
self.n - result
149276
} else {

0 commit comments

Comments
 (0)