diff options
Diffstat (limited to 'rand/src/distributions/binomial.rs')
-rw-r--r-- | rand/src/distributions/binomial.rs | 276 |
1 files changed, 206 insertions, 70 deletions
diff --git a/rand/src/distributions/binomial.rs b/rand/src/distributions/binomial.rs index 2df393e..8fc290a 100644 --- a/rand/src/distributions/binomial.rs +++ b/rand/src/distributions/binomial.rs @@ -8,25 +8,17 @@ // except according to those terms. //! The binomial distribution. +#![allow(deprecated)] +#![allow(clippy::all)] -use Rng; -use distributions::{Distribution, Bernoulli, Cauchy}; -use distributions::utils::log_gamma; +use crate::Rng; +use crate::distributions::{Distribution, Uniform}; /// The binomial distribution `Binomial(n, p)`. /// /// This distribution has density function: /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Binomial, Distribution}; -/// -/// let bin = Binomial::new(20, 0.3); -/// let v = bin.sample(&mut rand::thread_rng()); -/// println!("{} is from a binomial distribution", v); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Binomial { /// Number of trials. @@ -47,6 +39,13 @@ impl Binomial { } } +/// Convert a `f64` to an `i64`, panicing on overflow. +// In the future (Rust 1.34), this might be replaced with `TryFrom`. +fn f64_to_i64(x: f64) -> i64 { + assert!(x < (::std::i64::MAX as f64)); + x as i64 +} + impl Distribution<u64> for Binomial { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { // Handle these values directly. @@ -55,83 +54,217 @@ impl Distribution<u64> for Binomial { } else if self.p == 1.0 { return self.n; } - - // For low n, it is faster to sample directly. For both methods, - // performance is independent of p. On Intel Haswell CPU this method - // appears to be faster for approx n < 300. - if self.n < 300 { - let mut result = 0; - let d = Bernoulli::new(self.p); - for _ in 0 .. self.n { - result += rng.sample(d) as u32; - } - return result as u64; - } - - // binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k - // switch p so that it is less than 0.5 - this allows for lower expected values - // we will just invert the result at the end + + // The binomial distribution is symmetrical with respect to p -> 1-p, + // k -> n-k switch p so that it is less than 0.5 - this allows for lower + // expected values we will just invert the result at the end let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; - // prepare some cached values - let float_n = self.n as f64; - let ln_fact_n = log_gamma(float_n + 1.0); - let pc = 1.0 - p; - let log_p = p.ln(); - let log_pc = pc.ln(); - let expected = self.n as f64 * p; - let sq = (expected * (2.0 * pc)).sqrt(); - - let mut lresult; - - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(0.0, 1.0); - loop { - let mut comp_dev: f64; + let result; + let q = 1. - p; + + // For small n * min(p, 1 - p), the BINV algorithm based on the inverse + // transformation of the binomial distribution is efficient. Otherwise, + // the BTPE algorithm is used. + // + // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial + // random variate generation. Commun. ACM 31, 2 (February 1988), + // 216-222. http://dx.doi.org/10.1145/42372.42381 + + // Threshold for prefering the BINV algorithm. The paper suggests 10, + // Ranlib uses 30, and GSL uses 14. + const BINV_THRESHOLD: f64 = 10.; + + if (self.n as f64) * p < BINV_THRESHOLD && + self.n <= (::std::i32::MAX as u64) { + // Use the BINV algorithm. + let s = p / q; + let a = ((self.n + 1) as f64) * s; + let mut r = q.powi(self.n as i32); + let mut u: f64 = rng.gen(); + let mut x = 0; + while u > r as f64 { + u -= r; + x += 1; + r *= a / (x as f64) - s; + } + result = x; + } else { + // Use the BTPE algorithm. + + // Threshold for using the squeeze algorithm. This can be freely + // chosen based on performance. Ranlib and GSL use 20. + const SQUEEZE_THRESHOLD: i64 = 20; + + // Step 0: Calculate constants as functions of `n` and `p`. + let n = self.n as f64; + let np = n * p; + let npq = np * q; + let f_m = np + p; + let m = f64_to_i64(f_m); + // radius of triangle region, since height=1 also area of region + let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; + // tip of triangle + let x_m = (m as f64) + 0.5; + // left edge of triangle + let x_l = x_m - p1; + // right edge of triangle + let x_r = x_m + p1; + let c = 0.134 + 20.5 / (15.3 + (m as f64)); + // p1 + area of parallelogram region + let p2 = p1 * (1. + 2. * c); + + fn lambda(a: f64) -> f64 { + a * (1. + 0.5 * a) + } + + let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); + let lambda_r = lambda((x_r - f_m) / (x_r * q)); + // p1 + area of left tail + let p3 = p2 + c / lambda_l; + // p1 + area of right tail + let p4 = p3 + c / lambda_r; + + // return value + let mut y: i64; + + let gen_u = Uniform::new(0., p4); + let gen_v = Uniform::new(0., 1.); + loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison ditribution - lresult = expected + sq * comp_dev; - // repeat the drawing until we are in the range of possible values - if lresult >= 0.0 && lresult < float_n + 1.0 { + // Step 1: Generate `u` for selecting the region. If region 1 is + // selected, generate a triangularly distributed variate. + let u = gen_u.sample(rng); + let mut v = gen_v.sample(rng); + if !(u > p1) { + y = f64_to_i64(x_m - p1 * v + u); break; } - } - // the result should be discrete - lresult = lresult.floor(); + if !(u > p2) { + // Step 2: Region 2, parallelograms. Check if region 2 is + // used. If so, generate `y`. + let x = x_l + (u - p1) / c; + v = v * c + 1.0 - (x - x_m).abs() / p1; + if v > 1. { + continue; + } else { + y = f64_to_i64(x); + } + } else if !(u > p3) { + // Step 3: Region 3, left exponential tail. + y = f64_to_i64(x_l + v.ln() / lambda_l); + if y < 0 { + continue; + } else { + v *= (u - p2) * lambda_l; + } + } else { + // Step 4: Region 4, right exponential tail. + y = f64_to_i64(x_r - v.ln() / lambda_r); + if y > 0 && (y as u64) > self.n { + continue; + } else { + v *= (u - p3) * lambda_r; + } + } + + // Step 5: Acceptance/rejection comparison. + + // Step 5.0: Test for appropriate method of evaluating f(y). + let k = (y - m).abs(); + if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { + // Step 5.1: Evaluate f(y) via the recursive relationship. Start the + // search from the mode. + let s = p / q; + let a = s * (n + 1.); + let mut f = 1.0; + if m < y { + let mut i = m; + loop { + i += 1; + f *= a / (i as f64) - s; + if i == y { + break; + } + } + } else if m > y { + let mut i = y; + loop { + i += 1; + f /= a / (i as f64) - s; + if i == m { + break; + } + } + } + if v > f { + continue; + } else { + break; + } + } - let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) - - log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc; - // this is the binomial probability divided by the comparison probability - // we will generate a uniform random value and if it is larger than this, - // we interpret it as a value falling out of the distribution and repeat - let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev)); + // Step 5.2: Squeezing. Check the value of ln(v) againts upper and + // lower bound of ln(f(y)). + let k = k as f64; + let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1./6.) / npq + 0.5); + let t = -0.5 * k*k / npq; + let alpha = v.ln(); + if alpha < t - rho { + break; + } + if alpha > t + rho { + continue; + } + + // Step 5.3: Final acceptance/rejection test. + let x1 = (y + 1) as f64; + let f1 = (m + 1) as f64; + let z = (f64_to_i64(n) + 1 - m) as f64; + let w = (f64_to_i64(n) - y + 1) as f64; + + fn stirling(a: f64) -> f64 { + let a2 = a * a; + (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. + } + + if alpha > x_m * (f1 / x1).ln() + + (n - (m as f64) + 0.5) * (z / w).ln() + + ((y - m) as f64) * (w * p / (x1 * q)).ln() + // We use the signs from the GSL implementation, which are + // different than the ones in the reference. According to + // the GSL authors, the new signs were verified to be + // correct by one of the original designers of the + // algorithm. + + stirling(f1) + stirling(z) - stirling(x1) - stirling(w) + { + continue; + } - if comparison_coeff >= rng.gen() { break; } + assert!(y >= 0); + result = y as u64; } - // invert the result for p < 0.5 + // Invert the result for p < 0.5. if p != self.p { - self.n - lresult as u64 + self.n - result } else { - lresult as u64 + result } } } #[cfg(test)] mod test { - use Rng; - use distributions::Distribution; + use crate::Rng; + use crate::distributions::Distribution; use super::Binomial; fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) { @@ -144,17 +277,20 @@ mod test { for i in results.iter_mut() { *i = binomial.sample(rng) as f64; } let mean = results.iter().sum::<f64>() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0, + "mean: {}, expected_mean: {}", mean, expected_mean); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); + assert!((variance - expected_variance).abs() < expected_variance / 10.0, + "variance: {}, expected_variance: {}", variance, expected_variance); } #[test] + #[cfg(not(miri))] // Miri is too slow fn test_binomial() { - let mut rng = ::test::rng(351); + let mut rng = crate::test::rng(351); test_binomial_mean_and_variance(150, 0.1, &mut rng); test_binomial_mean_and_variance(70, 0.6, &mut rng); test_binomial_mean_and_variance(40, 0.5, &mut rng); @@ -164,7 +300,7 @@ mod test { #[test] fn test_binomial_end_points() { - let mut rng = ::test::rng(352); + let mut rng = crate::test::rng(352); assert_eq!(rng.sample(Binomial::new(20, 0.0)), 0); assert_eq!(rng.sample(Binomial::new(20, 1.0)), 20); } |