diff options
Diffstat (limited to 'rand/src/distributions/weighted')
-rw-r--r-- | rand/src/distributions/weighted/alias_method.rs | 499 | ||||
-rw-r--r-- | rand/src/distributions/weighted/mod.rs | 363 |
2 files changed, 0 insertions, 862 deletions
diff --git a/rand/src/distributions/weighted/alias_method.rs b/rand/src/distributions/weighted/alias_method.rs deleted file mode 100644 index bdd4ba0..0000000 --- a/rand/src/distributions/weighted/alias_method.rs +++ /dev/null @@ -1,499 +0,0 @@ -//! This module contains an implementation of alias method for sampling random -//! indices with probabilities proportional to a collection of weights. - -use super::WeightedError; -#[cfg(not(feature = "std"))] -use crate::alloc::vec::Vec; -#[cfg(not(feature = "std"))] -use crate::alloc::vec; -use core::fmt; -use core::iter::Sum; -use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; -use crate::distributions::uniform::SampleUniform; -use crate::distributions::Distribution; -use crate::distributions::Uniform; -use crate::Rng; - -/// A distribution using weighted sampling to pick a discretely selected item. -/// -/// Sampling a [`WeightedIndex<W>`] distribution returns the index of a randomly -/// selected element from the vector used to create the [`WeightedIndex<W>`]. -/// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which a implementation of -/// [`Weight`] exists. -/// -/// # Performance -/// -/// Given that `n` is the number of items in the vector used to create an -/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of -/// memory. More specifically it takes up some constant amount of memory plus -/// the vector used to create it and a [`Vec<u32>`] with capacity `n`. -/// -/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`. -/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call -/// to [`Uniform<W>::sample`]. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::weighted::alias_method::WeightedIndex; -/// use rand::prelude::*; -/// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 1, 1]; -/// let dist = WeightedIndex::new(weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`WeightedIndex<W>`]: crate::distributions::weighted::alias_method::WeightedIndex -/// [`Weight`]: crate::distributions::weighted::alias_method::Weight -/// [`Vec<u32>`]: Vec -/// [`Uniform<u32>::sample`]: Distribution::sample -/// [`Uniform<W>::sample`]: Distribution::sample -pub struct WeightedIndex<W: Weight> { - aliases: Vec<u32>, - no_alias_odds: Vec<W>, - uniform_index: Uniform<u32>, - uniform_within_weight_sum: Uniform<W>, -} - -impl<W: Weight> WeightedIndex<W> { - /// Creates a new [`WeightedIndex`]. - /// - /// Returns an error if: - /// - The vector is empty. - /// - The vector is longer than `u32::MAX`. - /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / - /// weights.len()`. - /// - The sum of weights is zero. - pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> { - let n = weights.len(); - if n == 0 { - return Err(WeightedError::NoItem); - } else if n > ::core::u32::MAX as usize { - return Err(WeightedError::TooMany); - } - let n = n as u32; - - let max_weight_size = W::try_from_u32_lossy(n) - .map(|n| W::MAX / n) - .unwrap_or(W::ZERO); - if !weights - .iter() - .all(|&w| W::ZERO <= w && w <= max_weight_size) - { - return Err(WeightedError::InvalidWeight); - } - - // The sum of weights will represent 100% of no alias odds. - let weight_sum = Weight::sum(weights.as_slice()); - // Prevent floating point overflow due to rounding errors. - let weight_sum = if weight_sum > W::MAX { - W::MAX - } else { - weight_sum - }; - if weight_sum == W::ZERO { - return Err(WeightedError::AllWeightsZero); - } - - // `weight_sum` would have been zero if `try_from_lossy` causes an error here. - let n_converted = W::try_from_u32_lossy(n).unwrap(); - - let mut no_alias_odds = weights; - for odds in no_alias_odds.iter_mut() { - *odds *= n_converted; - // Prevent floating point overflow due to rounding errors. - *odds = if *odds > W::MAX { W::MAX } else { *odds }; - } - - /// This struct is designed to contain three data structures at once, - /// sharing the same memory. More precisely it contains two linked lists - /// and an alias map, which will be the output of this method. To keep - /// the three data structures from getting in each other's way, it must - /// be ensured that a single index is only ever in one of them at the - /// same time. - struct Aliases { - aliases: Vec<u32>, - smalls_head: u32, - bigs_head: u32, - } - - impl Aliases { - fn new(size: u32) -> Self { - Aliases { - aliases: vec![0; size as usize], - smalls_head: ::core::u32::MAX, - bigs_head: ::core::u32::MAX, - } - } - - fn push_small(&mut self, idx: u32) { - self.aliases[idx as usize] = self.smalls_head; - self.smalls_head = idx; - } - - fn push_big(&mut self, idx: u32) { - self.aliases[idx as usize] = self.bigs_head; - self.bigs_head = idx; - } - - fn pop_small(&mut self) -> u32 { - let popped = self.smalls_head; - self.smalls_head = self.aliases[popped as usize]; - popped - } - - fn pop_big(&mut self) -> u32 { - let popped = self.bigs_head; - self.bigs_head = self.aliases[popped as usize]; - popped - } - - fn smalls_is_empty(&self) -> bool { - self.smalls_head == ::core::u32::MAX - } - - fn bigs_is_empty(&self) -> bool { - self.bigs_head == ::core::u32::MAX - } - - fn set_alias(&mut self, idx: u32, alias: u32) { - self.aliases[idx as usize] = alias; - } - } - - let mut aliases = Aliases::new(n); - - // Split indices into those with small weights and those with big weights. - for (index, &odds) in no_alias_odds.iter().enumerate() { - if odds < weight_sum { - aliases.push_small(index as u32); - } else { - aliases.push_big(index as u32); - } - } - - // Build the alias map by finding an alias with big weight for each index with - // small weight. - while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { - let s = aliases.pop_small(); - let b = aliases.pop_big(); - - aliases.set_alias(s, b); - no_alias_odds[b as usize] = no_alias_odds[b as usize] - - weight_sum - + no_alias_odds[s as usize]; - - if no_alias_odds[b as usize] < weight_sum { - aliases.push_small(b); - } else { - aliases.push_big(b); - } - } - - // The remaining indices should have no alias odds of about 100%. This is due to - // numeric accuracy. Otherwise they would be exactly 100%. - while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small() as usize] = weight_sum; - } - while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big() as usize] = weight_sum; - } - - // Prepare distributions for sampling. Creating them beforehand improves - // sampling performance. - let uniform_index = Uniform::new(0, n); - let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); - - Ok(Self { - aliases: aliases.aliases, - no_alias_odds, - uniform_index, - uniform_within_weight_sum, - }) - } -} - -impl<W: Weight> Distribution<usize> for WeightedIndex<W> { - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { - let candidate = rng.sample(self.uniform_index); - if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] { - candidate as usize - } else { - self.aliases[candidate as usize] as usize - } - } -} - -impl<W: Weight> fmt::Debug for WeightedIndex<W> -where - W: fmt::Debug, - Uniform<W>: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("WeightedIndex") - .field("aliases", &self.aliases) - .field("no_alias_odds", &self.no_alias_odds) - .field("uniform_index", &self.uniform_index) - .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) - .finish() - } -} - -impl<W: Weight> Clone for WeightedIndex<W> -where - Uniform<W>: Clone, -{ - fn clone(&self) -> Self { - Self { - aliases: self.aliases.clone(), - no_alias_odds: self.no_alias_odds.clone(), - uniform_index: self.uniform_index.clone(), - uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), - } - } -} - -/// Trait that must be implemented for weights, that are used with -/// [`WeightedIndex`]. Currently no guarantees on the correctness of -/// [`WeightedIndex`] are given for custom implementations of this trait. -pub trait Weight: - Sized - + Copy - + SampleUniform - + PartialOrd - + Add<Output = Self> - + AddAssign - + Sub<Output = Self> - + SubAssign - + Mul<Output = Self> - + MulAssign - + Div<Output = Self> - + DivAssign - + Sum -{ - /// Maximum number representable by `Self`. - const MAX: Self; - - /// Element of `Self` equivalent to 0. - const ZERO: Self; - - /// Produce an instance of `Self` from a `u32` value, or return `None` if - /// out of range. Loss of precision (where `Self` is a floating point type) - /// is acceptable. - fn try_from_u32_lossy(n: u32) -> Option<Self>; - - /// Sums all values in slice `values`. - fn sum(values: &[Self]) -> Self { - values.iter().map(|x| *x).sum() - } -} - -macro_rules! impl_weight_for_float { - ($T: ident) => { - impl Weight for $T { - const MAX: Self = ::core::$T::MAX; - const ZERO: Self = 0.0; - - fn try_from_u32_lossy(n: u32) -> Option<Self> { - Some(n as $T) - } - - fn sum(values: &[Self]) -> Self { - pairwise_sum(values) - } - } - }; -} - -/// In comparison to naive accumulation, the pairwise sum algorithm reduces -/// rounding errors when there are many floating point values. -fn pairwise_sum<T: Weight>(values: &[T]) -> T { - if values.len() <= 32 { - values.iter().map(|x| *x).sum() - } else { - let mid = values.len() / 2; - let (a, b) = values.split_at(mid); - pairwise_sum(a) + pairwise_sum(b) - } -} - -macro_rules! impl_weight_for_int { - ($T: ident) => { - impl Weight for $T { - const MAX: Self = ::core::$T::MAX; - const ZERO: Self = 0; - - fn try_from_u32_lossy(n: u32) -> Option<Self> { - let n_converted = n as Self; - if n_converted >= Self::ZERO && n_converted as u32 == n { - Some(n_converted) - } else { - None - } - } - } - }; -} - -impl_weight_for_float!(f64); -impl_weight_for_float!(f32); -impl_weight_for_int!(usize); -#[cfg(not(target_os = "emscripten"))] -impl_weight_for_int!(u128); -impl_weight_for_int!(u64); -impl_weight_for_int!(u32); -impl_weight_for_int!(u16); -impl_weight_for_int!(u8); -impl_weight_for_int!(isize); -#[cfg(not(target_os = "emscripten"))] -impl_weight_for_int!(i128); -impl_weight_for_int!(i64); -impl_weight_for_int!(i32); -impl_weight_for_int!(i16); -impl_weight_for_int!(i8); - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[cfg(not(miri))] // Miri is too slow - fn test_weighted_index_f32() { - test_weighted_index(f32::into); - - // Floating point special cases - assert_eq!( - WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(vec![-0_f32]).unwrap_err(), - WeightedError::AllWeightsZero - ); - assert_eq!( - WeightedIndex::new(vec![-1_f32]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight - ); - } - - #[cfg(not(target_os = "emscripten"))] - #[test] - #[cfg(not(miri))] // Miri is too slow - fn test_weighted_index_u128() { - test_weighted_index(|x: u128| x as f64); - } - - #[cfg(all(rustc_1_26, not(target_os = "emscripten")))] - #[test] - #[cfg(not(miri))] // Miri is too slow - fn test_weighted_index_i128() { - test_weighted_index(|x: i128| x as f64); - - // Signed integer special cases - assert_eq!( - WeightedIndex::new(vec![-1_i128]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(), - WeightedError::InvalidWeight - ); - } - - #[test] - #[cfg(not(miri))] // Miri is too slow - fn test_weighted_index_u8() { - test_weighted_index(u8::into); - } - - #[test] - #[cfg(not(miri))] // Miri is too slow - fn test_weighted_index_i8() { - test_weighted_index(i8::into); - - // Signed integer special cases - assert_eq!( - WeightedIndex::new(vec![-1_i8]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(), - WeightedError::InvalidWeight - ); - } - - fn test_weighted_index<W: Weight, F: Fn(W) -> f64>(w_to_f64: F) - where - WeightedIndex<W>: fmt::Debug, - { - const NUM_WEIGHTS: u32 = 10; - const ZERO_WEIGHT_INDEX: u32 = 3; - const NUM_SAMPLES: u32 = 15000; - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - - let weights = { - let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); - let random_weight_distribution = crate::distributions::Uniform::new_inclusive( - W::ZERO, - W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), - ); - for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample(&random_weight_distribution)); - } - weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; - weights - }; - let weight_sum = weights.iter().map(|w| *w).sum::<W>(); - let expected_counts = weights - .iter() - .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) - .collect::<Vec<f64>>(); - let weight_distribution = WeightedIndex::new(weights).unwrap(); - - let mut counts = vec![0; NUM_WEIGHTS as usize]; - for _ in 0..NUM_SAMPLES { - counts[rng.sample(&weight_distribution)] += 1; - } - - assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0); - for (count, expected_count) in counts.into_iter().zip(expected_counts) { - let difference = (count as f64 - expected_count).abs(); - let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; - assert!(difference <= max_allowed_difference); - } - - assert_eq!( - WeightedIndex::<W>::new(vec![]).unwrap_err(), - WeightedError::NoItem - ); - assert_eq!( - WeightedIndex::new(vec![W::ZERO]).unwrap_err(), - WeightedError::AllWeightsZero - ); - assert_eq!( - WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - WeightedError::InvalidWeight - ); - } -} diff --git a/rand/src/distributions/weighted/mod.rs b/rand/src/distributions/weighted/mod.rs deleted file mode 100644 index 2711637..0000000 --- a/rand/src/distributions/weighted/mod.rs +++ /dev/null @@ -1,363 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or -// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license -// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted index sampling -//! -//! This module provides two implementations for sampling indices: -//! -//! * [`WeightedIndex`] allows `O(log N)` sampling -//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with -//! much greater set-up cost -//! -//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html - -pub mod alias_method; - -use crate::Rng; -use crate::distributions::Distribution; -use crate::distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow}; -use core::cmp::PartialOrd; -use core::fmt; - -// Note that this whole module is only imported if feature="alloc" is enabled. -#[cfg(not(feature="std"))] use crate::alloc::vec::Vec; - -/// A distribution using weighted sampling to pick a discretely selected -/// item. -/// -/// Sampling a `WeightedIndex` distribution returns the index of a randomly -/// selected element from the iterator used when the `WeightedIndex` was -/// created. The chance of a given element being picked is proportional to the -/// value of the element. The weights can use any type `X` for which an -/// implementation of [`Uniform<X>`] exists. -/// -/// # Performance -/// -/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its -/// size is the sum of the size of those objects, possibly plus some alignment. -/// -/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1` -/// weights of type `X`, where `N` is the number of weights. However, since -/// `Vec` doesn't guarantee a particular growth strategy, additional memory -/// might be allocated but not used. Since the `WeightedIndex` object also -/// contains, this might cause additional allocations, though for primitive -/// types, ['Uniform<X>`] doesn't allocate any memory. -/// -/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. -/// -/// Sampling from `WeightedIndex` will result in a single call to -/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically -/// will request a single value from the underlying [`RngCore`], though the -/// exact number depends on the implementaiton of `Uniform<X>::sample`. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; -/// -/// let choices = ['a', 'b', 'c']; -/// let weights = [2, 1, 1]; -/// let dist = WeightedIndex::new(&weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`Uniform<X>`]: crate::distributions::uniform::Uniform -/// [`RngCore`]: crate::RngCore -#[derive(Debug, Clone)] -pub struct WeightedIndex<X: SampleUniform + PartialOrd> { - cumulative_weights: Vec<X>, - total_weight: X, - weight_distribution: X::Sampler, -} - -impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { - /// Creates a new a `WeightedIndex` [`Distribution`] using the values - /// in `weights`. The weights can use any type `X` for which an - /// implementation of [`Uniform<X>`] exists. - /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. - /// - /// [`Uniform<X>`]: crate::distributions::uniform::Uniform - pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> - where I: IntoIterator, - I::Item: SampleBorrow<X>, - X: for<'a> ::core::ops::AddAssign<&'a X> + - Clone + - Default { - let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next() - .ok_or(WeightedError::NoItem)? - .borrow() - .clone(); - - let zero = <X as Default>::default(); - if total_weight < zero { - return Err(WeightedError::InvalidWeight); - } - - let mut weights = Vec::<X>::with_capacity(iter.size_hint().0); - for w in iter { - if *w.borrow() < zero { - return Err(WeightedError::InvalidWeight); - } - weights.push(total_weight.clone()); - total_weight += w.borrow(); - } - - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - let distr = X::Sampler::new(zero, total_weight.clone()); - - Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr }) - } - - /// Update a subset of weights, without changing the number of weights. - /// - /// `new_weights` must be sorted by the index. - /// - /// Using this method instead of `new` might be more efficient if only a small number of - /// weights is modified. No allocations are performed, unless the weight type `X` uses - /// allocation internally. - /// - /// In case of error, `self` is not modified. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> + - for<'a> ::core::ops::SubAssign<&'a X> + - Clone + - Default { - if new_weights.is_empty() { - return Ok(()); - } - - let zero = <X as Default>::default(); - - let mut total_weight = self.total_weight.clone(); - - // Check for errors first, so we don't modify `self` in case something - // goes wrong. - let mut prev_i = None; - for &(i, w) in new_weights { - if let Some(old_i) = prev_i { - if old_i >= i { - return Err(WeightedError::InvalidWeight); - } - } - if *w < zero { - return Err(WeightedError::InvalidWeight); - } - if i >= self.cumulative_weights.len() + 1 { - return Err(WeightedError::TooMany); - } - - let mut old_w = if i < self.cumulative_weights.len() { - self.cumulative_weights[i].clone() - } else { - self.total_weight.clone() - }; - if i > 0 { - old_w -= &self.cumulative_weights[i - 1]; - } - - total_weight -= &old_w; - total_weight += w; - prev_i = Some(i); - } - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - - // Update the weights. Because we checked all the preconditions in the - // previous loop, this should never panic. - let mut iter = new_weights.iter(); - - let mut prev_weight = zero.clone(); - let mut next_new_weight = iter.next(); - let &(first_new_index, _) = next_new_weight.unwrap(); - let mut cumulative_weight = if first_new_index > 0 { - self.cumulative_weights[first_new_index - 1].clone() - } else { - zero.clone() - }; - for i in first_new_index..self.cumulative_weights.len() { - match next_new_weight { - Some(&(j, w)) if i == j => { - cumulative_weight += w; - next_new_weight = iter.next(); - }, - _ => { - let mut tmp = self.cumulative_weights[i].clone(); - tmp -= &prev_weight; // We know this is positive. - cumulative_weight += &tmp; - } - } - prev_weight = cumulative_weight.clone(); - core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); - } - - self.total_weight = total_weight; - self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); - - Ok(()) - } -} - -impl<X> Distribution<usize> for WeightedIndex<X> where - X: SampleUniform + PartialOrd { - fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { - use ::core::cmp::Ordering; - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights.binary_search_by( - |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[cfg(not(miri))] // Miri is too slow - fn test_weightedindex() { - let mut r = crate::test::rng(700); - const N_REPS: u32 = 5000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::<u32>() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // WeightedIndex from vec - let mut chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from slice - chosen = [0i32; 14]; - let distr = WeightedIndex::new(&weights[..]).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from iterator - chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.iter()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); - assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); - } - - assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); - assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); - assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight); - assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight); - assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight); - } - - #[test] - fn test_update_weights() { - let data = [ - (&[10u32, 2, 3, 4][..], - &[(1, &100), (2, &4)][..], // positive change - &[10, 100, 4, 4][..]), - (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], - &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element - &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]), - ]; - - for (weights, update, expected_weights) in data.into_iter() { - let total_weight = weights.iter().sum::<u32>(); - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, total_weight); - - distr.update_weights(update).unwrap(); - let expected_total_weight = expected_weights.iter().sum::<u32>(); - let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, expected_total_weight); - assert_eq!(distr.total_weight, expected_distr.total_weight); - assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); - } - } -} - -/// Error type returned from `WeightedIndex::new`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided weight collection contains no items. - NoItem, - - /// A weight is either less than zero, greater than the supported maximum or - /// otherwise invalid. - InvalidWeight, - - /// All items in the provided weight collection are zero. - AllWeightsZero, - - /// Too many weights are provided (length greater than `u32::MAX`) - TooMany, -} - -impl WeightedError { - fn msg(&self) -> &str { - match *self { - WeightedError::NoItem => "No weights provided.", - WeightedError::InvalidWeight => "A weight is invalid.", - WeightedError::AllWeightsZero => "All weights are zero.", - WeightedError::TooMany => "Too many weights (hit u32::MAX)", - } - } -} - -#[cfg(feature="std")] -impl ::std::error::Error for WeightedError { - fn description(&self) -> &str { - self.msg() - } - fn cause(&self) -> Option<&dyn (::std::error::Error)> { - None - } -} - -impl fmt::Display for WeightedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.msg()) - } -} |