diff options
Diffstat (limited to 'rand/src/distributions')
21 files changed, 1333 insertions, 962 deletions
diff --git a/rand/src/distributions/bernoulli.rs b/rand/src/distributions/bernoulli.rs index f49618c..eadd056 100644 --- a/rand/src/distributions/bernoulli.rs +++ b/rand/src/distributions/bernoulli.rs @@ -8,8 +8,8 @@ //! The Bernoulli distribution. -use Rng; -use distributions::Distribution; +use crate::Rng; +use crate::distributions::Distribution; /// The Bernoulli distribution. /// @@ -20,7 +20,7 @@ use distributions::Distribution; /// ```rust /// use rand::distributions::{Bernoulli, Distribution}; /// -/// let d = Bernoulli::new(0.3); +/// let d = Bernoulli::new(0.3).unwrap(); /// let v = d.sample(&mut rand::thread_rng()); /// println!("{} is from a Bernoulli distribution", v); /// ``` @@ -61,13 +61,16 @@ const ALWAYS_TRUE: u64 = ::core::u64::MAX; // in `no_std` mode. const SCALE: f64 = 2.0 * (1u64 << 63) as f64; +/// Error type returned from `Bernoulli::new`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BernoulliError { + /// `p < 0` or `p > 1`. + InvalidProbability, +} + impl Bernoulli { /// Construct a new `Bernoulli` with the given probability of success `p`. /// - /// # Panics - /// - /// If `p < 0` or `p > 1`. - /// /// # Precision /// /// For `p = 1.0`, the resulting distribution will always generate true. @@ -77,12 +80,12 @@ impl Bernoulli { /// a multiple of 2<sup>-64</sup>. (Note that not all multiples of /// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.) #[inline] - pub fn new(p: f64) -> Bernoulli { + pub fn new(p: f64) -> Result<Bernoulli, BernoulliError> { if p < 0.0 || p >= 1.0 { - if p == 1.0 { return Bernoulli { p_int: ALWAYS_TRUE } } - panic!("Bernoulli::new not called with 0.0 <= p <= 1.0"); + if p == 1.0 { return Ok(Bernoulli { p_int: ALWAYS_TRUE }) } + return Err(BernoulliError::InvalidProbability); } - Bernoulli { p_int: (p * SCALE) as u64 } + Ok(Bernoulli { p_int: (p * SCALE) as u64 }) } /// Construct a new `Bernoulli` with the probability of success of @@ -91,19 +94,16 @@ impl Bernoulli { /// /// If `numerator == denominator` then the returned `Bernoulli` will always /// return `true`. If `numerator == 0` it will always return `false`. - /// - /// # Panics - /// - /// If `denominator == 0` or `numerator > denominator`. - /// #[inline] - pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli { - assert!(numerator <= denominator); + pub fn from_ratio(numerator: u32, denominator: u32) -> Result<Bernoulli, BernoulliError> { + if numerator > denominator { + return Err(BernoulliError::InvalidProbability); + } if numerator == denominator { - return Bernoulli { p_int: ::core::u64::MAX } + return Ok(Bernoulli { p_int: ALWAYS_TRUE }) } - let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64; - Bernoulli { p_int } + let p_int = ((f64::from(numerator) / f64::from(denominator)) * SCALE) as u64; + Ok(Bernoulli { p_int }) } } @@ -119,15 +119,15 @@ impl Distribution<bool> for Bernoulli { #[cfg(test)] mod test { - use Rng; - use distributions::Distribution; + use crate::Rng; + use crate::distributions::Distribution; use super::Bernoulli; #[test] fn test_trivial() { - let mut r = ::test::rng(1); - let always_false = Bernoulli::new(0.0); - let always_true = Bernoulli::new(1.0); + let mut r = crate::test::rng(1); + let always_false = Bernoulli::new(0.0).unwrap(); + let always_true = Bernoulli::new(1.0).unwrap(); for _ in 0..5 { assert_eq!(r.sample::<bool, _>(&always_false), false); assert_eq!(r.sample::<bool, _>(&always_true), true); @@ -137,17 +137,18 @@ mod test { } #[test] + #[cfg(not(miri))] // Miri is too slow fn test_average() { const P: f64 = 0.3; const NUM: u32 = 3; const DENOM: u32 = 10; - let d1 = Bernoulli::new(P); - let d2 = Bernoulli::from_ratio(NUM, DENOM); + let d1 = Bernoulli::new(P).unwrap(); + let d2 = Bernoulli::from_ratio(NUM, DENOM).unwrap(); const N: u32 = 100_000; let mut sum1: u32 = 0; let mut sum2: u32 = 0; - let mut rng = ::test::rng(2); + let mut rng = crate::test::rng(2); for _ in 0..N { if d1.sample(&mut rng) { sum1 += 1; diff --git a/rand/src/distributions/binomial.rs b/rand/src/distributions/binomial.rs index 2df393e..8fc290a 100644 --- a/rand/src/distributions/binomial.rs +++ b/rand/src/distributions/binomial.rs @@ -8,25 +8,17 @@ // except according to those terms. //! The binomial distribution. +#![allow(deprecated)] +#![allow(clippy::all)] -use Rng; -use distributions::{Distribution, Bernoulli, Cauchy}; -use distributions::utils::log_gamma; +use crate::Rng; +use crate::distributions::{Distribution, Uniform}; /// The binomial distribution `Binomial(n, p)`. /// /// This distribution has density function: /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Binomial, Distribution}; -/// -/// let bin = Binomial::new(20, 0.3); -/// let v = bin.sample(&mut rand::thread_rng()); -/// println!("{} is from a binomial distribution", v); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Binomial { /// Number of trials. @@ -47,6 +39,13 @@ impl Binomial { } } +/// Convert a `f64` to an `i64`, panicing on overflow. +// In the future (Rust 1.34), this might be replaced with `TryFrom`. +fn f64_to_i64(x: f64) -> i64 { + assert!(x < (::std::i64::MAX as f64)); + x as i64 +} + impl Distribution<u64> for Binomial { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { // Handle these values directly. @@ -55,83 +54,217 @@ impl Distribution<u64> for Binomial { } else if self.p == 1.0 { return self.n; } - - // For low n, it is faster to sample directly. For both methods, - // performance is independent of p. On Intel Haswell CPU this method - // appears to be faster for approx n < 300. - if self.n < 300 { - let mut result = 0; - let d = Bernoulli::new(self.p); - for _ in 0 .. self.n { - result += rng.sample(d) as u32; - } - return result as u64; - } - - // binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k - // switch p so that it is less than 0.5 - this allows for lower expected values - // we will just invert the result at the end + + // The binomial distribution is symmetrical with respect to p -> 1-p, + // k -> n-k switch p so that it is less than 0.5 - this allows for lower + // expected values we will just invert the result at the end let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; - // prepare some cached values - let float_n = self.n as f64; - let ln_fact_n = log_gamma(float_n + 1.0); - let pc = 1.0 - p; - let log_p = p.ln(); - let log_pc = pc.ln(); - let expected = self.n as f64 * p; - let sq = (expected * (2.0 * pc)).sqrt(); - - let mut lresult; - - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(0.0, 1.0); - loop { - let mut comp_dev: f64; + let result; + let q = 1. - p; + + // For small n * min(p, 1 - p), the BINV algorithm based on the inverse + // transformation of the binomial distribution is efficient. Otherwise, + // the BTPE algorithm is used. + // + // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial + // random variate generation. Commun. ACM 31, 2 (February 1988), + // 216-222. http://dx.doi.org/10.1145/42372.42381 + + // Threshold for prefering the BINV algorithm. The paper suggests 10, + // Ranlib uses 30, and GSL uses 14. + const BINV_THRESHOLD: f64 = 10.; + + if (self.n as f64) * p < BINV_THRESHOLD && + self.n <= (::std::i32::MAX as u64) { + // Use the BINV algorithm. + let s = p / q; + let a = ((self.n + 1) as f64) * s; + let mut r = q.powi(self.n as i32); + let mut u: f64 = rng.gen(); + let mut x = 0; + while u > r as f64 { + u -= r; + x += 1; + r *= a / (x as f64) - s; + } + result = x; + } else { + // Use the BTPE algorithm. + + // Threshold for using the squeeze algorithm. This can be freely + // chosen based on performance. Ranlib and GSL use 20. + const SQUEEZE_THRESHOLD: i64 = 20; + + // Step 0: Calculate constants as functions of `n` and `p`. + let n = self.n as f64; + let np = n * p; + let npq = np * q; + let f_m = np + p; + let m = f64_to_i64(f_m); + // radius of triangle region, since height=1 also area of region + let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; + // tip of triangle + let x_m = (m as f64) + 0.5; + // left edge of triangle + let x_l = x_m - p1; + // right edge of triangle + let x_r = x_m + p1; + let c = 0.134 + 20.5 / (15.3 + (m as f64)); + // p1 + area of parallelogram region + let p2 = p1 * (1. + 2. * c); + + fn lambda(a: f64) -> f64 { + a * (1. + 0.5 * a) + } + + let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); + let lambda_r = lambda((x_r - f_m) / (x_r * q)); + // p1 + area of left tail + let p3 = p2 + c / lambda_l; + // p1 + area of right tail + let p4 = p3 + c / lambda_r; + + // return value + let mut y: i64; + + let gen_u = Uniform::new(0., p4); + let gen_v = Uniform::new(0., 1.); + loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison ditribution - lresult = expected + sq * comp_dev; - // repeat the drawing until we are in the range of possible values - if lresult >= 0.0 && lresult < float_n + 1.0 { + // Step 1: Generate `u` for selecting the region. If region 1 is + // selected, generate a triangularly distributed variate. + let u = gen_u.sample(rng); + let mut v = gen_v.sample(rng); + if !(u > p1) { + y = f64_to_i64(x_m - p1 * v + u); break; } - } - // the result should be discrete - lresult = lresult.floor(); + if !(u > p2) { + // Step 2: Region 2, parallelograms. Check if region 2 is + // used. If so, generate `y`. + let x = x_l + (u - p1) / c; + v = v * c + 1.0 - (x - x_m).abs() / p1; + if v > 1. { + continue; + } else { + y = f64_to_i64(x); + } + } else if !(u > p3) { + // Step 3: Region 3, left exponential tail. + y = f64_to_i64(x_l + v.ln() / lambda_l); + if y < 0 { + continue; + } else { + v *= (u - p2) * lambda_l; + } + } else { + // Step 4: Region 4, right exponential tail. + y = f64_to_i64(x_r - v.ln() / lambda_r); + if y > 0 && (y as u64) > self.n { + continue; + } else { + v *= (u - p3) * lambda_r; + } + } + + // Step 5: Acceptance/rejection comparison. + + // Step 5.0: Test for appropriate method of evaluating f(y). + let k = (y - m).abs(); + if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { + // Step 5.1: Evaluate f(y) via the recursive relationship. Start the + // search from the mode. + let s = p / q; + let a = s * (n + 1.); + let mut f = 1.0; + if m < y { + let mut i = m; + loop { + i += 1; + f *= a / (i as f64) - s; + if i == y { + break; + } + } + } else if m > y { + let mut i = y; + loop { + i += 1; + f /= a / (i as f64) - s; + if i == m { + break; + } + } + } + if v > f { + continue; + } else { + break; + } + } - let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) - - log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc; - // this is the binomial probability divided by the comparison probability - // we will generate a uniform random value and if it is larger than this, - // we interpret it as a value falling out of the distribution and repeat - let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev)); + // Step 5.2: Squeezing. Check the value of ln(v) againts upper and + // lower bound of ln(f(y)). + let k = k as f64; + let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1./6.) / npq + 0.5); + let t = -0.5 * k*k / npq; + let alpha = v.ln(); + if alpha < t - rho { + break; + } + if alpha > t + rho { + continue; + } + + // Step 5.3: Final acceptance/rejection test. + let x1 = (y + 1) as f64; + let f1 = (m + 1) as f64; + let z = (f64_to_i64(n) + 1 - m) as f64; + let w = (f64_to_i64(n) - y + 1) as f64; + + fn stirling(a: f64) -> f64 { + let a2 = a * a; + (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. + } + + if alpha > x_m * (f1 / x1).ln() + + (n - (m as f64) + 0.5) * (z / w).ln() + + ((y - m) as f64) * (w * p / (x1 * q)).ln() + // We use the signs from the GSL implementation, which are + // different than the ones in the reference. According to + // the GSL authors, the new signs were verified to be + // correct by one of the original designers of the + // algorithm. + + stirling(f1) + stirling(z) - stirling(x1) - stirling(w) + { + continue; + } - if comparison_coeff >= rng.gen() { break; } + assert!(y >= 0); + result = y as u64; } - // invert the result for p < 0.5 + // Invert the result for p < 0.5. if p != self.p { - self.n - lresult as u64 + self.n - result } else { - lresult as u64 + result } } } #[cfg(test)] mod test { - use Rng; - use distributions::Distribution; + use crate::Rng; + use crate::distributions::Distribution; use super::Binomial; fn test_binomial_mean_and_variance<R: Rng>(n: u64, p: f64, rng: &mut R) { @@ -144,17 +277,20 @@ mod test { for i in results.iter_mut() { *i = binomial.sample(rng) as f64; } let mean = results.iter().sum::<f64>() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0, + "mean: {}, expected_mean: {}", mean, expected_mean); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); + assert!((variance - expected_variance).abs() < expected_variance / 10.0, + "variance: {}, expected_variance: {}", variance, expected_variance); } #[test] + #[cfg(not(miri))] // Miri is too slow fn test_binomial() { - let mut rng = ::test::rng(351); + let mut rng = crate::test::rng(351); test_binomial_mean_and_variance(150, 0.1, &mut rng); test_binomial_mean_and_variance(70, 0.6, &mut rng); test_binomial_mean_and_variance(40, 0.5, &mut rng); @@ -164,7 +300,7 @@ mod test { #[test] fn test_binomial_end_points() { - let mut rng = ::test::rng(352); + let mut rng = crate::test::rng(352); assert_eq!(rng.sample(Binomial::new(20, 0.0)), 0); assert_eq!(rng.sample(Binomial::new(20, 1.0)), 20); } diff --git a/rand/src/distributions/cauchy.rs b/rand/src/distributions/cauchy.rs index feef015..0a5d149 100644 --- a/rand/src/distributions/cauchy.rs +++ b/rand/src/distributions/cauchy.rs @@ -8,25 +8,18 @@ // except according to those terms. //! The Cauchy distribution. +#![allow(deprecated)] +#![allow(clippy::all)] -use Rng; -use distributions::Distribution; +use crate::Rng; +use crate::distributions::Distribution; use std::f64::consts::PI; /// The Cauchy distribution `Cauchy(median, scale)`. /// /// This distribution has a density function: /// `f(x) = 1 / (pi * scale * (1 + ((x - median) / scale)^2))` -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Cauchy, Distribution}; -/// -/// let cau = Cauchy::new(2.0, 5.0); -/// let v = cau.sample(&mut rand::thread_rng()); -/// println!("{} is from a Cauchy(2, 5) distribution", v); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Cauchy { median: f64, @@ -61,7 +54,7 @@ impl Distribution<f64> for Cauchy { #[cfg(test)] mod test { - use distributions::Distribution; + use crate::distributions::Distribution; use super::Cauchy; fn median(mut numbers: &mut [f64]) -> f64 { @@ -75,30 +68,25 @@ mod test { } #[test] - fn test_cauchy_median() { + #[cfg(not(miri))] // Miri doesn't support transcendental functions + fn test_cauchy_averages() { + // NOTE: given that the variance and mean are undefined, + // this test does not have any rigorous statistical meaning. let cauchy = Cauchy::new(10.0, 5.0); - let mut rng = ::test::rng(123); + let mut rng = crate::test::rng(123); let mut numbers: [f64; 1000] = [0.0; 1000]; + let mut sum = 0.0; for i in 0..1000 { numbers[i] = cauchy.sample(&mut rng); + sum += numbers[i]; } let median = median(&mut numbers); println!("Cauchy median: {}", median); - assert!((median - 10.0).abs() < 0.5); // not 100% certain, but probable enough - } - - #[test] - fn test_cauchy_mean() { - let cauchy = Cauchy::new(10.0, 5.0); - let mut rng = ::test::rng(123); - let mut sum = 0.0; - for _ in 0..1000 { - sum += cauchy.sample(&mut rng); - } + assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough let mean = sum / 1000.0; println!("Cauchy mean: {}", mean); // for a Cauchy distribution the mean should not converge - assert!((mean - 10.0).abs() > 0.5); // not 100% certain, but probable enough + assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough } #[test] diff --git a/rand/src/distributions/dirichlet.rs b/rand/src/distributions/dirichlet.rs index 19384b8..1ce01fd 100644 --- a/rand/src/distributions/dirichlet.rs +++ b/rand/src/distributions/dirichlet.rs @@ -8,28 +8,19 @@ // except according to those terms. //! The dirichlet distribution. +#![allow(deprecated)] +#![allow(clippy::all)] -use Rng; -use distributions::Distribution; -use distributions::gamma::Gamma; +use crate::Rng; +use crate::distributions::Distribution; +use crate::distributions::gamma::Gamma; /// The dirichelet distribution `Dirichlet(alpha)`. /// /// The Dirichlet distribution is a family of continuous multivariate /// probability distributions parameterized by a vector alpha of positive reals. /// It is a multivariate generalization of the beta distribution. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::Dirichlet; -/// -/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]); -/// let samples = dirichlet.sample(&mut rand::thread_rng()); -/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); -/// ``` - +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Debug)] pub struct Dirichlet { /// Concentration parameters (alpha) @@ -91,12 +82,12 @@ impl Distribution<Vec<f64>> for Dirichlet { #[cfg(test)] mod test { use super::Dirichlet; - use distributions::Distribution; + use crate::distributions::Distribution; #[test] fn test_dirichlet() { let d = Dirichlet::new(vec![1.0, 2.0, 3.0]); - let mut rng = ::test::rng(221); + let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); let _: Vec<f64> = samples .into_iter() @@ -112,7 +103,7 @@ mod test { let alpha = 0.5f64; let size = 2; let d = Dirichlet::new_with_param(alpha, size); - let mut rng = ::test::rng(221); + let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); let _: Vec<f64> = samples .into_iter() diff --git a/rand/src/distributions/exponential.rs b/rand/src/distributions/exponential.rs index a7d0500..0278248 100644 --- a/rand/src/distributions/exponential.rs +++ b/rand/src/distributions/exponential.rs @@ -8,10 +8,11 @@ // except according to those terms. //! The exponential distribution. +#![allow(deprecated)] -use {Rng}; -use distributions::{ziggurat_tables, Distribution}; -use distributions::utils::ziggurat; +use crate::{Rng}; +use crate::distributions::{ziggurat_tables, Distribution}; +use crate::distributions::utils::ziggurat; /// Samples floating-point numbers according to the exponential distribution, /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or @@ -27,15 +28,7 @@ use distributions::utils::ziggurat; /// Generate Normal Random Samples*]( /// https://www.doornik.com/research/ziggurat.pdf). /// Nuffield College, Oxford -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::Exp1; -/// -/// let val: f64 = SmallRng::from_entropy().sample(Exp1); -/// println!("{}", val); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Exp1; @@ -64,17 +57,8 @@ impl Distribution<f64> for Exp1 { /// This distribution has density function: `f(x) = lambda * exp(-lambda * x)` /// for `x > 0`. /// -/// Note that [`Exp1`](struct.Exp1.html) is an optimised implementation for `lambda = 1`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Exp, Distribution}; -/// -/// let exp = Exp::new(2.0); -/// let v = exp.sample(&mut rand::thread_rng()); -/// println!("{} is from a Exp(2) distribution", v); -/// ``` +/// Note that [`Exp1`](crate::distributions::Exp1) is an optimised implementation for `lambda = 1`. +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Exp { /// `lambda` stored as `1/lambda`, since this is what we scale by. @@ -100,13 +84,13 @@ impl Distribution<f64> for Exp { #[cfg(test)] mod test { - use distributions::Distribution; + use crate::distributions::Distribution; use super::Exp; #[test] fn test_exp() { let exp = Exp::new(10.0); - let mut rng = ::test::rng(221); + let mut rng = crate::test::rng(221); for _ in 0..1000 { assert!(exp.sample(&mut rng) >= 0.0); } diff --git a/rand/src/distributions/float.rs b/rand/src/distributions/float.rs index ece12f5..bda523a 100644 --- a/rand/src/distributions/float.rs +++ b/rand/src/distributions/float.rs @@ -9,9 +9,9 @@ //! Basic floating-point number distributions use core::mem; -use Rng; -use distributions::{Distribution, Standard}; -use distributions::utils::FloatSIMDUtils; +use crate::Rng; +use crate::distributions::{Distribution, Standard}; +use crate::distributions::utils::FloatSIMDUtils; #[cfg(feature="simd_support")] use packed_simd::*; @@ -36,9 +36,9 @@ use packed_simd::*; /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: struct.Standard.html -/// [`Open01`]: struct.Open01.html -/// [`Uniform`]: uniform/struct.Uniform.html +/// [`Standard`]: crate::distributions::Standard +/// [`Open01`]: crate::distributions::Open01 +/// [`Uniform`]: crate::distributions::uniform::Uniform #[derive(Clone, Copy, Debug)] pub struct OpenClosed01; @@ -62,14 +62,16 @@ pub struct OpenClosed01; /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: struct.Standard.html -/// [`OpenClosed01`]: struct.OpenClosed01.html -/// [`Uniform`]: uniform/struct.Uniform.html +/// [`Standard`]: crate::distributions::Standard +/// [`OpenClosed01`]: crate::distributions::OpenClosed01 +/// [`Uniform`]: crate::distributions::uniform::Uniform #[derive(Clone, Copy, Debug)] pub struct Open01; -pub(crate) trait IntoFloat { +// This trait is needed by both this lib and rand_distr hence is a hidden export +#[doc(hidden)] +pub trait IntoFloat { type F; /// Helper method to combine the fraction and a contant exponent into a @@ -93,9 +95,7 @@ macro_rules! float_impls { // The exponent is encoded using an offset-binary representation let exponent_bits: $u_scalar = (($exponent_bias + exponent) as $u_scalar) << $fraction_bits; - // TODO: use from_bits when min compiler > 1.25 (see #545) - // $ty::from_bits(self | exponent_bits) - unsafe{ mem::transmute(self | exponent_bits) } + $ty::from_bits(self | exponent_bits) } } @@ -168,9 +168,9 @@ float_impls! { f64x8, u64x8, f64, u64, 52, 1023 } #[cfg(test)] mod tests { - use Rng; - use distributions::{Open01, OpenClosed01}; - use rngs::mock::StepRng; + use crate::Rng; + use crate::distributions::{Open01, OpenClosed01}; + use crate::rngs::mock::StepRng; #[cfg(feature="simd_support")] use packed_simd::*; diff --git a/rand/src/distributions/gamma.rs b/rand/src/distributions/gamma.rs index 43ac2bc..b5a97f5 100644 --- a/rand/src/distributions/gamma.rs +++ b/rand/src/distributions/gamma.rs @@ -8,13 +8,14 @@ // except according to those terms. //! The Gamma and derived distributions. +#![allow(deprecated)] use self::GammaRepr::*; use self::ChiSquaredRepr::*; -use Rng; -use distributions::normal::StandardNormal; -use distributions::{Distribution, Exp, Open01}; +use crate::Rng; +use crate::distributions::normal::StandardNormal; +use crate::distributions::{Distribution, Exp, Open01}; /// The Gamma distribution `Gamma(shape, scale)` distribution. /// @@ -32,20 +33,11 @@ use distributions::{Distribution, Exp, Open01}; /// == 1`, and using the boosting technique described in that paper for /// `shape < 1`. /// -/// # Example -/// -/// ``` -/// use rand::distributions::{Distribution, Gamma}; -/// -/// let gamma = Gamma::new(2.0, 5.0); -/// let v = gamma.sample(&mut rand::thread_rng()); -/// println!("{} is from a Gamma(2, 5) distribution", v); -/// ``` -/// /// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for /// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3 /// (September 2000), 363-372. /// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Gamma { repr: GammaRepr, @@ -174,16 +166,7 @@ impl Distribution<f64> for GammaLargeShape { /// of `k` independent standard normal random variables. For other /// `k`, this uses the equivalent characterisation /// `χ²(k) = Gamma(k/2, 2)`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{ChiSquared, Distribution}; -/// -/// let chi = ChiSquared::new(11.0); -/// let v = chi.sample(&mut rand::thread_rng()); -/// println!("{} is from a χ²(11) distribution", v) -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct ChiSquared { repr: ChiSquaredRepr, @@ -229,16 +212,7 @@ impl Distribution<f64> for ChiSquared { /// This distribution is equivalent to the ratio of two normalised /// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / /// (χ²(n)/n)`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{FisherF, Distribution}; -/// -/// let f = FisherF::new(2.0, 32.0); -/// let v = f.sample(&mut rand::thread_rng()); -/// println!("{} is from an F(2, 32) distribution", v) -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct FisherF { numer: ChiSquared, @@ -270,16 +244,7 @@ impl Distribution<f64> for FisherF { /// The Student t distribution, `t(nu)`, where `nu` is the degrees of /// freedom. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{StudentT, Distribution}; -/// -/// let t = StudentT::new(11.0); -/// let v = t.sample(&mut rand::thread_rng()); -/// println!("{} is from a t(11) distribution", v) -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct StudentT { chi: ChiSquared, @@ -305,16 +270,7 @@ impl Distribution<f64> for StudentT { } /// The Beta distribution with shape parameters `alpha` and `beta`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Distribution, Beta}; -/// -/// let beta = Beta::new(2.0, 5.0); -/// let v = beta.sample(&mut rand::thread_rng()); -/// println!("{} is from a Beta(2, 5) distribution", v); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Beta { gamma_a: Gamma, @@ -345,30 +301,32 @@ impl Distribution<f64> for Beta { #[cfg(test)] mod test { - use distributions::Distribution; + use crate::distributions::Distribution; use super::{Beta, ChiSquared, StudentT, FisherF}; + const N: u32 = 100; + #[test] fn test_chi_squared_one() { let chi = ChiSquared::new(1.0); - let mut rng = ::test::rng(201); - for _ in 0..1000 { + let mut rng = crate::test::rng(201); + for _ in 0..N { chi.sample(&mut rng); } } #[test] fn test_chi_squared_small() { let chi = ChiSquared::new(0.5); - let mut rng = ::test::rng(202); - for _ in 0..1000 { + let mut rng = crate::test::rng(202); + for _ in 0..N { chi.sample(&mut rng); } } #[test] fn test_chi_squared_large() { let chi = ChiSquared::new(30.0); - let mut rng = ::test::rng(203); - for _ in 0..1000 { + let mut rng = crate::test::rng(203); + for _ in 0..N { chi.sample(&mut rng); } } @@ -381,8 +339,8 @@ mod test { #[test] fn test_f() { let f = FisherF::new(2.0, 32.0); - let mut rng = ::test::rng(204); - for _ in 0..1000 { + let mut rng = crate::test::rng(204); + for _ in 0..N { f.sample(&mut rng); } } @@ -390,8 +348,8 @@ mod test { #[test] fn test_t() { let t = StudentT::new(11.0); - let mut rng = ::test::rng(205); - for _ in 0..1000 { + let mut rng = crate::test::rng(205); + for _ in 0..N { t.sample(&mut rng); } } @@ -399,8 +357,8 @@ mod test { #[test] fn test_beta() { let beta = Beta::new(1.0, 2.0); - let mut rng = ::test::rng(201); - for _ in 0..1000 { + let mut rng = crate::test::rng(201); + for _ in 0..N { beta.sample(&mut rng); } } diff --git a/rand/src/distributions/integer.rs b/rand/src/distributions/integer.rs index 7e408db..5238339 100644 --- a/rand/src/distributions/integer.rs +++ b/rand/src/distributions/integer.rs @@ -8,8 +8,10 @@ //! The implementations of the `Standard` distribution for integer types. -use {Rng}; -use distributions::{Distribution, Standard}; +use crate::{Rng}; +use crate::distributions::{Distribution, Standard}; +use core::num::{NonZeroU8, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroUsize}; +#[cfg(not(target_os = "emscripten"))] use core::num::NonZeroU128; #[cfg(feature="simd_support")] use packed_simd::*; #[cfg(all(target_arch = "x86", feature="nightly"))] @@ -45,13 +47,13 @@ impl Distribution<u64> for Standard { } } -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +#[cfg(not(target_os = "emscripten"))] impl Distribution<u128> for Standard { #[inline] fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 { // Use LE; we explicitly generate one value before the next. - let x = rng.next_u64() as u128; - let y = rng.next_u64() as u128; + let x = u128::from(rng.next_u64()); + let y = u128::from(rng.next_u64()); (y << 64) | x } } @@ -85,9 +87,30 @@ impl_int_from_uint! { i8, u8 } impl_int_from_uint! { i16, u16 } impl_int_from_uint! { i32, u32 } impl_int_from_uint! { i64, u64 } -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] impl_int_from_uint! { i128, u128 } +#[cfg(not(target_os = "emscripten"))] impl_int_from_uint! { i128, u128 } impl_int_from_uint! { isize, usize } +macro_rules! impl_nzint { + ($ty:ty, $new:path) => { + impl Distribution<$ty> for Standard { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty { + loop { + if let Some(nz) = $new(rng.gen()) { + break nz; + } + } + } + } + } +} + +impl_nzint!(NonZeroU8, NonZeroU8::new); +impl_nzint!(NonZeroU16, NonZeroU16::new); +impl_nzint!(NonZeroU32, NonZeroU32::new); +impl_nzint!(NonZeroU64, NonZeroU64::new); +#[cfg(not(target_os = "emscripten"))] impl_nzint!(NonZeroU128, NonZeroU128::new); +impl_nzint!(NonZeroUsize, NonZeroUsize::new); + #[cfg(feature="simd_support")] macro_rules! simd_impl { ($(($intrinsic:ident, $vec:ty),)+) => {$( @@ -135,19 +158,19 @@ simd_impl!((__m64, u8x8), (__m128i, u8x16), (__m256i, u8x32),); #[cfg(test)] mod tests { - use Rng; - use distributions::{Standard}; + use crate::Rng; + use crate::distributions::{Standard}; #[test] fn test_integers() { - let mut rng = ::test::rng(806); + let mut rng = crate::test::rng(806); rng.sample::<isize, _>(Standard); rng.sample::<i8, _>(Standard); rng.sample::<i16, _>(Standard); rng.sample::<i32, _>(Standard); rng.sample::<i64, _>(Standard); - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[cfg(not(target_os = "emscripten"))] rng.sample::<i128, _>(Standard); rng.sample::<usize, _>(Standard); @@ -155,7 +178,7 @@ mod tests { rng.sample::<u16, _>(Standard); rng.sample::<u32, _>(Standard); rng.sample::<u64, _>(Standard); - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[cfg(not(target_os = "emscripten"))] rng.sample::<u128, _>(Standard); } } diff --git a/rand/src/distributions/mod.rs b/rand/src/distributions/mod.rs index 5e879cb..02ece6f 100644 --- a/rand/src/distributions/mod.rs +++ b/rand/src/distributions/mod.rs @@ -7,12 +7,12 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Generating random samples from probability distributions. +//! Generating random samples from probability distributions //! //! This module is the home of the [`Distribution`] trait and several of its //! implementations. It is the workhorse behind some of the convenient -//! functionality of the [`Rng`] trait, including [`gen`], [`gen_range`] and -//! of course [`sample`]. +//! functionality of the [`Rng`] trait, e.g. [`Rng::gen`], [`Rng::gen_range`] and +//! of course [`Rng::sample`]. //! //! Abstractly, a [probability distribution] describes the probability of //! occurance of each value in its sample space. @@ -40,8 +40,14 @@ //! possible to generate type `T` with [`Rng::gen()`], and by extension also //! with the [`random()`] function. //! +//! ## Random characters +//! +//! [`Alphanumeric`] is a simple distribution to sample random letters and +//! numbers of the `char` type; in contrast [`Standard`] may sample any valid +//! `char`. +//! //! -//! # Distribution to sample from a `Uniform` range +//! # Uniform numeric ranges //! //! The [`Uniform`] distribution is more flexible than [`Standard`], but also //! more specialised: it supports fewer target types, but allows the sample @@ -56,158 +62,84 @@ //! //! User types `T` may also implement `Distribution<T>` for [`Uniform`], //! although this is less straightforward than for [`Standard`] (see the -//! documentation in the [`uniform` module]. Doing so enables generation of +//! documentation in the [`uniform`] module. Doing so enables generation of //! values of type `T` with [`Rng::gen_range`]. //! -//! -//! # Other distributions +//! ## Open and half-open ranges //! //! There are surprisingly many ways to uniformly generate random floats. A //! range between 0 and 1 is standard, but the exact bounds (open vs closed) //! and accuracy differ. In addition to the [`Standard`] distribution Rand offers -//! [`Open01`] and [`OpenClosed01`]. See [Floating point implementation] for -//! more details. -//! -//! [`Alphanumeric`] is a simple distribution to sample random letters and -//! numbers of the `char` type; in contrast [`Standard`] may sample any valid -//! `char`. -//! -//! [`WeightedIndex`] can be used to do weighted sampling from a set of items, -//! such as from an array. -//! -//! # Non-uniform probability distributions -//! -//! Rand currently provides the following probability distributions: -//! -//! - Related to real-valued quantities that grow linearly -//! (e.g. errors, offsets): -//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive -//! - [`Cauchy`] distribution -//! - Related to Bernoulli trials (yes/no events, with a given probability): -//! - [`Binomial`] distribution -//! - [`Bernoulli`] distribution, similar to [`Rng::gen_bool`]. -//! - Related to positive real-valued quantities that grow exponentially -//! (e.g. prices, incomes, populations): -//! - [`LogNormal`] distribution -//! - Related to the occurrence of independent events at a given rate: -//! - [`Pareto`] distribution -//! - [`Poisson`] distribution -//! - [`Exp`]onential distribution, and [`Exp1`] as a primitive -//! - [`Weibull`] distribution -//! - Gamma and derived distributions: -//! - [`Gamma`] distribution -//! - [`ChiSquared`] distribution -//! - [`StudentT`] distribution -//! - [`FisherF`] distribution -//! - Triangular distribution: -//! - [`Beta`] distribution -//! - [`Triangular`] distribution -//! - Multivariate probability distributions -//! - [`Dirichlet`] distribution -//! - [`UnitSphereSurface`] distribution -//! - [`UnitCircle`] distribution +//! [`Open01`] and [`OpenClosed01`]. See "Floating point implementation" section of +//! [`Standard`] documentation for more details. //! -//! # Examples +//! # Non-uniform sampling //! -//! Sampling from a distribution: +//! Sampling a simple true/false outcome with a given probability has a name: +//! the [`Bernoulli`] distribution (this is used by [`Rng::gen_bool`]). //! -//! ``` -//! use rand::{thread_rng, Rng}; -//! use rand::distributions::Exp; +//! For weighted sampling from a sequence of discrete values, use the +//! [`weighted`] module. //! -//! let exp = Exp::new(2.0); -//! let v = thread_rng().sample(exp); -//! println!("{} is from an Exp(2) distribution", v); -//! ``` -//! -//! Implementing the [`Standard`] distribution for a user type: -//! -//! ``` -//! # #![allow(dead_code)] -//! use rand::Rng; -//! use rand::distributions::{Distribution, Standard}; -//! -//! struct MyF32 { -//! x: f32, -//! } -//! -//! impl Distribution<MyF32> for Standard { -//! fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> MyF32 { -//! MyF32 { x: rng.gen() } -//! } -//! } -//! ``` +//! This crate no longer includes other non-uniform distributions; instead +//! it is recommended that you use either [`rand_distr`] or [`statrs`]. //! //! //! [probability distribution]: https://en.wikipedia.org/wiki/Probability_distribution -//! [`Distribution`]: trait.Distribution.html -//! [`gen_range`]: ../trait.Rng.html#method.gen_range -//! [`gen`]: ../trait.Rng.html#method.gen -//! [`sample`]: ../trait.Rng.html#method.sample -//! [`new_inclusive`]: struct.Uniform.html#method.new_inclusive -//! [`random()`]: ../fn.random.html -//! [`Rng::gen_bool`]: ../trait.Rng.html#method.gen_bool -//! [`Rng::gen_range`]: ../trait.Rng.html#method.gen_range -//! [`Rng::gen()`]: ../trait.Rng.html#method.gen -//! [`Rng`]: ../trait.Rng.html -//! [`uniform` module]: uniform/index.html -//! [Floating point implementation]: struct.Standard.html#floating-point-implementation -// distributions -//! [`Alphanumeric`]: struct.Alphanumeric.html -//! [`Bernoulli`]: struct.Bernoulli.html -//! [`Beta`]: struct.Beta.html -//! [`Binomial`]: struct.Binomial.html -//! [`Cauchy`]: struct.Cauchy.html -//! [`ChiSquared`]: struct.ChiSquared.html -//! [`Dirichlet`]: struct.Dirichlet.html -//! [`Exp`]: struct.Exp.html -//! [`Exp1`]: struct.Exp1.html -//! [`FisherF`]: struct.FisherF.html -//! [`Gamma`]: struct.Gamma.html -//! [`LogNormal`]: struct.LogNormal.html -//! [`Normal`]: struct.Normal.html -//! [`Open01`]: struct.Open01.html -//! [`OpenClosed01`]: struct.OpenClosed01.html -//! [`Pareto`]: struct.Pareto.html -//! [`Poisson`]: struct.Poisson.html -//! [`Standard`]: struct.Standard.html -//! [`StandardNormal`]: struct.StandardNormal.html -//! [`StudentT`]: struct.StudentT.html -//! [`Triangular`]: struct.Triangular.html -//! [`Uniform`]: struct.Uniform.html -//! [`Uniform::new`]: struct.Uniform.html#method.new -//! [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive -//! [`UnitSphereSurface`]: struct.UnitSphereSurface.html -//! [`UnitCircle`]: struct.UnitCircle.html -//! [`Weibull`]: struct.Weibull.html -//! [`WeightedIndex`]: struct.WeightedIndex.html +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +//! [`Alphanumeric`]: distributions::Alphanumeric +//! [`Bernoulli`]: distributions::Bernoulli +//! [`Open01`]: distributions::Open01 +//! [`OpenClosed01`]: distributions::OpenClosed01 +//! [`Standard`]: distributions::Standard +//! [`Uniform`]: distributions::Uniform +//! [`Uniform::new`]: distributions::Uniform::new +//! [`Uniform::new_inclusive`]: distributions::Uniform::new_inclusive +//! [`weighted`]: distributions::weighted +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs -#[cfg(any(rustc_1_26, features="nightly"))] use core::iter; -use Rng; +use crate::Rng; pub use self::other::Alphanumeric; #[doc(inline)] pub use self::uniform::Uniform; pub use self::float::{OpenClosed01, Open01}; -pub use self::bernoulli::Bernoulli; +pub use self::bernoulli::{Bernoulli, BernoulliError}; #[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError}; + +// The following are all deprecated after being moved to rand_distr +#[allow(deprecated)] #[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::unit_circle::UnitCircle; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT, Beta}; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::normal::{Normal, LogNormal, StandardNormal}; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::exponential::{Exp, Exp1}; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::pareto::Pareto; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::poisson::Poisson; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::binomial::Binomial; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::cauchy::Cauchy; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::dirichlet::Dirichlet; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::triangular::Triangular; +#[allow(deprecated)] #[cfg(feature="std")] pub use self::weibull::Weibull; pub mod uniform; mod bernoulli; -#[cfg(feature="alloc")] mod weighted; +#[cfg(feature="alloc")] pub mod weighted; #[cfg(feature="std")] mod unit_sphere; #[cfg(feature="std")] mod unit_circle; #[cfg(feature="std")] mod gamma; @@ -222,6 +154,9 @@ mod bernoulli; #[cfg(feature="std")] mod weibull; mod float; +#[doc(hidden)] pub mod hidden_export { + pub use super::float::IntoFloat; // used by rand_distr +} mod integer; mod other; mod utils; @@ -238,8 +173,7 @@ mod utils; /// advantage of not needing to consider thread safety, and for most /// distributions efficient state-less sampling algorithms are available. /// -/// [`Rng`]: ../trait.Rng.html -/// [`sample_iter`]: trait.Distribution.html#method.sample_iter +/// [`sample_iter`]: Distribution::method.sample_iter pub trait Distribution<T> { /// Generate a random value of `T`, using `rng` as the source of randomness. fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T; @@ -247,33 +181,39 @@ pub trait Distribution<T> { /// Create an iterator that generates random values of `T`, using `rng` as /// the source of randomness. /// + /// Note that this function takes `self` by value. This works since + /// `Distribution<T>` is impl'd for `&D` where `D: Distribution<T>`, + /// however borrowing is not automatic hence `distr.sample_iter(...)` may + /// need to be replaced with `(&distr).sample_iter(...)` to borrow or + /// `(&*distr).sample_iter(...)` to reborrow an existing reference. + /// /// # Example /// /// ``` /// use rand::thread_rng; /// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard}; /// - /// let mut rng = thread_rng(); + /// let rng = thread_rng(); /// /// // Vec of 16 x f32: - /// let v: Vec<f32> = Standard.sample_iter(&mut rng).take(16).collect(); + /// let v: Vec<f32> = Standard.sample_iter(rng).take(16).collect(); /// /// // String: - /// let s: String = Alphanumeric.sample_iter(&mut rng).take(7).collect(); + /// let s: String = Alphanumeric.sample_iter(rng).take(7).collect(); /// /// // Dice-rolling: /// let die_range = Uniform::new_inclusive(1, 6); - /// let mut roll_die = die_range.sample_iter(&mut rng); + /// let mut roll_die = die_range.sample_iter(rng); /// while roll_die.next().unwrap() != 6 { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T> - where Self: Sized, R: Rng + fn sample_iter<R>(self, rng: R) -> DistIter<Self, R, T> + where R: Rng, Self: Sized { DistIter { distr: self, - rng: rng, + rng, phantom: ::core::marker::PhantomData, } } @@ -292,23 +232,25 @@ impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D { /// This `struct` is created by the [`sample_iter`] method on [`Distribution`]. /// See its documentation for more. /// -/// [`Distribution`]: trait.Distribution.html -/// [`sample_iter`]: trait.Distribution.html#method.sample_iter +/// [`sample_iter`]: Distribution::sample_iter #[derive(Debug)] -pub struct DistIter<'a, D: 'a, R: 'a, T> { - distr: &'a D, - rng: &'a mut R, +pub struct DistIter<D, R, T> { + distr: D, + rng: R, phantom: ::core::marker::PhantomData<T>, } -impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T> - where D: Distribution<T>, R: Rng + 'a +impl<D, R, T> Iterator for DistIter<D, R, T> + where D: Distribution<T>, R: Rng { type Item = T; #[inline(always)] fn next(&mut self) -> Option<T> { - Some(self.distr.sample(self.rng)) + // Here, self.rng may be a reference, but we must take &mut anyway. + // Even if sample could take an R: Rng by value, we would need to do this + // since Rng is not copyable and we cannot enforce that this is "reborrowable". + Some(self.distr.sample(&mut self.rng)) } fn size_hint(&self) -> (usize, Option<usize>) { @@ -316,20 +258,19 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T> } } -#[cfg(rustc_1_26)] -impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T> - where D: Distribution<T>, R: Rng + 'a {} +impl<D, R, T> iter::FusedIterator for DistIter<D, R, T> + where D: Distribution<T>, R: Rng {} #[cfg(features = "nightly")] -impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T> - where D: Distribution<T>, R: Rng + 'a {} +impl<D, R, T> iter::TrustedLen for DistIter<D, R, T> + where D: Distribution<T>, R: Rng {} /// A generic random value distribution, implemented for many primitive types. /// Usually generates values with a numerically uniform distribution, and with a /// range appropriate to the type. /// -/// ## Built-in Implementations +/// ## Provided implementations /// /// Assuming the provided `Rng` is well-behaved, these implementations /// generate values with the following ranges and distributions: @@ -346,20 +287,42 @@ impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T> /// * Wrapping integers (`Wrapping<T>`), besides the type identical to their /// normal integer variants. /// -/// The following aggregate types also implement the distribution `Standard` as -/// long as their component types implement it: +/// The `Standard` distribution also supports generation of the following +/// compound types where all component types are supported: /// -/// * Tuples and arrays: Each element of the tuple or array is generated -/// independently, using the `Standard` distribution recursively. -/// * `Option<T>` where `Standard` is implemented for `T`: Returns `None` with -/// probability 0.5; otherwise generates a random `x: T` and returns `Some(x)`. +/// * Tuples (up to 12 elements): each element is generated sequentially. +/// * Arrays (up to 32 elements): each element is generated sequentially; +/// see also [`Rng::fill`] which supports arbitrary array length for integer +/// types and tends to be faster for `u32` and smaller types. +/// * `Option<T>` first generates a `bool`, and if true generates and returns +/// `Some(value)` where `value: T`, otherwise returning `None`. /// -/// # Example +/// ## Custom implementations +/// +/// The [`Standard`] distribution may be implemented for user types as follows: +/// +/// ``` +/// # #![allow(dead_code)] +/// use rand::Rng; +/// use rand::distributions::{Distribution, Standard}; +/// +/// struct MyF32 { +/// x: f32, +/// } +/// +/// impl Distribution<MyF32> for Standard { +/// fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> MyF32 { +/// MyF32 { x: rng.gen() } +/// } +/// } +/// ``` +/// +/// ## Example usage /// ``` /// use rand::prelude::*; /// use rand::distributions::Standard; /// -/// let val: f32 = SmallRng::from_entropy().sample(Standard); +/// let val: f32 = StdRng::from_entropy().sample(Standard); /// println!("f32 from [0, 1): {}", val); /// ``` /// @@ -379,243 +342,40 @@ impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T> /// faster on some architectures (on modern Intel CPUs all methods have /// approximately equal performance). /// -/// [`Open01`]: struct.Open01.html -/// [`OpenClosed01`]: struct.OpenClosed01.html -/// [`Uniform`]: uniform/struct.Uniform.html +/// [`Uniform`]: uniform::Uniform #[derive(Clone, Copy, Debug)] pub struct Standard; -/// A value with a particular weight for use with `WeightedChoice`. -#[deprecated(since="0.6.0", note="use WeightedIndex instead")] -#[allow(deprecated)] -#[derive(Copy, Clone, Debug)] -pub struct Weighted<T> { - /// The numerical weight of this item - pub weight: u32, - /// The actual item which is being weighted - pub item: T, -} - -/// A distribution that selects from a finite collection of weighted items. -/// -/// Deprecated: use [`WeightedIndex`] instead. -/// -/// [`WeightedIndex`]: struct.WeightedIndex.html -#[deprecated(since="0.6.0", note="use WeightedIndex instead")] -#[allow(deprecated)] -#[derive(Debug)] -pub struct WeightedChoice<'a, T:'a> { - items: &'a mut [Weighted<T>], - weight_range: Uniform<u32>, -} - -#[deprecated(since="0.6.0", note="use WeightedIndex instead")] -#[allow(deprecated)] -impl<'a, T: Clone> WeightedChoice<'a, T> { - /// Create a new `WeightedChoice`. - /// - /// Panics if: - /// - /// - `items` is empty - /// - the total weight is 0 - /// - the total weight is larger than a `u32` can contain. - pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> { - // strictly speaking, this is subsumed by the total weight == 0 case - assert!(!items.is_empty(), "WeightedChoice::new called with no items"); - - let mut running_total: u32 = 0; - - // we convert the list from individual weights to cumulative - // weights so we can binary search. This *could* drop elements - // with weight == 0 as an optimisation. - for item in items.iter_mut() { - running_total = match running_total.checked_add(item.weight) { - Some(n) => n, - None => panic!("WeightedChoice::new called with a total weight \ - larger than a u32 can contain") - }; - - item.weight = running_total; - } - assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0"); - - WeightedChoice { - items, - // we're likely to be generating numbers in this range - // relatively often, so might as well cache it - weight_range: Uniform::new(0, running_total) - } - } -} - -#[deprecated(since="0.6.0", note="use WeightedIndex instead")] -#[allow(deprecated)] -impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> { - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T { - // we want to find the first element that has cumulative - // weight > sample_weight, which we do by binary since the - // cumulative weights of self.items are sorted. - - // choose a weight in [0, total_weight) - let sample_weight = self.weight_range.sample(rng); - - // short circuit when it's the first item - if sample_weight < self.items[0].weight { - return self.items[0].item.clone(); - } - - let mut idx = 0; - let mut modifier = self.items.len(); - - // now we know that every possibility has an element to the - // left, so we can just search for the last element that has - // cumulative weight <= sample_weight, then the next one will - // be "it". (Note that this greatest element will never be the - // last element of the vector, since sample_weight is chosen - // in [0, total_weight) and the cumulative weight of the last - // one is exactly the total weight.) - while modifier > 1 { - let i = idx + modifier / 2; - if self.items[i].weight <= sample_weight { - // we're small, so look to the right, but allow this - // exact element still. - idx = i; - // we need the `/ 2` to round up otherwise we'll drop - // the trailing elements when `modifier` is odd. - modifier += 1; - } else { - // otherwise we're too big, so go left. (i.e. do - // nothing) - } - modifier /= 2; - } - self.items[idx + 1].item.clone() - } -} - -#[cfg(test)] +#[cfg(all(test, feature = "std"))] mod tests { - use rngs::mock::StepRng; - #[allow(deprecated)] - use super::{WeightedChoice, Weighted, Distribution}; + use crate::Rng; + use super::{Distribution, Uniform}; #[test] - #[allow(deprecated)] - fn test_weighted_choice() { - // this makes assumptions about the internal implementation of - // WeightedChoice. It may fail when the implementation in - // `distributions::uniform::UniformInt` changes. - - macro_rules! t { - ($items:expr, $expected:expr) => {{ - let mut items = $items; - let mut total_weight = 0; - for item in &items { total_weight += item.weight; } - - let wc = WeightedChoice::new(&mut items); - let expected = $expected; - - // Use extremely large steps between the random numbers, because - // we test with small ranges and `UniformInt` is designed to prefer - // the most significant bits. - let mut rng = StepRng::new(0, !0 / (total_weight as u64)); - - for &val in expected.iter() { - assert_eq!(wc.sample(&mut rng), val) - } - }} - } - - t!([Weighted { weight: 1, item: 10}], [10]); - - // skip some - t!([Weighted { weight: 0, item: 20}, - Weighted { weight: 2, item: 21}, - Weighted { weight: 0, item: 22}, - Weighted { weight: 1, item: 23}], - [21, 21, 23]); - - // different weights - t!([Weighted { weight: 4, item: 30}, - Weighted { weight: 3, item: 31}], - [30, 31, 30, 31, 30, 31, 30]); - - // check that we're binary searching - // correctly with some vectors of odd - // length. - t!([Weighted { weight: 1, item: 40}, - Weighted { weight: 1, item: 41}, - Weighted { weight: 1, item: 42}, - Weighted { weight: 1, item: 43}, - Weighted { weight: 1, item: 44}], - [40, 41, 42, 43, 44]); - t!([Weighted { weight: 1, item: 50}, - Weighted { weight: 1, item: 51}, - Weighted { weight: 1, item: 52}, - Weighted { weight: 1, item: 53}, - Weighted { weight: 1, item: 54}, - Weighted { weight: 1, item: 55}, - Weighted { weight: 1, item: 56}], - [50, 54, 51, 55, 52, 56, 53]); - } - - #[test] - #[allow(deprecated)] - fn test_weighted_clone_initialization() { - let initial : Weighted<u32> = Weighted {weight: 1, item: 1}; - let clone = initial.clone(); - assert_eq!(initial.weight, clone.weight); - assert_eq!(initial.item, clone.item); - } - - #[test] #[should_panic] - #[allow(deprecated)] - fn test_weighted_clone_change_weight() { - let initial : Weighted<u32> = Weighted {weight: 1, item: 1}; - let mut clone = initial.clone(); - clone.weight = 5; - assert_eq!(initial.weight, clone.weight); - } - - #[test] #[should_panic] - #[allow(deprecated)] - fn test_weighted_clone_change_item() { - let initial : Weighted<u32> = Weighted {weight: 1, item: 1}; - let mut clone = initial.clone(); - clone.item = 5; - assert_eq!(initial.item, clone.item); - - } - - #[test] #[should_panic] - #[allow(deprecated)] - fn test_weighted_choice_no_items() { - WeightedChoice::<isize>::new(&mut []); - } - #[test] #[should_panic] - #[allow(deprecated)] - fn test_weighted_choice_zero_weight() { - WeightedChoice::new(&mut [Weighted { weight: 0, item: 0}, - Weighted { weight: 0, item: 1}]); - } - #[test] #[should_panic] - #[allow(deprecated)] - fn test_weighted_choice_weight_overflows() { - let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow - WeightedChoice::new(&mut [Weighted { weight: x, item: 0 }, - Weighted { weight: 1, item: 1 }, - Weighted { weight: x, item: 2 }, - Weighted { weight: 1, item: 3 }]); - } - - #[cfg(feature="std")] - #[test] fn test_distributions_iter() { - use distributions::Normal; - let mut rng = ::test::rng(210); - let distr = Normal::new(10.0, 10.0); - let results: Vec<_> = distr.sample_iter(&mut rng).take(100).collect(); + use crate::distributions::Open01; + let mut rng = crate::test::rng(210); + let distr = Open01; + let results: Vec<f32> = distr.sample_iter(&mut rng).take(100).collect(); println!("{:?}", results); } + + #[test] + fn test_make_an_iter() { + fn ten_dice_rolls_other_than_five<'a, R: Rng>(rng: &'a mut R) -> impl Iterator<Item = i32> + 'a { + Uniform::new_inclusive(1, 6) + .sample_iter(rng) + .filter(|x| *x != 5) + .take(10) + } + + let mut rng = crate::test::rng(211); + let mut count = 0; + for val in ten_dice_rolls_other_than_five(&mut rng) { + assert!(val >= 1 && val <= 6 && val != 5); + count += 1; + } + assert_eq!(count, 10); + } } diff --git a/rand/src/distributions/normal.rs b/rand/src/distributions/normal.rs index b8d632e..7808baf 100644 --- a/rand/src/distributions/normal.rs +++ b/rand/src/distributions/normal.rs @@ -8,10 +8,11 @@ // except according to those terms. //! The normal and derived distributions. +#![allow(deprecated)] -use Rng; -use distributions::{ziggurat_tables, Distribution, Open01}; -use distributions::utils::ziggurat; +use crate::Rng; +use crate::distributions::{ziggurat_tables, Distribution, Open01}; +use crate::distributions::utils::ziggurat; /// Samples floating-point numbers according to the normal distribution /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to @@ -25,15 +26,7 @@ use distributions::utils::ziggurat; /// Generate Normal Random Samples*]( /// https://www.doornik.com/research/ziggurat.pdf). /// Nuffield College, Oxford -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::StandardNormal; -/// -/// let val: f64 = SmallRng::from_entropy().sample(StandardNormal); -/// println!("{}", val); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct StandardNormal; @@ -80,18 +73,8 @@ impl Distribution<f64> for StandardNormal { /// Note that [`StandardNormal`] is an optimised implementation for mean 0, and /// standard deviation 1. /// -/// # Example -/// -/// ``` -/// use rand::distributions::{Normal, Distribution}; -/// -/// // mean 2, standard deviation 3 -/// let normal = Normal::new(2.0, 3.0); -/// let v = normal.sample(&mut rand::thread_rng()); -/// println!("{} is from a N(2, 9) distribution", v) -/// ``` -/// -/// [`StandardNormal`]: struct.StandardNormal.html +/// [`StandardNormal`]: crate::distributions::StandardNormal +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Normal { mean: f64, @@ -126,17 +109,7 @@ impl Distribution<f64> for Normal { /// /// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)` /// distributed. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{LogNormal, Distribution}; -/// -/// // mean 2, standard deviation 3 -/// let log_normal = LogNormal::new(2.0, 3.0); -/// let v = log_normal.sample(&mut rand::thread_rng()); -/// println!("{} is from an ln N(2, 9) distribution", v) -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct LogNormal { norm: Normal @@ -163,13 +136,13 @@ impl Distribution<f64> for LogNormal { #[cfg(test)] mod tests { - use distributions::Distribution; + use crate::distributions::Distribution; use super::{Normal, LogNormal}; #[test] fn test_normal() { let norm = Normal::new(10.0, 10.0); - let mut rng = ::test::rng(210); + let mut rng = crate::test::rng(210); for _ in 0..1000 { norm.sample(&mut rng); } @@ -184,7 +157,7 @@ mod tests { #[test] fn test_log_normal() { let lnorm = LogNormal::new(10.0, 10.0); - let mut rng = ::test::rng(211); + let mut rng = crate::test::rng(211); for _ in 0..1000 { lnorm.sample(&mut rng); } diff --git a/rand/src/distributions/other.rs b/rand/src/distributions/other.rs index 2295f79..6ec0473 100644 --- a/rand/src/distributions/other.rs +++ b/rand/src/distributions/other.rs @@ -11,8 +11,8 @@ use core::char; use core::num::Wrapping; -use {Rng}; -use distributions::{Distribution, Standard, Uniform}; +use crate::Rng; +use crate::distributions::{Distribution, Standard, Uniform}; // ----- Sampling distributions ----- @@ -116,6 +116,7 @@ macro_rules! tuple_impl { } impl Distribution<()> for Standard { + #[allow(clippy::unused_unit)] #[inline] fn sample<R: Rng + ?Sized>(&self, _: &mut R) -> () { () } } @@ -176,13 +177,13 @@ impl<T> Distribution<Wrapping<T>> for Standard where Standard: Distribution<T> { #[cfg(test)] mod tests { - use {Rng, RngCore, Standard}; - use distributions::Alphanumeric; + use crate::{Rng, RngCore, Standard}; + use crate::distributions::Alphanumeric; #[cfg(all(not(feature="std"), feature="alloc"))] use alloc::string::String; #[test] fn test_misc() { - let rng: &mut RngCore = &mut ::test::rng(820); + let rng: &mut dyn RngCore = &mut crate::test::rng(820); rng.sample::<char, _>(Standard); rng.sample::<bool, _>(Standard); @@ -192,7 +193,7 @@ mod tests { #[test] fn test_chars() { use core::iter; - let mut rng = ::test::rng(805); + let mut rng = crate::test::rng(805); // Test by generating a relatively large number of chars, so we also // take the rejection sampling path. @@ -203,7 +204,7 @@ mod tests { #[test] fn test_alphanumeric() { - let mut rng = ::test::rng(806); + let mut rng = crate::test::rng(806); // Test by generating a relatively large number of chars, so we also // take the rejection sampling path. diff --git a/rand/src/distributions/pareto.rs b/rand/src/distributions/pareto.rs index 744a157..edc9122 100644 --- a/rand/src/distributions/pareto.rs +++ b/rand/src/distributions/pareto.rs @@ -7,20 +7,13 @@ // except according to those terms. //! The Pareto distribution. +#![allow(deprecated)] -use Rng; -use distributions::{Distribution, OpenClosed01}; +use crate::Rng; +use crate::distributions::{Distribution, OpenClosed01}; /// Samples floating-point numbers according to the Pareto distribution -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::Pareto; -/// -/// let val: f64 = SmallRng::from_entropy().sample(Pareto::new(1., 2.)); -/// println!("{}", val); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Pareto { scale: f64, @@ -51,7 +44,7 @@ impl Distribution<f64> for Pareto { #[cfg(test)] mod tests { - use distributions::Distribution; + use crate::distributions::Distribution; use super::Pareto; #[test] @@ -65,7 +58,7 @@ mod tests { let scale = 1.0; let shape = 2.0; let d = Pareto::new(scale, shape); - let mut rng = ::test::rng(1); + let mut rng = crate::test::rng(1); for _ in 0..1000 { let r = d.sample(&mut rng); assert!(r >= scale); diff --git a/rand/src/distributions/poisson.rs b/rand/src/distributions/poisson.rs index 1244caa..9fd6e99 100644 --- a/rand/src/distributions/poisson.rs +++ b/rand/src/distributions/poisson.rs @@ -8,25 +8,17 @@ // except according to those terms. //! The Poisson distribution. +#![allow(deprecated)] -use Rng; -use distributions::{Distribution, Cauchy}; -use distributions::utils::log_gamma; +use crate::Rng; +use crate::distributions::{Distribution, Cauchy}; +use crate::distributions::utils::log_gamma; /// The Poisson distribution `Poisson(lambda)`. /// /// This distribution has a density function: /// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Poisson, Distribution}; -/// -/// let poi = Poisson::new(2.0); -/// let v = poi.sample(&mut rand::thread_rng()); -/// println!("{} is from a Poisson(2) distribution", v); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Poisson { lambda: f64, @@ -113,13 +105,14 @@ impl Distribution<u64> for Poisson { #[cfg(test)] mod test { - use distributions::Distribution; + use crate::distributions::Distribution; use super::Poisson; #[test] + #[cfg(not(miri))] // Miri is too slow fn test_poisson_10() { let poisson = Poisson::new(10.0); - let mut rng = ::test::rng(123); + let mut rng = crate::test::rng(123); let mut sum = 0; for _ in 0..1000 { sum += poisson.sample(&mut rng); @@ -130,10 +123,11 @@ mod test { } #[test] + #[cfg(not(miri))] // Miri doesn't support transcendental functions fn test_poisson_15() { // Take the 'high expected values' path let poisson = Poisson::new(15.0); - let mut rng = ::test::rng(123); + let mut rng = crate::test::rng(123); let mut sum = 0; for _ in 0..1000 { sum += poisson.sample(&mut rng); diff --git a/rand/src/distributions/triangular.rs b/rand/src/distributions/triangular.rs index a6eef5c..3e8f8b0 100644 --- a/rand/src/distributions/triangular.rs +++ b/rand/src/distributions/triangular.rs @@ -5,22 +5,15 @@ // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your // option. This file may not be copied, modified, or distributed // except according to those terms. + //! The triangular distribution. +#![allow(deprecated)] -use Rng; -use distributions::{Distribution, Standard}; +use crate::Rng; +use crate::distributions::{Distribution, Standard}; /// The triangular distribution. -/// -/// # Example -/// -/// ```rust -/// use rand::distributions::{Triangular, Distribution}; -/// -/// let d = Triangular::new(0., 5., 2.5); -/// let v = d.sample(&mut rand::thread_rng()); -/// println!("{} is from a triangular distribution", v); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Triangular { min: f64, @@ -61,7 +54,7 @@ impl Distribution<f64> for Triangular { #[cfg(test)] mod test { - use distributions::Distribution; + use crate::distributions::Distribution; use super::Triangular; #[test] @@ -78,7 +71,7 @@ mod test { #[test] fn test_sample() { let norm = Triangular::new(0., 1., 0.5); - let mut rng = ::test::rng(1); + let mut rng = crate::test::rng(1); for _ in 0..1000 { norm.sample(&mut rng); } diff --git a/rand/src/distributions/uniform.rs b/rand/src/distributions/uniform.rs index ceed77d..8c90f4e 100644 --- a/rand/src/distributions/uniform.rs +++ b/rand/src/distributions/uniform.rs @@ -15,13 +15,13 @@ //! [`Uniform`]. //! //! This distribution is provided with support for several primitive types -//! (all integer and floating-point types) as well as `std::time::Duration`, +//! (all integer and floating-point types) as well as [`std::time::Duration`], //! and supports extension to user-defined types via a type-specific *back-end* //! implementation. //! //! The types [`UniformInt`], [`UniformFloat`] and [`UniformDuration`] are the //! back-ends supporting sampling from primitive integer and floating-point -//! ranges as well as from `std::time::Duration`; these types do not normally +//! ranges as well as from [`std::time::Duration`]; these types do not normally //! need to be used directly (unless implementing a derived back-end). //! //! # Example usage @@ -100,28 +100,26 @@ //! let x = uniform.sample(&mut thread_rng()); //! ``` //! -//! [`Uniform`]: struct.Uniform.html -//! [`Rng::gen_range`]: ../../trait.Rng.html#method.gen_range -//! [`SampleUniform`]: trait.SampleUniform.html -//! [`UniformSampler`]: trait.UniformSampler.html -//! [`UniformInt`]: struct.UniformInt.html -//! [`UniformFloat`]: struct.UniformFloat.html -//! [`UniformDuration`]: struct.UniformDuration.html -//! [`SampleBorrow::borrow`]: trait.SampleBorrow.html#method.borrow +//! [`SampleUniform`]: crate::distributions::uniform::SampleUniform +//! [`UniformSampler`]: crate::distributions::uniform::UniformSampler +//! [`UniformInt`]: crate::distributions::uniform::UniformInt +//! [`UniformFloat`]: crate::distributions::uniform::UniformFloat +//! [`UniformDuration`]: crate::distributions::uniform::UniformDuration +//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow #[cfg(feature = "std")] use std::time::Duration; -#[cfg(all(not(feature = "std"), rustc_1_25))] +#[cfg(not(feature = "std"))] use core::time::Duration; -use Rng; -use distributions::Distribution; -use distributions::float::IntoFloat; -use distributions::utils::{WideningMultiply, FloatSIMDUtils, FloatAsSIMD, BoolAsSIMD}; +use crate::Rng; +use crate::distributions::Distribution; +use crate::distributions::float::IntoFloat; +use crate::distributions::utils::{WideningMultiply, FloatSIMDUtils, FloatAsSIMD, BoolAsSIMD}; #[cfg(not(feature = "std"))] #[allow(unused_imports)] // rustc doesn't detect that this is actually used -use distributions::utils::Float; +use crate::distributions::utils::Float; #[cfg(feature="simd_support")] @@ -165,10 +163,8 @@ use packed_simd::*; /// } /// ``` /// -/// [`Uniform::new`]: struct.Uniform.html#method.new -/// [`Uniform::new_inclusive`]: struct.Uniform.html#method.new_inclusive -/// [`new`]: struct.Uniform.html#method.new -/// [`new_inclusive`]: struct.Uniform.html#method.new_inclusive +/// [`new`]: Uniform::new +/// [`new_inclusive`]: Uniform::new_inclusive #[derive(Clone, Copy, Debug)] pub struct Uniform<X: SampleUniform> { inner: X::Sampler, @@ -206,9 +202,7 @@ impl<X: SampleUniform> Distribution<X> for Uniform<X> { /// See the [module documentation] on how to implement [`Uniform`] range /// sampling for a custom type. /// -/// [`UniformSampler`]: trait.UniformSampler.html -/// [module documentation]: index.html -/// [`Uniform`]: struct.Uniform.html +/// [module documentation]: crate::distributions::uniform pub trait SampleUniform: Sized { /// The `UniformSampler` implementation supporting type `X`. type Sampler: UniformSampler<X = Self>; @@ -222,9 +216,8 @@ pub trait SampleUniform: Sized { /// Implementation of [`sample_single`] is optional, and is only useful when /// the implementation can be faster than `Self::new(low, high).sample(rng)`. /// -/// [module documentation]: index.html -/// [`Uniform`]: struct.Uniform.html -/// [`sample_single`]: trait.UniformSampler.html#method.sample_single +/// [module documentation]: crate::distributions::uniform +/// [`sample_single`]: UniformSampler::sample_single pub trait UniformSampler: Sized { /// The type sampled by this implementation. type X; @@ -253,14 +246,11 @@ pub trait UniformSampler: Sized { /// Sample a single value uniformly from a range with inclusive lower bound /// and exclusive upper bound `[low, high)`. /// - /// Usually users should not call this directly but instead use - /// `Uniform::sample_single`, which asserts that `low < high` before calling - /// this. - /// - /// Via this method, implementations can provide a method optimized for - /// sampling only a single value from the specified range. The default - /// implementation simply calls `UniformSampler::new` then `sample` on the - /// result. + /// By default this is implemented using + /// `UniformSampler::new(low, high).sample(rng)`. However, for some types + /// more optimal implementations for single usage may be provided via this + /// method (which is the case for integers and floats). + /// Results may not be identical. fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R) -> Self::X where B1: SampleBorrow<Self::X> + Sized, @@ -277,7 +267,6 @@ impl<X: SampleUniform> From<::core::ops::Range<X>> for Uniform<X> { } } -#[cfg(rustc_1_27)] impl<X: SampleUniform> From<::core::ops::RangeInclusive<X>> for Uniform<X> { fn from(r: ::core::ops::RangeInclusive<X>) -> Uniform<X> { Uniform::new_inclusive(r.start(), r.end()) @@ -288,11 +277,11 @@ impl<X: SampleUniform> From<::core::ops::RangeInclusive<X>> for Uniform<X> { /// only for SampleUniform and references to SampleUniform in /// order to resolve ambiguity issues. /// -/// [`Borrow`]: https://doc.rust-lang.org/std/borrow/trait.Borrow.html +/// [`Borrow`]: std::borrow::Borrow pub trait SampleBorrow<Borrowed> { /// Immutably borrows from an owned value. See [`Borrow::borrow`] /// - /// [`Borrow::borrow`]: https://doc.rust-lang.org/std/borrow/trait.Borrow.html#tymethod.borrow + /// [`Borrow::borrow`]: std::borrow::Borrow::borrow fn borrow(&self) -> &Borrowed; } impl<Borrowed> SampleBorrow<Borrowed> for Borrowed where Borrowed: SampleUniform { @@ -316,48 +305,42 @@ impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed where Borrowed: Sampl /// /// # Implementation notes /// +/// For simplicity, we use the same generic struct `UniformInt<X>` for all +/// integer types `X`. This gives us only one field type, `X`; to store unsigned +/// values of this size, we take use the fact that these conversions are no-ops. +/// /// For a closed range, the number of possible numbers we should generate is -/// `range = (high - low + 1)`. It is not possible to end up with a uniform -/// distribution if we map *all* the random integers that can be generated to -/// this range. We have to map integers from a `zone` that is a multiple of the -/// range. The rest of the integers, that cause a bias, are rejected. +/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of +/// our sample space, `zone`, is a multiple of `range`; other values must be +/// rejected (by replacing with a new random sample). /// -/// The problem with `range` is that to cover the full range of the type, it has -/// to store `unsigned_max + 1`, which can't be represented. But if the range -/// covers the full range of the type, no modulus is needed. A range of size 0 -/// can't exist, so we use that to represent this special case. Wrapping -/// arithmetic even makes representing `unsigned_max + 1` as 0 simple. +/// As a special case, we use `range = 0` to represent the full range of the +/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). /// -/// We don't calculate `zone` directly, but first calculate the number of -/// integers to reject. To handle `unsigned_max + 1` not fitting in the type, -/// we use: -/// `ints_to_reject = (unsigned_max + 1) % range;` -/// `ints_to_reject = (unsigned_max - range + 1) % range;` +/// The optimum `zone` is the largest product of `range` which fits in our +/// (unsigned) target type. We calculate this by calculating how many numbers we +/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) +/// product of `range` will suffice, thus in `sample_single` we multiply by a +/// power of 2 via bit-shifting (faster but may cause more rejections). /// -/// The smallest integer PRNGs generate is `u32`. That is why for small integer -/// sizes (`i8`/`u8` and `i16`/`u16`) there is an optimization: don't pick the -/// largest zone that can fit in the small type, but pick the largest zone that -/// can fit in an `u32`. `ints_to_reject` is always less than half the size of -/// the small integer. This means the first bit of `zone` is always 1, and so -/// are all the other preceding bits of a larger integer. The easiest way to -/// grow the `zone` for the larger type is to simply sign extend it. +/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we +/// use `u32` for our `zone` and samples (because it's not slower and because +/// it reduces the chance of having to reject a sample). In this case we cannot +/// store `zone` in the target type since it is too large, however we know +/// `ints_to_reject < range <= $unsigned::MAX`. /// /// An alternative to using a modulus is widening multiply: After a widening /// multiply by `range`, the result is in the high word. Then comparing the low /// word against `zone` makes sure our distribution is uniform. -/// -/// [`UniformSampler`]: trait.UniformSampler.html -/// [`Uniform`]: struct.Uniform.html #[derive(Clone, Copy, Debug)] pub struct UniformInt<X> { low: X, range: X, - zone: X, + z: X, // either ints_to_reject or zone depending on implementation } macro_rules! uniform_int_impl { - ($ty:ty, $signed:ty, $unsigned:ident, - $i_large:ident, $u_large:ident) => { + ($ty:ty, $unsigned:ident, $u_large:ident) => { impl SampleUniform for $ty { type Sampler = UniformInt<$ty>; } @@ -392,34 +375,30 @@ macro_rules! uniform_int_impl { let high = *high_b.borrow(); assert!(low <= high, "Uniform::new_inclusive called with `low > high`"); - let unsigned_max = ::core::$unsigned::MAX; + let unsigned_max = ::core::$u_large::MAX; let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned; let ints_to_reject = if range > 0 { + let range = $u_large::from(range); (unsigned_max - range + 1) % range } else { 0 }; - let zone = unsigned_max - ints_to_reject; UniformInt { low: low, // These are really $unsigned values, but store as $ty: range: range as $ty, - zone: zone as $ty + z: ints_to_reject as $unsigned as $ty } } fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { let range = self.range as $unsigned as $u_large; if range > 0 { - // Grow `zone` to fit a type of at least 32 bits, by - // sign-extending it (the first bit is always 1, so are all - // the preceding bits of the larger type). - // For types that already have the right size, all the - // casting is a no-op. - let zone = self.zone as $signed as $i_large as $u_large; + let unsigned_max = ::core::$u_large::MAX; + let zone = unsigned_max - (self.z as $unsigned as $u_large); loop { let v: $u_large = rng.gen(); let (hi, lo) = v.wmul(range); @@ -441,7 +420,7 @@ macro_rules! uniform_int_impl { let low = *low_b.borrow(); let high = *high_b.borrow(); assert!(low < high, - "Uniform::sample_single called with low >= high"); + "UniformSampler::sample_single: low >= high"); let range = high.wrapping_sub(low) as $unsigned as $u_large; let zone = if ::core::$unsigned::MAX <= ::core::u16::MAX as $unsigned { @@ -469,20 +448,20 @@ macro_rules! uniform_int_impl { } } -uniform_int_impl! { i8, i8, u8, i32, u32 } -uniform_int_impl! { i16, i16, u16, i32, u32 } -uniform_int_impl! { i32, i32, u32, i32, u32 } -uniform_int_impl! { i64, i64, u64, i64, u64 } -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] -uniform_int_impl! { i128, i128, u128, u128, u128 } -uniform_int_impl! { isize, isize, usize, isize, usize } -uniform_int_impl! { u8, i8, u8, i32, u32 } -uniform_int_impl! { u16, i16, u16, i32, u32 } -uniform_int_impl! { u32, i32, u32, i32, u32 } -uniform_int_impl! { u64, i64, u64, i64, u64 } -uniform_int_impl! { usize, isize, usize, isize, usize } -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] -uniform_int_impl! { u128, u128, u128, i128, u128 } +uniform_int_impl! { i8, u8, u32 } +uniform_int_impl! { i16, u16, u32 } +uniform_int_impl! { i32, u32, u32 } +uniform_int_impl! { i64, u64, u64 } +#[cfg(not(target_os = "emscripten"))] +uniform_int_impl! { i128, u128, u128 } +uniform_int_impl! { isize, usize, usize } +uniform_int_impl! { u8, u8, u32 } +uniform_int_impl! { u16, u16, u32 } +uniform_int_impl! { u32, u32, u32 } +uniform_int_impl! { u64, u64, u64 } +uniform_int_impl! { usize, usize, usize } +#[cfg(not(target_os = "emscripten"))] +uniform_int_impl! { u128, u128, u128 } #[cfg(all(feature = "simd_support", feature = "nightly"))] macro_rules! uniform_simd_int_impl { @@ -544,13 +523,13 @@ macro_rules! uniform_simd_int_impl { low: low, // These are really $unsigned values, but store as $ty: range: range.cast(), - zone: zone.cast(), + z: zone.cast(), } } fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { let range: $unsigned = self.range.cast(); - let zone: $unsigned = self.zone.cast(); + let zone: $unsigned = self.z.cast(); // This might seem very slow, generating a whole new // SIMD vector for every sample rejection. For most uses @@ -646,11 +625,9 @@ uniform_simd_int_impl! { /// multiply and addition. Values produced this way have what equals 22 bits of /// random digits for an `f32`, and 52 for an `f64`. /// -/// [`UniformSampler`]: trait.UniformSampler.html -/// [`new`]: trait.UniformSampler.html#tymethod.new -/// [`new_inclusive`]: trait.UniformSampler.html#tymethod.new_inclusive -/// [`Uniform`]: struct.Uniform.html -/// [`Standard`]: ../struct.Standard.html +/// [`new`]: UniformSampler::new +/// [`new_inclusive`]: UniformSampler::new_inclusive +/// [`Standard`]: crate::distributions::Standard #[derive(Clone, Copy, Debug)] pub struct UniformFloat<X> { low: X, @@ -748,7 +725,7 @@ macro_rules! uniform_float_impl { let low = *low_b.borrow(); let high = *high_b.borrow(); assert!(low.all_lt(high), - "Uniform::sample_single called with low >= high"); + "UniformSampler::sample_single: low >= high"); let mut scale = high - low; loop { @@ -799,7 +776,7 @@ macro_rules! uniform_float_impl { let mask = !scale.finite_mask(); if mask.any() { assert!(low.all_finite() && high.all_finite(), - "Uniform::sample_single called with non-finite boundaries"); + "Uniform::sample_single: low and high must be finite"); scale = scale.decrease_masked(mask); } } @@ -833,17 +810,12 @@ uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } /// /// Unless you are implementing [`UniformSampler`] for your own types, this type /// should not be used directly, use [`Uniform`] instead. -/// -/// [`UniformSampler`]: trait.UniformSampler.html -/// [`Uniform`]: struct.Uniform.html -#[cfg(any(feature = "std", rustc_1_25))] #[derive(Clone, Copy, Debug)] pub struct UniformDuration { mode: UniformDurationMode, offset: u32, } -#[cfg(any(feature = "std", rustc_1_25))] #[derive(Debug, Copy, Clone)] enum UniformDurationMode { Small { @@ -860,12 +832,10 @@ enum UniformDurationMode { } } -#[cfg(any(feature = "std", rustc_1_25))] impl SampleUniform for Duration { type Sampler = UniformDuration; } -#[cfg(any(feature = "std", rustc_1_25))] impl UniformSampler for UniformDuration { type X = Duration; @@ -895,8 +865,8 @@ impl UniformSampler for UniformDuration { let mut high_n = high.subsec_nanos(); if high_n < low_n { - high_s = high_s - 1; - high_n = high_n + 1_000_000_000; + high_s -= 1; + high_n += 1_000_000_000; } let mode = if low_s == high_s { @@ -907,10 +877,10 @@ impl UniformSampler for UniformDuration { } else { let max = high_s .checked_mul(1_000_000_000) - .and_then(|n| n.checked_add(high_n as u64)); + .and_then(|n| n.checked_add(u64::from(high_n))); if let Some(higher_bound) = max { - let lower_bound = low_s * 1_000_000_000 + low_n as u64; + let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); UniformDurationMode::Medium { nanos: Uniform::new_inclusive(lower_bound, higher_bound), } @@ -959,10 +929,10 @@ impl UniformSampler for UniformDuration { #[cfg(test)] mod tests { - use Rng; - use rngs::mock::StepRng; - use distributions::uniform::Uniform; - use distributions::utils::FloatAsSIMD; + use crate::Rng; + use crate::rngs::mock::StepRng; + use crate::distributions::uniform::Uniform; + use crate::distributions::utils::FloatAsSIMD; #[cfg(feature="simd_support")] use packed_simd::*; #[should_panic] @@ -973,7 +943,7 @@ mod tests { #[test] fn test_uniform_good_limits_equal_int() { - let mut rng = ::test::rng(804); + let mut rng = crate::test::rng(804); let dist = Uniform::new_inclusive(10, 10); for _ in 0..20 { assert_eq!(rng.sample(dist), 10); @@ -987,13 +957,14 @@ mod tests { } #[test] + #[cfg(not(miri))] // Miri is too slow fn test_integers() { use core::{i8, i16, i32, i64, isize}; use core::{u8, u16, u32, u64, usize}; - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[cfg(not(target_os = "emscripten"))] use core::{i128, u128}; - let mut rng = ::test::rng(251); + let mut rng = crate::test::rng(251); macro_rules! t { ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ for &(low, high) in $v.iter() { @@ -1054,7 +1025,7 @@ mod tests { } t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize); - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[cfg(not(target_os = "emscripten"))] t!(i128, u128); #[cfg(all(feature = "simd_support", feature = "nightly"))] @@ -1071,8 +1042,9 @@ mod tests { } #[test] + #[cfg(not(miri))] // Miri is too slow fn test_floats() { - let mut rng = ::test::rng(252); + let mut rng = crate::test::rng(252); let mut zero_rng = StepRng::new(0, 0); let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); macro_rules! t { @@ -1155,11 +1127,12 @@ mod tests { #[cfg(all(feature="std", not(target_arch = "wasm32"), not(target_arch = "asmjs")))] + #[cfg(not(miri))] // Miri does not support catching panics fn test_float_assertions() { use std::panic::catch_unwind; use super::SampleUniform; fn range<T: SampleUniform>(low: T, high: T) { - let mut rng = ::test::rng(253); + let mut rng = crate::test::rng(253); rng.gen_range(low, high); } @@ -1209,14 +1182,14 @@ mod tests { #[test] - #[cfg(any(feature = "std", rustc_1_25))] + #[cfg(not(miri))] // Miri is too slow fn test_durations() { #[cfg(feature = "std")] use std::time::Duration; - #[cfg(all(not(feature = "std"), rustc_1_25))] + #[cfg(not(feature = "std"))] use core::time::Duration; - let mut rng = ::test::rng(253); + let mut rng = crate::test::rng(253); let v = &[(Duration::new(10, 50000), Duration::new(100, 1234)), (Duration::new(0, 100), Duration::new(1, 50)), @@ -1232,7 +1205,7 @@ mod tests { #[test] fn test_custom_uniform() { - use distributions::uniform::{UniformSampler, UniformFloat, SampleUniform, SampleBorrow}; + use crate::distributions::uniform::{UniformSampler, UniformFloat, SampleUniform, SampleBorrow}; #[derive(Clone, Copy, PartialEq, PartialOrd)] struct MyF32 { x: f32, @@ -1267,7 +1240,7 @@ mod tests { let (low, high) = (MyF32{ x: 17.0f32 }, MyF32{ x: 22.0f32 }); let uniform = Uniform::new(low, high); - let mut rng = ::test::rng(804); + let mut rng = crate::test::rng(804); for _ in 0..100 { let x: MyF32 = rng.sample(uniform); assert!(low <= x && x < high); @@ -1284,7 +1257,6 @@ mod tests { assert_eq!(r.inner.scale, 5.0); } - #[cfg(rustc_1_27)] #[test] fn test_uniform_from_std_range_inclusive() { let r = Uniform::from(2u32..=6); diff --git a/rand/src/distributions/unit_circle.rs b/rand/src/distributions/unit_circle.rs index 01ab76a..56e75b6 100644 --- a/rand/src/distributions/unit_circle.rs +++ b/rand/src/distributions/unit_circle.rs @@ -6,28 +6,21 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use Rng; -use distributions::{Distribution, Uniform}; +#![allow(deprecated)] +#![allow(clippy::all)] + +use crate::Rng; +use crate::distributions::{Distribution, Uniform}; /// Samples uniformly from the edge of the unit circle in two dimensions. /// /// Implemented via a method by von Neumann[^1]. /// -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{UnitCircle, Distribution}; -/// -/// let circle = UnitCircle::new(); -/// let v = circle.sample(&mut rand::thread_rng()); -/// println!("{:?} is from the unit circle.", v) -/// ``` -/// /// [^1]: von Neumann, J. (1951) [*Various Techniques Used in Connection with /// Random Digits.*](https://mcnp.lanl.gov/pdf_files/nbs_vonneumann.pdf) /// NBS Appl. Math. Ser., No. 12. Washington, DC: U.S. Government Printing /// Office, pp. 36-38. +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct UnitCircle; @@ -61,7 +54,7 @@ impl Distribution<[f64; 2]> for UnitCircle { #[cfg(test)] mod tests { - use distributions::Distribution; + use crate::distributions::Distribution; use super::UnitCircle; /// Assert that two numbers are almost equal to each other. @@ -82,7 +75,7 @@ mod tests { #[test] fn norm() { - let mut rng = ::test::rng(1); + let mut rng = crate::test::rng(1); let dist = UnitCircle::new(); for _ in 0..1000 { let x = dist.sample(&mut rng); @@ -92,10 +85,17 @@ mod tests { #[test] fn value_stability() { - let mut rng = ::test::rng(2); - let dist = UnitCircle::new(); - assert_eq!(dist.sample(&mut rng), [-0.8032118336637037, 0.5956935036263119]); - assert_eq!(dist.sample(&mut rng), [-0.4742919588505423, -0.880367615130018]); - assert_eq!(dist.sample(&mut rng), [0.9297328981467168, 0.368234623716601]); + let mut rng = crate::test::rng(2); + let expected = [ + [-0.9965658683520504, -0.08280380447614634], + [-0.9790853270389644, -0.20345004884984505], + [-0.8449189758898707, 0.5348943112253227], + ]; + let samples = [ + UnitCircle.sample(&mut rng), + UnitCircle.sample(&mut rng), + UnitCircle.sample(&mut rng), + ]; + assert_eq!(samples, expected); } } diff --git a/rand/src/distributions/unit_sphere.rs b/rand/src/distributions/unit_sphere.rs index 37de88b..188f48c 100644 --- a/rand/src/distributions/unit_sphere.rs +++ b/rand/src/distributions/unit_sphere.rs @@ -6,27 +6,20 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use Rng; -use distributions::{Distribution, Uniform}; +#![allow(deprecated)] +#![allow(clippy::all)] + +use crate::Rng; +use crate::distributions::{Distribution, Uniform}; /// Samples uniformly from the surface of the unit sphere in three dimensions. /// /// Implemented via a method by Marsaglia[^1]. /// -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{UnitSphereSurface, Distribution}; -/// -/// let sphere = UnitSphereSurface::new(); -/// let v = sphere.sample(&mut rand::thread_rng()); -/// println!("{:?} is from the unit sphere surface.", v) -/// ``` -/// /// [^1]: Marsaglia, George (1972). [*Choosing a Point from the Surface of a /// Sphere.*](https://doi.org/10.1214/aoms/1177692644) /// Ann. Math. Statist. 43, no. 2, 645--646. +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct UnitSphereSurface; @@ -56,7 +49,7 @@ impl Distribution<[f64; 3]> for UnitSphereSurface { #[cfg(test)] mod tests { - use distributions::Distribution; + use crate::distributions::Distribution; use super::UnitSphereSurface; /// Assert that two numbers are almost equal to each other. @@ -77,7 +70,7 @@ mod tests { #[test] fn norm() { - let mut rng = ::test::rng(1); + let mut rng = crate::test::rng(1); let dist = UnitSphereSurface::new(); for _ in 0..1000 { let x = dist.sample(&mut rng); @@ -87,13 +80,17 @@ mod tests { #[test] fn value_stability() { - let mut rng = ::test::rng(2); - let dist = UnitSphereSurface::new(); - assert_eq!(dist.sample(&mut rng), - [-0.24950027180862533, -0.7552572587896719, 0.6060825747478084]); - assert_eq!(dist.sample(&mut rng), - [0.47604534507233487, -0.797200864987207, -0.3712837328763685]); - assert_eq!(dist.sample(&mut rng), - [0.9795722330927367, 0.18692349236651176, 0.07414747571708524]); + let mut rng = crate::test::rng(2); + let expected = [ + [0.03247542860231647, -0.7830477442152738, 0.6211131755296027], + [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], + [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], + ]; + let samples = [ + UnitSphereSurface.sample(&mut rng), + UnitSphereSurface.sample(&mut rng), + UnitSphereSurface.sample(&mut rng), + ]; + assert_eq!(samples, expected); } } diff --git a/rand/src/distributions/utils.rs b/rand/src/distributions/utils.rs index d4d3642..3af4e86 100644 --- a/rand/src/distributions/utils.rs +++ b/rand/src/distributions/utils.rs @@ -11,9 +11,9 @@ #[cfg(feature="simd_support")] use packed_simd::*; #[cfg(feature="std")] -use distributions::ziggurat_tables; +use crate::distributions::ziggurat_tables; #[cfg(feature="std")] -use Rng; +use crate::Rng; pub trait WideningMultiply<RHS = Self> { @@ -61,7 +61,7 @@ macro_rules! wmul_impl { wmul_impl! { u8, u16, 8 } wmul_impl! { u16, u32, 16 } wmul_impl! { u32, u64, 32 } -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +#[cfg(not(target_os = "emscripten"))] wmul_impl! { u64, u128, 64 } // This code is a translation of the __mulddi3 function in LLVM's @@ -125,9 +125,9 @@ macro_rules! wmul_impl_large { )+ }; } -#[cfg(not(all(rustc_1_26, not(target_os = "emscripten"))))] +#[cfg(target_os = "emscripten")] wmul_impl_large! { u64, 32 } -#[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +#[cfg(not(target_os = "emscripten"))] wmul_impl_large! { u128, 64 } macro_rules! wmul_impl_usize { @@ -249,13 +249,9 @@ pub(crate) trait FloatSIMDUtils { /// Implement functions available in std builds but missing from core primitives #[cfg(not(std))] pub(crate) trait Float : Sized { - type Bits; - fn is_nan(self) -> bool; fn is_infinite(self) -> bool; fn is_finite(self) -> bool; - fn to_bits(self) -> Self::Bits; - fn from_bits(v: Self::Bits) -> Self; } /// Implement functions on f32/f64 to give them APIs similar to SIMD types @@ -289,8 +285,6 @@ macro_rules! scalar_float_impl { ($ty:ident, $uty:ident) => { #[cfg(not(std))] impl Float for $ty { - type Bits = $uty; - #[inline] fn is_nan(self) -> bool { self != self @@ -305,17 +299,6 @@ macro_rules! scalar_float_impl { fn is_finite(self) -> bool { !(self.is_nan() || self.is_infinite()) } - - #[inline] - fn to_bits(self) -> Self::Bits { - unsafe { ::core::mem::transmute(self) } - } - - #[inline] - fn from_bits(v: Self::Bits) -> Self { - // It turns out the safety issues with sNaN were overblown! Hooray! - unsafe { ::core::mem::transmute(v) } - } } impl FloatSIMDUtils for $ty { @@ -383,6 +366,7 @@ macro_rules! simd_impl { <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask)) } type UInt = $uty; + #[inline] fn cast_from_int(i: Self::UInt) -> Self { i.cast() } } } @@ -464,7 +448,7 @@ pub fn ziggurat<R: Rng + ?Sized, P, Z>( mut pdf: P, mut zero_case: Z) -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64 { - use distributions::float::IntoFloat; + use crate::distributions::float::IntoFloat; loop { // As an optimisation we re-implement the conversion to a f64. // From the remaining 12 most significant bits we use 8 to construct `i`. diff --git a/rand/src/distributions/weibull.rs b/rand/src/distributions/weibull.rs index 5fbe10a..483714f 100644 --- a/rand/src/distributions/weibull.rs +++ b/rand/src/distributions/weibull.rs @@ -7,20 +7,13 @@ // except according to those terms. //! The Weibull distribution. +#![allow(deprecated)] -use Rng; -use distributions::{Distribution, OpenClosed01}; +use crate::Rng; +use crate::distributions::{Distribution, OpenClosed01}; /// Samples floating-point numbers according to the Weibull distribution -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::Weibull; -/// -/// let val: f64 = SmallRng::from_entropy().sample(Weibull::new(1., 10.)); -/// println!("{}", val); -/// ``` +#[deprecated(since="0.7.0", note="moved to rand_distr crate")] #[derive(Clone, Copy, Debug)] pub struct Weibull { inv_shape: f64, @@ -48,7 +41,7 @@ impl Distribution<f64> for Weibull { #[cfg(test)] mod tests { - use distributions::Distribution; + use crate::distributions::Distribution; use super::Weibull; #[test] @@ -62,7 +55,7 @@ mod tests { let scale = 1.0; let shape = 2.0; let d = Weibull::new(scale, shape); - let mut rng = ::test::rng(1); + let mut rng = crate::test::rng(1); for _ in 0..1000 { let r = d.sample(&mut rng); assert!(r >= 0.); diff --git a/rand/src/distributions/weighted/alias_method.rs b/rand/src/distributions/weighted/alias_method.rs new file mode 100644 index 0000000..bdd4ba0 --- /dev/null +++ b/rand/src/distributions/weighted/alias_method.rs @@ -0,0 +1,499 @@ +//! This module contains an implementation of alias method for sampling random +//! indices with probabilities proportional to a collection of weights. + +use super::WeightedError; +#[cfg(not(feature = "std"))] +use crate::alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use crate::alloc::vec; +use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; +use crate::distributions::uniform::SampleUniform; +use crate::distributions::Distribution; +use crate::distributions::Uniform; +use crate::Rng; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedIndex<W>`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedIndex<W>`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which a implementation of +/// [`Weight`] exists. +/// +/// # Performance +/// +/// Given that `n` is the number of items in the vector used to create an +/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of +/// memory. More specifically it takes up some constant amount of memory plus +/// the vector used to create it and a [`Vec<u32>`] with capacity `n`. +/// +/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`. +/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call +/// to [`Uniform<W>::sample`]. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::weighted::alias_method::WeightedIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = WeightedIndex::new(weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`WeightedIndex<W>`]: crate::distributions::weighted::alias_method::WeightedIndex +/// [`Weight`]: crate::distributions::weighted::alias_method::Weight +/// [`Vec<u32>`]: Vec +/// [`Uniform<u32>::sample`]: Distribution::sample +/// [`Uniform<W>::sample`]: Distribution::sample +pub struct WeightedIndex<W: Weight> { + aliases: Vec<u32>, + no_alias_odds: Vec<W>, + uniform_index: Uniform<u32>, + uniform_within_weight_sum: Uniform<W>, +} + +impl<W: Weight> WeightedIndex<W> { + /// Creates a new [`WeightedIndex`]. + /// + /// Returns an error if: + /// - The vector is empty. + /// - The vector is longer than `u32::MAX`. + /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / + /// weights.len()`. + /// - The sum of weights is zero. + pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> { + let n = weights.len(); + if n == 0 { + return Err(WeightedError::NoItem); + } else if n > ::core::u32::MAX as usize { + return Err(WeightedError::TooMany); + } + let n = n as u32; + + let max_weight_size = W::try_from_u32_lossy(n) + .map(|n| W::MAX / n) + .unwrap_or(W::ZERO); + if !weights + .iter() + .all(|&w| W::ZERO <= w && w <= max_weight_size) + { + return Err(WeightedError::InvalidWeight); + } + + // The sum of weights will represent 100% of no alias odds. + let weight_sum = Weight::sum(weights.as_slice()); + // Prevent floating point overflow due to rounding errors. + let weight_sum = if weight_sum > W::MAX { + W::MAX + } else { + weight_sum + }; + if weight_sum == W::ZERO { + return Err(WeightedError::AllWeightsZero); + } + + // `weight_sum` would have been zero if `try_from_lossy` causes an error here. + let n_converted = W::try_from_u32_lossy(n).unwrap(); + + let mut no_alias_odds = weights; + for odds in no_alias_odds.iter_mut() { + *odds *= n_converted; + // Prevent floating point overflow due to rounding errors. + *odds = if *odds > W::MAX { W::MAX } else { *odds }; + } + + /// This struct is designed to contain three data structures at once, + /// sharing the same memory. More precisely it contains two linked lists + /// and an alias map, which will be the output of this method. To keep + /// the three data structures from getting in each other's way, it must + /// be ensured that a single index is only ever in one of them at the + /// same time. + struct Aliases { + aliases: Vec<u32>, + smalls_head: u32, + bigs_head: u32, + } + + impl Aliases { + fn new(size: u32) -> Self { + Aliases { + aliases: vec![0; size as usize], + smalls_head: ::core::u32::MAX, + bigs_head: ::core::u32::MAX, + } + } + + fn push_small(&mut self, idx: u32) { + self.aliases[idx as usize] = self.smalls_head; + self.smalls_head = idx; + } + + fn push_big(&mut self, idx: u32) { + self.aliases[idx as usize] = self.bigs_head; + self.bigs_head = idx; + } + + fn pop_small(&mut self) -> u32 { + let popped = self.smalls_head; + self.smalls_head = self.aliases[popped as usize]; + popped + } + + fn pop_big(&mut self) -> u32 { + let popped = self.bigs_head; + self.bigs_head = self.aliases[popped as usize]; + popped + } + + fn smalls_is_empty(&self) -> bool { + self.smalls_head == ::core::u32::MAX + } + + fn bigs_is_empty(&self) -> bool { + self.bigs_head == ::core::u32::MAX + } + + fn set_alias(&mut self, idx: u32, alias: u32) { + self.aliases[idx as usize] = alias; + } + } + + let mut aliases = Aliases::new(n); + + // Split indices into those with small weights and those with big weights. + for (index, &odds) in no_alias_odds.iter().enumerate() { + if odds < weight_sum { + aliases.push_small(index as u32); + } else { + aliases.push_big(index as u32); + } + } + + // Build the alias map by finding an alias with big weight for each index with + // small weight. + while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { + let s = aliases.pop_small(); + let b = aliases.pop_big(); + + aliases.set_alias(s, b); + no_alias_odds[b as usize] = no_alias_odds[b as usize] + - weight_sum + + no_alias_odds[s as usize]; + + if no_alias_odds[b as usize] < weight_sum { + aliases.push_small(b); + } else { + aliases.push_big(b); + } + } + + // The remaining indices should have no alias odds of about 100%. This is due to + // numeric accuracy. Otherwise they would be exactly 100%. + while !aliases.smalls_is_empty() { + no_alias_odds[aliases.pop_small() as usize] = weight_sum; + } + while !aliases.bigs_is_empty() { + no_alias_odds[aliases.pop_big() as usize] = weight_sum; + } + + // Prepare distributions for sampling. Creating them beforehand improves + // sampling performance. + let uniform_index = Uniform::new(0, n); + let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); + + Ok(Self { + aliases: aliases.aliases, + no_alias_odds, + uniform_index, + uniform_within_weight_sum, + }) + } +} + +impl<W: Weight> Distribution<usize> for WeightedIndex<W> { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { + let candidate = rng.sample(self.uniform_index); + if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] { + candidate as usize + } else { + self.aliases[candidate as usize] as usize + } + } +} + +impl<W: Weight> fmt::Debug for WeightedIndex<W> +where + W: fmt::Debug, + Uniform<W>: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("WeightedIndex") + .field("aliases", &self.aliases) + .field("no_alias_odds", &self.no_alias_odds) + .field("uniform_index", &self.uniform_index) + .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) + .finish() + } +} + +impl<W: Weight> Clone for WeightedIndex<W> +where + Uniform<W>: Clone, +{ + fn clone(&self) -> Self { + Self { + aliases: self.aliases.clone(), + no_alias_odds: self.no_alias_odds.clone(), + uniform_index: self.uniform_index.clone(), + uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), + } + } +} + +/// Trait that must be implemented for weights, that are used with +/// [`WeightedIndex`]. Currently no guarantees on the correctness of +/// [`WeightedIndex`] are given for custom implementations of this trait. +pub trait Weight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add<Output = Self> + + AddAssign + + Sub<Output = Self> + + SubAssign + + Mul<Output = Self> + + MulAssign + + Div<Output = Self> + + DivAssign + + Sum +{ + /// Maximum number representable by `Self`. + const MAX: Self; + + /// Element of `Self` equivalent to 0. + const ZERO: Self; + + /// Produce an instance of `Self` from a `u32` value, or return `None` if + /// out of range. Loss of precision (where `Self` is a floating point type) + /// is acceptable. + fn try_from_u32_lossy(n: u32) -> Option<Self>; + + /// Sums all values in slice `values`. + fn sum(values: &[Self]) -> Self { + values.iter().map(|x| *x).sum() + } +} + +macro_rules! impl_weight_for_float { + ($T: ident) => { + impl Weight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0.0; + + fn try_from_u32_lossy(n: u32) -> Option<Self> { + Some(n as $T) + } + + fn sum(values: &[Self]) -> Self { + pairwise_sum(values) + } + } + }; +} + +/// In comparison to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum<T: Weight>(values: &[T]) -> T { + if values.len() <= 32 { + values.iter().map(|x| *x).sum() + } else { + let mid = values.len() / 2; + let (a, b) = values.split_at(mid); + pairwise_sum(a) + pairwise_sum(b) + } +} + +macro_rules! impl_weight_for_int { + ($T: ident) => { + impl Weight for $T { + const MAX: Self = ::core::$T::MAX; + const ZERO: Self = 0; + + fn try_from_u32_lossy(n: u32) -> Option<Self> { + let n_converted = n as Self; + if n_converted >= Self::ZERO && n_converted as u32 == n { + Some(n_converted) + } else { + None + } + } + } + }; +} + +impl_weight_for_float!(f64); +impl_weight_for_float!(f32); +impl_weight_for_int!(usize); +#[cfg(not(target_os = "emscripten"))] +impl_weight_for_int!(u128); +impl_weight_for_int!(u64); +impl_weight_for_int!(u32); +impl_weight_for_int!(u16); +impl_weight_for_int!(u8); +impl_weight_for_int!(isize); +#[cfg(not(target_os = "emscripten"))] +impl_weight_for_int!(i128); +impl_weight_for_int!(i64); +impl_weight_for_int!(i32); +impl_weight_for_int!(i16); +impl_weight_for_int!(i8); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[cfg(not(miri))] // Miri is too slow + fn test_weighted_index_f32() { + test_weighted_index(f32::into); + + // Floating point special cases + assert_eq!( + WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![-0_f32]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(vec![-1_f32]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[cfg(not(target_os = "emscripten"))] + #[test] + #[cfg(not(miri))] // Miri is too slow + fn test_weighted_index_u128() { + test_weighted_index(|x: u128| x as f64); + } + + #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] + #[test] + #[cfg(not(miri))] // Miri is too slow + fn test_weighted_index_i128() { + test_weighted_index(|x: i128| x as f64); + + // Signed integer special cases + assert_eq!( + WeightedIndex::new(vec![-1_i128]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + #[cfg(not(miri))] // Miri is too slow + fn test_weighted_index_u8() { + test_weighted_index(u8::into); + } + + #[test] + #[cfg(not(miri))] // Miri is too slow + fn test_weighted_index_i8() { + test_weighted_index(i8::into); + + // Signed integer special cases + assert_eq!( + WeightedIndex::new(vec![-1_i8]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + fn test_weighted_index<W: Weight, F: Fn(W) -> f64>(w_to_f64: F) + where + WeightedIndex<W>: fmt::Debug, + { + const NUM_WEIGHTS: u32 = 10; + const ZERO_WEIGHT_INDEX: u32 = 3; + const NUM_SAMPLES: u32 = 15000; + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + + let weights = { + let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); + let random_weight_distribution = crate::distributions::Uniform::new_inclusive( + W::ZERO, + W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), + ); + for _ in 0..NUM_WEIGHTS { + weights.push(rng.sample(&random_weight_distribution)); + } + weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; + weights + }; + let weight_sum = weights.iter().map(|w| *w).sum::<W>(); + let expected_counts = weights + .iter() + .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) + .collect::<Vec<f64>>(); + let weight_distribution = WeightedIndex::new(weights).unwrap(); + + let mut counts = vec![0; NUM_WEIGHTS as usize]; + for _ in 0..NUM_SAMPLES { + counts[rng.sample(&weight_distribution)] += 1; + } + + assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0); + for (count, expected_count) in counts.into_iter().zip(expected_counts) { + let difference = (count as f64 - expected_count).abs(); + let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; + assert!(difference <= max_allowed_difference); + } + + assert_eq!( + WeightedIndex::<W>::new(vec![]).unwrap_err(), + WeightedError::NoItem + ); + assert_eq!( + WeightedIndex::new(vec![W::ZERO]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), + WeightedError::InvalidWeight + ); + } +} diff --git a/rand/src/distributions/weighted.rs b/rand/src/distributions/weighted/mod.rs index 01c8fe6..2711637 100644 --- a/rand/src/distributions/weighted.rs +++ b/rand/src/distributions/weighted/mod.rs @@ -6,14 +6,26 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use Rng; -use distributions::Distribution; -use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; -use ::core::cmp::PartialOrd; +//! Weighted index sampling +//! +//! This module provides two implementations for sampling indices: +//! +//! * [`WeightedIndex`] allows `O(log N)` sampling +//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with +//! much greater set-up cost +//! +//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html + +pub mod alias_method; + +use crate::Rng; +use crate::distributions::Distribution; +use crate::distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; +use core::cmp::PartialOrd; use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. -#[cfg(not(feature="std"))] use alloc::vec::Vec; +#[cfg(not(feature="std"))] use crate::alloc::vec::Vec; /// A distribution using weighted sampling to pick a discretely selected /// item. @@ -40,9 +52,9 @@ use core::fmt; /// `N` is the number of weights. /// /// Sampling from `WeightedIndex` will result in a single call to -/// [`Uniform<X>::sample`], which typically will request a single value from -/// the underlying [`RngCore`], though the exact number depends on the -/// implementaiton of [`Uniform<X>::sample`]. +/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementaiton of `Uniform<X>::sample`. /// /// # Example /// @@ -67,12 +79,12 @@ use core::fmt; /// } /// ``` /// -/// [`Uniform<X>`]: struct.Uniform.html -/// [`Uniform<X>::sample`]: struct.Uniform.html#method.sample -/// [`RngCore`]: ../trait.RngCore.html +/// [`Uniform<X>`]: crate::distributions::uniform::Uniform +/// [`RngCore`]: crate::RngCore #[derive(Debug, Clone)] pub struct WeightedIndex<X: SampleUniform + PartialOrd> { cumulative_weights: Vec<X>, + total_weight: X, weight_distribution: X::Sampler, } @@ -84,8 +96,7 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { /// Returns an error if the iterator is empty, if any weight is `< 0`, or /// if its total value is 0. /// - /// [`Distribution`]: trait.Distribution.html - /// [`Uniform<X>`]: struct.Uniform.html + /// [`Uniform<X>`]: crate::distributions::uniform::Uniform pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> where I: IntoIterator, I::Item: SampleBorrow<X>, @@ -100,13 +111,13 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { let zero = <X as Default>::default(); if total_weight < zero { - return Err(WeightedError::NegativeWeight); + return Err(WeightedError::InvalidWeight); } let mut weights = Vec::<X>::with_capacity(iter.size_hint().0); for w in iter { if *w.borrow() < zero { - return Err(WeightedError::NegativeWeight); + return Err(WeightedError::InvalidWeight); } weights.push(total_weight.clone()); total_weight += w.borrow(); @@ -115,9 +126,98 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { if total_weight == zero { return Err(WeightedError::AllWeightsZero); } - let distr = X::Sampler::new(zero, total_weight); + let distr = X::Sampler::new(zero, total_weight.clone()); - Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) + Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + where X: for<'a> ::core::ops::AddAssign<&'a X> + + for<'a> ::core::ops::SubAssign<&'a X> + + Clone + + Default { + if new_weights.is_empty() { + return Ok(()); + } + + let zero = <X as Default>::default(); + + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(WeightedError::InvalidWeight); + } + } + if *w < zero { + return Err(WeightedError::InvalidWeight); + } + if i >= self.cumulative_weights.len() + 1 { + return Err(WeightedError::TooMany); + } + + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); + } + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + }, + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); + + Ok(()) } } @@ -137,8 +237,9 @@ mod test { use super::*; #[test] + #[cfg(not(miri))] // Miri is too slow fn test_weightedindex() { - let mut r = ::test::rng(700); + let mut r = crate::test::rng(700); const N_REPS: u32 = 5000; let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; let total_weight = weights.iter().sum::<u32>() as f32; @@ -186,31 +287,61 @@ mod test { assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); - assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight); - assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); - assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); + assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); + assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); + } + + #[test] + fn test_update_weights() { + let data = [ + (&[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..]), + (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]), + ]; + + for (weights, update, expected_weights) in data.into_iter() { + let total_weight = weights.iter().sum::<u32>(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::<u32>(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } } } /// Error type returned from `WeightedIndex::new`. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum WeightedError { - /// The provided iterator contained no items. + /// The provided weight collection contains no items. NoItem, - /// A weight lower than zero was used. - NegativeWeight, + /// A weight is either less than zero, greater than the supported maximum or + /// otherwise invalid. + InvalidWeight, - /// All items in the provided iterator had a weight of zero. + /// All items in the provided weight collection are zero. AllWeightsZero, + + /// Too many weights are provided (length greater than `u32::MAX`) + TooMany, } impl WeightedError { fn msg(&self) -> &str { match *self { - WeightedError::NoItem => "No items found", - WeightedError::NegativeWeight => "Item has negative weight", - WeightedError::AllWeightsZero => "All items had weight zero", + WeightedError::NoItem => "No weights provided.", + WeightedError::InvalidWeight => "A weight is invalid.", + WeightedError::AllWeightsZero => "All weights are zero.", + WeightedError::TooMany => "Too many weights (hit u32::MAX)", } } } @@ -220,7 +351,7 @@ impl ::std::error::Error for WeightedError { fn description(&self) -> &str { self.msg() } - fn cause(&self) -> Option<&::std::error::Error> { + fn cause(&self) -> Option<&dyn (::std::error::Error)> { None } } |