// Copyright 2013 The Rust Project Developers. See the COPYRIGHT // file at the top-level directory of this distribution and at // http://rust-lang.org/COPYRIGHT. // // Licensed under the Apache License, Version 2.0 or the MIT license // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. //! Sampling from random distributions. //! //! This is a generalization of `Rand` to allow parameters to control the //! exact properties of the generated values, e.g. the mean and standard //! deviation of a normal distribution. The `Sample` trait is the most //! general, and allows for generating values that change some state //! internally. The `IndependentSample` trait is for generating values //! that do not need to record state. use core::marker; use {Rng, Rand}; pub use self::range::Range; #[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT}; #[cfg(feature="std")] pub use self::normal::{Normal, LogNormal}; #[cfg(feature="std")] pub use self::exponential::Exp; pub mod range; #[cfg(feature="std")] pub mod gamma; #[cfg(feature="std")] pub mod normal; #[cfg(feature="std")] pub mod exponential; #[cfg(feature="std")] mod ziggurat_tables; /// Types that can be used to create a random instance of `Support`. pub trait Sample { /// Generate a random value of `Support`, using `rng` as the /// source of randomness. fn sample(&mut self, rng: &mut R) -> Support; } /// `Sample`s that do not require keeping track of state. /// /// Since no state is recorded, each sample is (statistically) /// independent of all others, assuming the `Rng` used has this /// property. // FIXME maybe having this separate is overkill (the only reason is to // take &self rather than &mut self)? or maybe this should be the // trait called `Sample` and the other should be `DependentSample`. pub trait IndependentSample: Sample { /// Generate a random value. fn ind_sample(&self, &mut R) -> Support; } /// A wrapper for generating types that implement `Rand` via the /// `Sample` & `IndependentSample` traits. #[derive(Debug)] pub struct RandSample { _marker: marker::PhantomData Sup>, } impl Copy for RandSample {} impl Clone for RandSample { fn clone(&self) -> Self { *self } } impl Sample for RandSample { fn sample(&mut self, rng: &mut R) -> Sup { self.ind_sample(rng) } } impl IndependentSample for RandSample { fn ind_sample(&self, rng: &mut R) -> Sup { rng.gen() } } impl RandSample { pub fn new() -> RandSample { RandSample { _marker: marker::PhantomData } } } /// A value with a particular weight for use with `WeightedChoice`. #[derive(Copy, Clone, Debug)] pub struct Weighted { /// The numerical weight of this item pub weight: u32, /// The actual item which is being weighted pub item: T, } /// A distribution that selects from a finite collection of weighted items. /// /// Each item has an associated weight that influences how likely it /// is to be chosen: higher weight is more likely. /// /// The `Clone` restriction is a limitation of the `Sample` and /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for /// all `T`, as is `u32`, so one can store references or indices into /// another vector. /// /// # Example /// /// ```rust /// use rand::distributions::{Weighted, WeightedChoice, IndependentSample}; /// /// let mut items = vec!(Weighted { weight: 2, item: 'a' }, /// Weighted { weight: 4, item: 'b' }, /// Weighted { weight: 1, item: 'c' }); /// let wc = WeightedChoice::new(&mut items); /// let mut rng = rand::thread_rng(); /// for _ in 0..16 { /// // on average prints 'a' 4 times, 'b' 8 and 'c' twice. /// println!("{}", wc.ind_sample(&mut rng)); /// } /// ``` #[derive(Debug)] pub struct WeightedChoice<'a, T:'a> { items: &'a mut [Weighted], weight_range: Range } impl<'a, T: Clone> WeightedChoice<'a, T> { /// Create a new `WeightedChoice`. /// /// Panics if: /// /// - `items` is empty /// - the total weight is 0 /// - the total weight is larger than a `u32` can contain. pub fn new(items: &'a mut [Weighted]) -> WeightedChoice<'a, T> { // strictly speaking, this is subsumed by the total weight == 0 case assert!(!items.is_empty(), "WeightedChoice::new called with no items"); let mut running_total: u32 = 0; // we convert the list from individual weights to cumulative // weights so we can binary search. This *could* drop elements // with weight == 0 as an optimisation. for item in items.iter_mut() { running_total = match running_total.checked_add(item.weight) { Some(n) => n, None => panic!("WeightedChoice::new called with a total weight \ larger than a u32 can contain") }; item.weight = running_total; } assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0"); WeightedChoice { items: items, // we're likely to be generating numbers in this range // relatively often, so might as well cache it weight_range: Range::new(0, running_total) } } } impl<'a, T: Clone> Sample for WeightedChoice<'a, T> { fn sample(&mut self, rng: &mut R) -> T { self.ind_sample(rng) } } impl<'a, T: Clone> IndependentSample for WeightedChoice<'a, T> { fn ind_sample(&self, rng: &mut R) -> T { // we want to find the first element that has cumulative // weight > sample_weight, which we do by binary since the // cumulative weights of self.items are sorted. // choose a weight in [0, total_weight) let sample_weight = self.weight_range.ind_sample(rng); // short circuit when it's the first item if sample_weight < self.items[0].weight { return self.items[0].item.clone(); } let mut idx = 0; let mut modifier = self.items.len(); // now we know that every possibility has an element to the // left, so we can just search for the last element that has // cumulative weight <= sample_weight, then the next one will // be "it". (Note that this greatest element will never be the // last element of the vector, since sample_weight is chosen // in [0, total_weight) and the cumulative weight of the last // one is exactly the total weight.) while modifier > 1 { let i = idx + modifier / 2; if self.items[i].weight <= sample_weight { // we're small, so look to the right, but allow this // exact element still. idx = i; // we need the `/ 2` to round up otherwise we'll drop // the trailing elements when `modifier` is odd. modifier += 1; } else { // otherwise we're too big, so go left. (i.e. do // nothing) } modifier /= 2; } return self.items[idx + 1].item.clone(); } } /// Sample a random number using the Ziggurat method (specifically the /// ZIGNOR variant from Doornik 2005). Most of the arguments are /// directly from the paper: /// /// * `rng`: source of randomness /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0. /// * `X`: the $x_i$ abscissae. /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$) /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$ /// * `pdf`: the probability density function /// * `zero_case`: manual sampling from the tail when we chose the /// bottom box (i.e. i == 0) // the perf improvement (25-50%) is definitely worth the extra code // size from force-inlining. #[cfg(feature="std")] #[inline(always)] fn ziggurat( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, f_tab: ziggurat_tables::ZigTable, mut pdf: P, mut zero_case: Z) -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64 { const SCALE: f64 = (1u64 << 53) as f64; loop { // reimplement the f64 generation as an optimisation suggested // by the Doornik paper: we have a lot of precision-space // (i.e. there are 11 bits of the 64 of a u64 to use after // creating a f64), so we might as well reuse some to save // generating a whole extra random number. (Seems to be 15% // faster.) // // This unfortunately misses out on the benefits of direct // floating point generation if an RNG like dSMFT is // used. (That is, such RNGs create floats directly, highly // efficiently and overload next_f32/f64, so by not calling it // this may be slower than it would be otherwise.) // FIXME: investigate/optimise for the above. let bits: u64 = rng.gen(); let i = (bits & 0xff) as usize; let f = (bits >> 11) as f64 / SCALE; // u is either U(-1, 1) or U(0, 1) depending on if this is a // symmetric distribution or not. let u = if symmetric {2.0 * f - 1.0} else {f}; let x = u * x_tab[i]; let test_x = if symmetric { x.abs() } else {x}; // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i]) if test_x < x_tab[i + 1] { return x; } if i == 0 { return zero_case(rng, u); } // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::() < pdf(x) { return x; } } } #[cfg(test)] mod tests { use {Rng, Rand}; use super::{RandSample, WeightedChoice, Weighted, Sample, IndependentSample}; #[derive(PartialEq, Debug)] struct ConstRand(usize); impl Rand for ConstRand { fn rand(_: &mut R) -> ConstRand { ConstRand(0) } } // 0, 1, 2, 3, ... struct CountingRng { i: u32 } impl Rng for CountingRng { fn next_u32(&mut self) -> u32 { self.i += 1; self.i - 1 } fn next_u64(&mut self) -> u64 { self.next_u32() as u64 } } #[test] fn test_rand_sample() { let mut rand_sample = RandSample::::new(); assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0)); assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0)); } #[test] fn test_weighted_choice() { // this makes assumptions about the internal implementation of // WeightedChoice, specifically: it doesn't reorder the items, // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to // 1, internally; modulo a modulo operation). macro_rules! t { ($items:expr, $expected:expr) => {{ let mut items = $items; let wc = WeightedChoice::new(&mut items); let expected = $expected; let mut rng = CountingRng { i: 0 }; for &val in expected.iter() { assert_eq!(wc.ind_sample(&mut rng), val) } }} } t!(vec!(Weighted { weight: 1, item: 10}), [10]); // skip some t!(vec!(Weighted { weight: 0, item: 20}, Weighted { weight: 2, item: 21}, Weighted { weight: 0, item: 22}, Weighted { weight: 1, item: 23}), [21,21, 23]); // different weights t!(vec!(Weighted { weight: 4, item: 30}, Weighted { weight: 3, item: 31}), [30,30,30,30, 31,31,31]); // check that we're binary searching // correctly with some vectors of odd // length. t!(vec!(Weighted { weight: 1, item: 40}, Weighted { weight: 1, item: 41}, Weighted { weight: 1, item: 42}, Weighted { weight: 1, item: 43}, Weighted { weight: 1, item: 44}), [40, 41, 42, 43, 44]); t!(vec!(Weighted { weight: 1, item: 50}, Weighted { weight: 1, item: 51}, Weighted { weight: 1, item: 52}, Weighted { weight: 1, item: 53}, Weighted { weight: 1, item: 54}, Weighted { weight: 1, item: 55}, Weighted { weight: 1, item: 56}), [50, 51, 52, 53, 54, 55, 56]); } #[test] fn test_weighted_clone_initialization() { let initial : Weighted = Weighted {weight: 1, item: 1}; let clone = initial.clone(); assert_eq!(initial.weight, clone.weight); assert_eq!(initial.item, clone.item); } #[test] #[should_panic] fn test_weighted_clone_change_weight() { let initial : Weighted = Weighted {weight: 1, item: 1}; let mut clone = initial.clone(); clone.weight = 5; assert_eq!(initial.weight, clone.weight); } #[test] #[should_panic] fn test_weighted_clone_change_item() { let initial : Weighted = Weighted {weight: 1, item: 1}; let mut clone = initial.clone(); clone.item = 5; assert_eq!(initial.item, clone.item); } #[test] #[should_panic] fn test_weighted_choice_no_items() { WeightedChoice::::new(&mut []); } #[test] #[should_panic] fn test_weighted_choice_zero_weight() { WeightedChoice::new(&mut [Weighted { weight: 0, item: 0}, Weighted { weight: 0, item: 1}]); } #[test] #[should_panic] fn test_weighted_choice_weight_overflows() { let x = ::std::u32::MAX / 2; // x + x + 2 is the overflow WeightedChoice::new(&mut [Weighted { weight: x, item: 0 }, Weighted { weight: 1, item: 1 }, Weighted { weight: x, item: 2 }, Weighted { weight: 1, item: 3 }]); } }