diff options
Diffstat (limited to 'rand/src/distributions/dirichlet.rs')
-rw-r--r-- | rand/src/distributions/dirichlet.rs | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/rand/src/distributions/dirichlet.rs b/rand/src/distributions/dirichlet.rs new file mode 100644 index 0000000..19384b8 --- /dev/null +++ b/rand/src/distributions/dirichlet.rs @@ -0,0 +1,137 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The dirichlet distribution. + +use Rng; +use distributions::Distribution; +use distributions::gamma::Gamma; + +/// The dirichelet distribution `Dirichlet(alpha)`. +/// +/// The Dirichlet distribution is a family of continuous multivariate +/// probability distributions parameterized by a vector alpha of positive reals. +/// It is a multivariate generalization of the beta distribution. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::Dirichlet; +/// +/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]); +/// let samples = dirichlet.sample(&mut rand::thread_rng()); +/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); +/// ``` + +#[derive(Clone, Debug)] +pub struct Dirichlet { + /// Concentration parameters (alpha) + alpha: Vec<f64>, +} + +impl Dirichlet { + /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. + /// + /// # Panics + /// - if `alpha.len() < 2` + /// + #[inline] + pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet { + let a = alpha.into(); + assert!(a.len() > 1); + for i in 0..a.len() { + assert!(a[i] > 0.0); + } + + Dirichlet { alpha: a } + } + + /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. + /// + /// # Panics + /// - if `alpha <= 0.0` + /// - if `size < 2` + /// + #[inline] + pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { + assert!(alpha > 0.0); + assert!(size > 1); + Dirichlet { + alpha: vec![alpha; size], + } + } +} + +impl Distribution<Vec<f64>> for Dirichlet { + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> { + let n = self.alpha.len(); + let mut samples = vec![0.0f64; n]; + let mut sum = 0.0f64; + + for i in 0..n { + let g = Gamma::new(self.alpha[i], 1.0); + samples[i] = g.sample(rng); + sum += samples[i]; + } + let invacc = 1.0 / sum; + for i in 0..n { + samples[i] *= invacc; + } + samples + } +} + +#[cfg(test)] +mod test { + use super::Dirichlet; + use distributions::Distribution; + + #[test] + fn test_dirichlet() { + let d = Dirichlet::new(vec![1.0, 2.0, 3.0]); + let mut rng = ::test::rng(221); + let samples = d.sample(&mut rng); + let _: Vec<f64> = samples + .into_iter() + .map(|x| { + assert!(x > 0.0); + x + }) + .collect(); + } + + #[test] + fn test_dirichlet_with_param() { + let alpha = 0.5f64; + let size = 2; + let d = Dirichlet::new_with_param(alpha, size); + let mut rng = ::test::rng(221); + let samples = d.sample(&mut rng); + let _: Vec<f64> = samples + .into_iter() + .map(|x| { + assert!(x > 0.0); + x + }) + .collect(); + } + + #[test] + #[should_panic] + fn test_dirichlet_invalid_length() { + Dirichlet::new_with_param(0.5f64, 1); + } + + #[test] + #[should_panic] + fn test_dirichlet_invalid_alpha() { + Dirichlet::new_with_param(0.0f64, 2); + } +} |