diff options
| author | Daniel Mueller <deso@posteo.net> | 2020-01-02 08:32:06 -0800 | 
|---|---|---|
| committer | Daniel Mueller <deso@posteo.net> | 2020-01-02 08:32:06 -0800 | 
| commit | fd091b04316db9dc5fafadbd6bdbe60b127408a9 (patch) | |
| tree | f202270f7ae5cedc513be03833a26148d9b5e219 /rand/src/distributions/weighted | |
| parent | 8161cdb26f98e65b39c603ddf7a614cc87c77a1c (diff) | |
| download | nitrocli-fd091b04316db9dc5fafadbd6bdbe60b127408a9.tar.gz nitrocli-fd091b04316db9dc5fafadbd6bdbe60b127408a9.tar.bz2  | |
Update nitrokey crate to 0.4.0
This change finally updates the version of the nitrokey crate that we
consume to 0.4.0. Along with that we update rand_core, one of its
dependencies, to 0.5.1. Further more we add cfg-if in version 0.1.10 and
getrandom in version 0.1.13, both of which are now new (non-development)
dependencies.
Import subrepo nitrokey/:nitrokey at e81057037e9b4f370b64c0a030a725bc6bdfb870
Import subrepo cfg-if/:cfg-if at 4484a6faf816ff8058088ad857b0c6bb2f4b02b2
Import subrepo getrandom/:getrandom at d661aa7e1b8cc80b47dabe3d2135b3b47d2858af
Import subrepo rand/:rand at d877ed528248b52d947e0484364a4e1ae59ca502
Diffstat (limited to 'rand/src/distributions/weighted')
| -rw-r--r-- | rand/src/distributions/weighted/alias_method.rs | 499 | ||||
| -rw-r--r-- | rand/src/distributions/weighted/mod.rs | 363 | 
2 files changed, 862 insertions, 0 deletions
diff --git a/rand/src/distributions/weighted/alias_method.rs b/rand/src/distributions/weighted/alias_method.rs new file mode 100644 index 0000000..bdd4ba0 --- /dev/null +++ b/rand/src/distributions/weighted/alias_method.rs @@ -0,0 +1,499 @@ +//! This module contains an implementation of alias method for sampling random +//! indices with probabilities proportional to a collection of weights. + +use super::WeightedError; +#[cfg(not(feature = "std"))] +use crate::alloc::vec::Vec; +#[cfg(not(feature = "std"))] +use crate::alloc::vec; +use core::fmt; +use core::iter::Sum; +use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; +use crate::distributions::uniform::SampleUniform; +use crate::distributions::Distribution; +use crate::distributions::Uniform; +use crate::Rng; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedIndex<W>`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedIndex<W>`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which a implementation of +/// [`Weight`] exists. +/// +/// # Performance +/// +/// Given that `n` is the number of items in the vector used to create an +/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of +/// memory. More specifically it takes up some constant amount of memory plus +/// the vector used to create it and a [`Vec<u32>`] with capacity `n`. +/// +/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`. +/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call +/// to [`Uniform<W>::sample`]. +/// +/// # Example +/// +/// ``` +/// use rand::distributions::weighted::alias_method::WeightedIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = WeightedIndex::new(weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +///     println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); +/// for _ in 0..100 { +///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +///     println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`WeightedIndex<W>`]: crate::distributions::weighted::alias_method::WeightedIndex +/// [`Weight`]: crate::distributions::weighted::alias_method::Weight +/// [`Vec<u32>`]: Vec +/// [`Uniform<u32>::sample`]: Distribution::sample +/// [`Uniform<W>::sample`]: Distribution::sample +pub struct WeightedIndex<W: Weight> { +    aliases: Vec<u32>, +    no_alias_odds: Vec<W>, +    uniform_index: Uniform<u32>, +    uniform_within_weight_sum: Uniform<W>, +} + +impl<W: Weight> WeightedIndex<W> { +    /// Creates a new [`WeightedIndex`]. +    /// +    /// Returns an error if: +    /// - The vector is empty. +    /// - The vector is longer than `u32::MAX`. +    /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / +    ///   weights.len()`. +    /// - The sum of weights is zero. +    pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> { +        let n = weights.len(); +        if n == 0 { +            return Err(WeightedError::NoItem); +        } else if n > ::core::u32::MAX as usize { +            return Err(WeightedError::TooMany); +        } +        let n = n as u32; + +        let max_weight_size = W::try_from_u32_lossy(n) +            .map(|n| W::MAX / n) +            .unwrap_or(W::ZERO); +        if !weights +            .iter() +            .all(|&w| W::ZERO <= w && w <= max_weight_size) +        { +            return Err(WeightedError::InvalidWeight); +        } + +        // The sum of weights will represent 100% of no alias odds. +        let weight_sum = Weight::sum(weights.as_slice()); +        // Prevent floating point overflow due to rounding errors. +        let weight_sum = if weight_sum > W::MAX { +            W::MAX +        } else { +            weight_sum +        }; +        if weight_sum == W::ZERO { +            return Err(WeightedError::AllWeightsZero); +        } + +        // `weight_sum` would have been zero if `try_from_lossy` causes an error here. +        let n_converted = W::try_from_u32_lossy(n).unwrap(); + +        let mut no_alias_odds = weights; +        for odds in no_alias_odds.iter_mut() { +            *odds *= n_converted; +            // Prevent floating point overflow due to rounding errors. +            *odds = if *odds > W::MAX { W::MAX } else { *odds }; +        } + +        /// This struct is designed to contain three data structures at once, +        /// sharing the same memory. More precisely it contains two linked lists +        /// and an alias map, which will be the output of this method. To keep +        /// the three data structures from getting in each other's way, it must +        /// be ensured that a single index is only ever in one of them at the +        /// same time. +        struct Aliases { +            aliases: Vec<u32>, +            smalls_head: u32, +            bigs_head: u32, +        } + +        impl Aliases { +            fn new(size: u32) -> Self { +                Aliases { +                    aliases: vec![0; size as usize], +                    smalls_head: ::core::u32::MAX, +                    bigs_head: ::core::u32::MAX, +                } +            } + +            fn push_small(&mut self, idx: u32) { +                self.aliases[idx as usize] = self.smalls_head; +                self.smalls_head = idx; +            } + +            fn push_big(&mut self, idx: u32) { +                self.aliases[idx as usize] = self.bigs_head; +                self.bigs_head = idx; +            } + +            fn pop_small(&mut self) -> u32 { +                let popped = self.smalls_head; +                self.smalls_head = self.aliases[popped as usize]; +                popped +            } + +            fn pop_big(&mut self) -> u32 { +                let popped = self.bigs_head; +                self.bigs_head = self.aliases[popped as usize]; +                popped +            } + +            fn smalls_is_empty(&self) -> bool { +                self.smalls_head == ::core::u32::MAX +            } + +            fn bigs_is_empty(&self) -> bool { +                self.bigs_head == ::core::u32::MAX +            } + +            fn set_alias(&mut self, idx: u32, alias: u32) { +                self.aliases[idx as usize] = alias; +            } +        } + +        let mut aliases = Aliases::new(n); + +        // Split indices into those with small weights and those with big weights. +        for (index, &odds) in no_alias_odds.iter().enumerate() { +            if odds < weight_sum { +                aliases.push_small(index as u32); +            } else { +                aliases.push_big(index as u32); +            } +        } + +        // Build the alias map by finding an alias with big weight for each index with +        // small weight. +        while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { +            let s = aliases.pop_small(); +            let b = aliases.pop_big(); + +            aliases.set_alias(s, b); +            no_alias_odds[b as usize] = no_alias_odds[b as usize] +                    - weight_sum +                    + no_alias_odds[s as usize]; + +            if no_alias_odds[b as usize] < weight_sum { +                aliases.push_small(b); +            } else { +                aliases.push_big(b); +            } +        } + +        // The remaining indices should have no alias odds of about 100%. This is due to +        // numeric accuracy. Otherwise they would be exactly 100%. +        while !aliases.smalls_is_empty() { +            no_alias_odds[aliases.pop_small() as usize] = weight_sum; +        } +        while !aliases.bigs_is_empty() { +            no_alias_odds[aliases.pop_big() as usize] = weight_sum; +        } + +        // Prepare distributions for sampling. Creating them beforehand improves +        // sampling performance. +        let uniform_index = Uniform::new(0, n); +        let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); + +        Ok(Self { +            aliases: aliases.aliases, +            no_alias_odds, +            uniform_index, +            uniform_within_weight_sum, +        }) +    } +} + +impl<W: Weight> Distribution<usize> for WeightedIndex<W> { +    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { +        let candidate = rng.sample(self.uniform_index); +        if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] { +            candidate as usize +        } else { +            self.aliases[candidate as usize] as usize +        } +    } +} + +impl<W: Weight> fmt::Debug for WeightedIndex<W> +where +    W: fmt::Debug, +    Uniform<W>: fmt::Debug, +{ +    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +        f.debug_struct("WeightedIndex") +            .field("aliases", &self.aliases) +            .field("no_alias_odds", &self.no_alias_odds) +            .field("uniform_index", &self.uniform_index) +            .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) +            .finish() +    } +} + +impl<W: Weight> Clone for WeightedIndex<W> +where +    Uniform<W>: Clone, +{ +    fn clone(&self) -> Self { +        Self { +            aliases: self.aliases.clone(), +            no_alias_odds: self.no_alias_odds.clone(), +            uniform_index: self.uniform_index.clone(), +            uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), +        } +    } +} + +/// Trait that must be implemented for weights, that are used with +/// [`WeightedIndex`]. Currently no guarantees on the correctness of +/// [`WeightedIndex`] are given for custom implementations of this trait. +pub trait Weight: +    Sized +    + Copy +    + SampleUniform +    + PartialOrd +    + Add<Output = Self> +    + AddAssign +    + Sub<Output = Self> +    + SubAssign +    + Mul<Output = Self> +    + MulAssign +    + Div<Output = Self> +    + DivAssign +    + Sum +{ +    /// Maximum number representable by `Self`. +    const MAX: Self; + +    /// Element of `Self` equivalent to 0. +    const ZERO: Self; + +    /// Produce an instance of `Self` from a `u32` value, or return `None` if +    /// out of range. Loss of precision (where `Self` is a floating point type) +    /// is acceptable. +    fn try_from_u32_lossy(n: u32) -> Option<Self>; + +    /// Sums all values in slice `values`. +    fn sum(values: &[Self]) -> Self { +        values.iter().map(|x| *x).sum() +    } +} + +macro_rules! impl_weight_for_float { +    ($T: ident) => { +        impl Weight for $T { +            const MAX: Self = ::core::$T::MAX; +            const ZERO: Self = 0.0; + +            fn try_from_u32_lossy(n: u32) -> Option<Self> { +                Some(n as $T) +            } + +            fn sum(values: &[Self]) -> Self { +                pairwise_sum(values) +            } +        } +    }; +} + +/// In comparison to naive accumulation, the pairwise sum algorithm reduces +/// rounding errors when there are many floating point values. +fn pairwise_sum<T: Weight>(values: &[T]) -> T { +    if values.len() <= 32 { +        values.iter().map(|x| *x).sum() +    } else { +        let mid = values.len() / 2; +        let (a, b) = values.split_at(mid); +        pairwise_sum(a) + pairwise_sum(b) +    } +} + +macro_rules! impl_weight_for_int { +    ($T: ident) => { +        impl Weight for $T { +            const MAX: Self = ::core::$T::MAX; +            const ZERO: Self = 0; + +            fn try_from_u32_lossy(n: u32) -> Option<Self> { +                let n_converted = n as Self; +                if n_converted >= Self::ZERO && n_converted as u32 == n { +                    Some(n_converted) +                } else { +                    None +                } +            } +        } +    }; +} + +impl_weight_for_float!(f64); +impl_weight_for_float!(f32); +impl_weight_for_int!(usize); +#[cfg(not(target_os = "emscripten"))] +impl_weight_for_int!(u128); +impl_weight_for_int!(u64); +impl_weight_for_int!(u32); +impl_weight_for_int!(u16); +impl_weight_for_int!(u8); +impl_weight_for_int!(isize); +#[cfg(not(target_os = "emscripten"))] +impl_weight_for_int!(i128); +impl_weight_for_int!(i64); +impl_weight_for_int!(i32); +impl_weight_for_int!(i16); +impl_weight_for_int!(i8); + +#[cfg(test)] +mod test { +    use super::*; + +    #[test] +    #[cfg(not(miri))] // Miri is too slow +    fn test_weighted_index_f32() { +        test_weighted_index(f32::into); + +        // Floating point special cases +        assert_eq!( +            WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +        assert_eq!( +            WeightedIndex::new(vec![-0_f32]).unwrap_err(), +            WeightedError::AllWeightsZero +        ); +        assert_eq!( +            WeightedIndex::new(vec![-1_f32]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +        assert_eq!( +            WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +        assert_eq!( +            WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +    } + +    #[cfg(not(target_os = "emscripten"))] +    #[test] +    #[cfg(not(miri))] // Miri is too slow +    fn test_weighted_index_u128() { +        test_weighted_index(|x: u128| x as f64); +    } + +    #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] +    #[test] +    #[cfg(not(miri))] // Miri is too slow +    fn test_weighted_index_i128() { +        test_weighted_index(|x: i128| x as f64); + +        // Signed integer special cases +        assert_eq!( +            WeightedIndex::new(vec![-1_i128]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +        assert_eq!( +            WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +    } + +    #[test] +    #[cfg(not(miri))] // Miri is too slow +    fn test_weighted_index_u8() { +        test_weighted_index(u8::into); +    } + +    #[test] +    #[cfg(not(miri))] // Miri is too slow +    fn test_weighted_index_i8() { +        test_weighted_index(i8::into); + +        // Signed integer special cases +        assert_eq!( +            WeightedIndex::new(vec![-1_i8]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +        assert_eq!( +            WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +    } + +    fn test_weighted_index<W: Weight, F: Fn(W) -> f64>(w_to_f64: F) +    where +        WeightedIndex<W>: fmt::Debug, +    { +        const NUM_WEIGHTS: u32 = 10; +        const ZERO_WEIGHT_INDEX: u32 = 3; +        const NUM_SAMPLES: u32 = 15000; +        let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + +        let weights = { +            let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); +            let random_weight_distribution = crate::distributions::Uniform::new_inclusive( +                W::ZERO, +                W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), +            ); +            for _ in 0..NUM_WEIGHTS { +                weights.push(rng.sample(&random_weight_distribution)); +            } +            weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; +            weights +        }; +        let weight_sum = weights.iter().map(|w| *w).sum::<W>(); +        let expected_counts = weights +            .iter() +            .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) +            .collect::<Vec<f64>>(); +        let weight_distribution = WeightedIndex::new(weights).unwrap(); + +        let mut counts = vec![0; NUM_WEIGHTS as usize]; +        for _ in 0..NUM_SAMPLES { +            counts[rng.sample(&weight_distribution)] += 1; +        } + +        assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0); +        for (count, expected_count) in counts.into_iter().zip(expected_counts) { +            let difference = (count as f64 - expected_count).abs(); +            let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; +            assert!(difference <= max_allowed_difference); +        } + +        assert_eq!( +            WeightedIndex::<W>::new(vec![]).unwrap_err(), +            WeightedError::NoItem +        ); +        assert_eq!( +            WeightedIndex::new(vec![W::ZERO]).unwrap_err(), +            WeightedError::AllWeightsZero +        ); +        assert_eq!( +            WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), +            WeightedError::InvalidWeight +        ); +    } +} diff --git a/rand/src/distributions/weighted/mod.rs b/rand/src/distributions/weighted/mod.rs new file mode 100644 index 0000000..2711637 --- /dev/null +++ b/rand/src/distributions/weighted/mod.rs @@ -0,0 +1,363 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or +// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted index sampling +//!  +//! This module provides two implementations for sampling indices: +//!  +//! *   [`WeightedIndex`] allows `O(log N)` sampling +//! *   [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with +//!      much greater set-up cost +//!       +//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html + +pub mod alias_method; + +use crate::Rng; +use crate::distributions::Distribution; +use crate::distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; +use core::cmp::PartialOrd; +use core::fmt; + +// Note that this whole module is only imported if feature="alloc" is enabled. +#[cfg(not(feature="std"))] use crate::alloc::vec::Vec; + +/// A distribution using weighted sampling to pick a discretely selected +/// item. +/// +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// value of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform<X>`] exists. +/// +/// # Performance +/// +/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. +/// +/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains, this might cause additional allocations, though for primitive +/// types, ['Uniform<X>`] doesn't allocate any memory. +/// +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. +/// +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementaiton of `Uniform<X>::sample`. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2,   1,   1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +///     println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0), ('b', 3), ('c', 7)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +///     println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`Uniform<X>`]: crate::distributions::uniform::Uniform +/// [`RngCore`]: crate::RngCore +#[derive(Debug, Clone)] +pub struct WeightedIndex<X: SampleUniform + PartialOrd> { +    cumulative_weights: Vec<X>, +    total_weight: X, +    weight_distribution: X::Sampler, +} + +impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { +    /// Creates a new a `WeightedIndex` [`Distribution`] using the values +    /// in `weights`. The weights can use any type `X` for which an +    /// implementation of [`Uniform<X>`] exists. +    /// +    /// Returns an error if the iterator is empty, if any weight is `< 0`, or +    /// if its total value is 0. +    /// +    /// [`Uniform<X>`]: crate::distributions::uniform::Uniform +    pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> +        where I: IntoIterator, +              I::Item: SampleBorrow<X>, +              X: for<'a> ::core::ops::AddAssign<&'a X> + +                 Clone + +                 Default { +        let mut iter = weights.into_iter(); +        let mut total_weight: X = iter.next() +                                      .ok_or(WeightedError::NoItem)? +                                      .borrow() +                                      .clone(); + +        let zero = <X as Default>::default(); +        if total_weight < zero { +            return Err(WeightedError::InvalidWeight); +        } + +        let mut weights = Vec::<X>::with_capacity(iter.size_hint().0); +        for w in iter { +            if *w.borrow() < zero { +                return Err(WeightedError::InvalidWeight); +            } +            weights.push(total_weight.clone()); +            total_weight += w.borrow(); +        } + +        if total_weight == zero { +            return Err(WeightedError::AllWeightsZero); +        } +        let distr = X::Sampler::new(zero, total_weight.clone()); + +        Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr }) +    } + +    /// Update a subset of weights, without changing the number of weights. +    /// +    /// `new_weights` must be sorted by the index. +    /// +    /// Using this method instead of `new` might be more efficient if only a small number of +    /// weights is modified. No allocations are performed, unless the weight type `X` uses +    /// allocation internally. +    /// +    /// In case of error, `self` is not modified. +    pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> +        where X: for<'a> ::core::ops::AddAssign<&'a X> + +                 for<'a> ::core::ops::SubAssign<&'a X> + +                 Clone + +                 Default { +        if new_weights.is_empty() { +            return Ok(()); +        } + +        let zero = <X as Default>::default(); + +        let mut total_weight = self.total_weight.clone(); + +        // Check for errors first, so we don't modify `self` in case something +        // goes wrong. +        let mut prev_i = None; +        for &(i, w) in new_weights { +            if let Some(old_i) = prev_i { +                if old_i >= i { +                    return Err(WeightedError::InvalidWeight); +                } +            } +            if *w < zero { +                return Err(WeightedError::InvalidWeight); +            } +            if i >= self.cumulative_weights.len() + 1 { +                return Err(WeightedError::TooMany); +            } + +            let mut old_w = if i < self.cumulative_weights.len() { +                self.cumulative_weights[i].clone() +            } else { +                self.total_weight.clone() +            }; +            if i > 0 { +                old_w -= &self.cumulative_weights[i - 1]; +            } + +            total_weight -= &old_w; +            total_weight += w; +            prev_i = Some(i); +        } +        if total_weight == zero { +            return Err(WeightedError::AllWeightsZero); +        } + +        // Update the weights. Because we checked all the preconditions in the +        // previous loop, this should never panic. +        let mut iter = new_weights.iter(); + +        let mut prev_weight = zero.clone(); +        let mut next_new_weight = iter.next(); +        let &(first_new_index, _) = next_new_weight.unwrap(); +        let mut cumulative_weight = if first_new_index > 0 { +            self.cumulative_weights[first_new_index - 1].clone() +        } else { +            zero.clone()  +        }; +        for i in first_new_index..self.cumulative_weights.len() { +            match next_new_weight { +                Some(&(j, w)) if i == j => { +                    cumulative_weight += w; +                    next_new_weight = iter.next(); +                }, +                _ => { +                    let mut tmp = self.cumulative_weights[i].clone(); +                    tmp -= &prev_weight;  // We know this is positive. +                    cumulative_weight += &tmp; +                } +            } +            prev_weight = cumulative_weight.clone(); +            core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); +        } + +        self.total_weight = total_weight; +        self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); + +        Ok(()) +    } +} + +impl<X> Distribution<usize> for WeightedIndex<X> where +    X: SampleUniform + PartialOrd { +    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { +        use ::core::cmp::Ordering; +        let chosen_weight = self.weight_distribution.sample(rng); +        // Find the first item which has a weight *higher* than the chosen weight. +        self.cumulative_weights.binary_search_by( +            |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() +    } +} + +#[cfg(test)] +mod test { +    use super::*; + +    #[test] +    #[cfg(not(miri))] // Miri is too slow +    fn test_weightedindex() { +        let mut r = crate::test::rng(700); +        const N_REPS: u32 = 5000; +        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; +        let total_weight = weights.iter().sum::<u32>() as f32; + +        let verify = |result: [i32; 14]| { +            for (i, count) in result.iter().enumerate() { +                let exp = (weights[i] * N_REPS) as f32 / total_weight; +                let mut err = (*count as f32 - exp).abs(); +                if err != 0.0 { +                    err /= exp; +                } +                assert!(err <= 0.25); +            } +        }; + +        // WeightedIndex from vec +        let mut chosen = [0i32; 14]; +        let distr = WeightedIndex::new(weights.to_vec()).unwrap(); +        for _ in 0..N_REPS { +            chosen[distr.sample(&mut r)] += 1; +        } +        verify(chosen); + +        // WeightedIndex from slice +        chosen = [0i32; 14]; +        let distr = WeightedIndex::new(&weights[..]).unwrap(); +        for _ in 0..N_REPS { +            chosen[distr.sample(&mut r)] += 1; +        } +        verify(chosen); + +        // WeightedIndex from iterator +        chosen = [0i32; 14]; +        let distr = WeightedIndex::new(weights.iter()).unwrap(); +        for _ in 0..N_REPS { +            chosen[distr.sample(&mut r)] += 1; +        } +        verify(chosen); + +        for _ in 0..5 { +            assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); +            assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); +            assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); +        } + +        assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); +        assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); +        assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight); +        assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); +        assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); +    } + +    #[test] +    fn test_update_weights() { +        let data = [ +            (&[10u32, 2, 3, 4][..], +             &[(1, &100), (2, &4)][..],  // positive change +             &[10, 100, 4, 4][..]), +            (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], +             &[(2, &1), (5, &1), (13, &100)][..],  // negative change and last element +             &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]), +        ]; + +        for (weights, update, expected_weights) in data.into_iter() { +            let total_weight = weights.iter().sum::<u32>(); +            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); +            assert_eq!(distr.total_weight, total_weight); + +            distr.update_weights(update).unwrap(); +            let expected_total_weight = expected_weights.iter().sum::<u32>(); +            let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); +            assert_eq!(distr.total_weight, expected_total_weight); +            assert_eq!(distr.total_weight, expected_distr.total_weight); +            assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); +        } +    } +} + +/// Error type returned from `WeightedIndex::new`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WeightedError { +    /// The provided weight collection contains no items. +    NoItem, + +    /// A weight is either less than zero, greater than the supported maximum or +    /// otherwise invalid. +    InvalidWeight, + +    /// All items in the provided weight collection are zero. +    AllWeightsZero, +     +    /// Too many weights are provided (length greater than `u32::MAX`) +    TooMany, +} + +impl WeightedError { +    fn msg(&self) -> &str { +        match *self { +            WeightedError::NoItem => "No weights provided.", +            WeightedError::InvalidWeight => "A weight is invalid.", +            WeightedError::AllWeightsZero => "All weights are zero.", +            WeightedError::TooMany => "Too many weights (hit u32::MAX)", +        } +    } +} + +#[cfg(feature="std")] +impl ::std::error::Error for WeightedError { +    fn description(&self) -> &str { +        self.msg() +    } +    fn cause(&self) -> Option<&dyn (::std::error::Error)> { +        None +    } +} + +impl fmt::Display for WeightedError { +    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +        write!(f, "{}", self.msg()) +    } +}  | 
