summaryrefslogtreecommitdiff
path: root/rand/src/distributions/weighted
diff options
context:
space:
mode:
Diffstat (limited to 'rand/src/distributions/weighted')
-rw-r--r--rand/src/distributions/weighted/alias_method.rs499
-rw-r--r--rand/src/distributions/weighted/mod.rs363
2 files changed, 862 insertions, 0 deletions
diff --git a/rand/src/distributions/weighted/alias_method.rs b/rand/src/distributions/weighted/alias_method.rs
new file mode 100644
index 0000000..bdd4ba0
--- /dev/null
+++ b/rand/src/distributions/weighted/alias_method.rs
@@ -0,0 +1,499 @@
+//! This module contains an implementation of alias method for sampling random
+//! indices with probabilities proportional to a collection of weights.
+
+use super::WeightedError;
+#[cfg(not(feature = "std"))]
+use crate::alloc::vec::Vec;
+#[cfg(not(feature = "std"))]
+use crate::alloc::vec;
+use core::fmt;
+use core::iter::Sum;
+use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
+use crate::distributions::uniform::SampleUniform;
+use crate::distributions::Distribution;
+use crate::distributions::Uniform;
+use crate::Rng;
+
+/// A distribution using weighted sampling to pick a discretely selected item.
+///
+/// Sampling a [`WeightedIndex<W>`] distribution returns the index of a randomly
+/// selected element from the vector used to create the [`WeightedIndex<W>`].
+/// The chance of a given element being picked is proportional to the value of
+/// the element. The weights can have any type `W` for which a implementation of
+/// [`Weight`] exists.
+///
+/// # Performance
+///
+/// Given that `n` is the number of items in the vector used to create an
+/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of
+/// memory. More specifically it takes up some constant amount of memory plus
+/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
+///
+/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`.
+/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
+/// to [`Uniform<W>::sample`].
+///
+/// # Example
+///
+/// ```
+/// use rand::distributions::weighted::alias_method::WeightedIndex;
+/// use rand::prelude::*;
+///
+/// let choices = vec!['a', 'b', 'c'];
+/// let weights = vec![2, 1, 1];
+/// let dist = WeightedIndex::new(weights).unwrap();
+/// let mut rng = thread_rng();
+/// for _ in 0..100 {
+/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
+/// println!("{}", choices[dist.sample(&mut rng)]);
+/// }
+///
+/// let items = [('a', 0), ('b', 3), ('c', 7)];
+/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1).collect()).unwrap();
+/// for _ in 0..100 {
+/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
+/// println!("{}", items[dist2.sample(&mut rng)].0);
+/// }
+/// ```
+///
+/// [`WeightedIndex<W>`]: crate::distributions::weighted::alias_method::WeightedIndex
+/// [`Weight`]: crate::distributions::weighted::alias_method::Weight
+/// [`Vec<u32>`]: Vec
+/// [`Uniform<u32>::sample`]: Distribution::sample
+/// [`Uniform<W>::sample`]: Distribution::sample
+pub struct WeightedIndex<W: Weight> {
+ aliases: Vec<u32>,
+ no_alias_odds: Vec<W>,
+ uniform_index: Uniform<u32>,
+ uniform_within_weight_sum: Uniform<W>,
+}
+
+impl<W: Weight> WeightedIndex<W> {
+ /// Creates a new [`WeightedIndex`].
+ ///
+ /// Returns an error if:
+ /// - The vector is empty.
+ /// - The vector is longer than `u32::MAX`.
+ /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
+ /// weights.len()`.
+ /// - The sum of weights is zero.
+ pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
+ let n = weights.len();
+ if n == 0 {
+ return Err(WeightedError::NoItem);
+ } else if n > ::core::u32::MAX as usize {
+ return Err(WeightedError::TooMany);
+ }
+ let n = n as u32;
+
+ let max_weight_size = W::try_from_u32_lossy(n)
+ .map(|n| W::MAX / n)
+ .unwrap_or(W::ZERO);
+ if !weights
+ .iter()
+ .all(|&w| W::ZERO <= w && w <= max_weight_size)
+ {
+ return Err(WeightedError::InvalidWeight);
+ }
+
+ // The sum of weights will represent 100% of no alias odds.
+ let weight_sum = Weight::sum(weights.as_slice());
+ // Prevent floating point overflow due to rounding errors.
+ let weight_sum = if weight_sum > W::MAX {
+ W::MAX
+ } else {
+ weight_sum
+ };
+ if weight_sum == W::ZERO {
+ return Err(WeightedError::AllWeightsZero);
+ }
+
+ // `weight_sum` would have been zero if `try_from_lossy` causes an error here.
+ let n_converted = W::try_from_u32_lossy(n).unwrap();
+
+ let mut no_alias_odds = weights;
+ for odds in no_alias_odds.iter_mut() {
+ *odds *= n_converted;
+ // Prevent floating point overflow due to rounding errors.
+ *odds = if *odds > W::MAX { W::MAX } else { *odds };
+ }
+
+ /// This struct is designed to contain three data structures at once,
+ /// sharing the same memory. More precisely it contains two linked lists
+ /// and an alias map, which will be the output of this method. To keep
+ /// the three data structures from getting in each other's way, it must
+ /// be ensured that a single index is only ever in one of them at the
+ /// same time.
+ struct Aliases {
+ aliases: Vec<u32>,
+ smalls_head: u32,
+ bigs_head: u32,
+ }
+
+ impl Aliases {
+ fn new(size: u32) -> Self {
+ Aliases {
+ aliases: vec![0; size as usize],
+ smalls_head: ::core::u32::MAX,
+ bigs_head: ::core::u32::MAX,
+ }
+ }
+
+ fn push_small(&mut self, idx: u32) {
+ self.aliases[idx as usize] = self.smalls_head;
+ self.smalls_head = idx;
+ }
+
+ fn push_big(&mut self, idx: u32) {
+ self.aliases[idx as usize] = self.bigs_head;
+ self.bigs_head = idx;
+ }
+
+ fn pop_small(&mut self) -> u32 {
+ let popped = self.smalls_head;
+ self.smalls_head = self.aliases[popped as usize];
+ popped
+ }
+
+ fn pop_big(&mut self) -> u32 {
+ let popped = self.bigs_head;
+ self.bigs_head = self.aliases[popped as usize];
+ popped
+ }
+
+ fn smalls_is_empty(&self) -> bool {
+ self.smalls_head == ::core::u32::MAX
+ }
+
+ fn bigs_is_empty(&self) -> bool {
+ self.bigs_head == ::core::u32::MAX
+ }
+
+ fn set_alias(&mut self, idx: u32, alias: u32) {
+ self.aliases[idx as usize] = alias;
+ }
+ }
+
+ let mut aliases = Aliases::new(n);
+
+ // Split indices into those with small weights and those with big weights.
+ for (index, &odds) in no_alias_odds.iter().enumerate() {
+ if odds < weight_sum {
+ aliases.push_small(index as u32);
+ } else {
+ aliases.push_big(index as u32);
+ }
+ }
+
+ // Build the alias map by finding an alias with big weight for each index with
+ // small weight.
+ while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() {
+ let s = aliases.pop_small();
+ let b = aliases.pop_big();
+
+ aliases.set_alias(s, b);
+ no_alias_odds[b as usize] = no_alias_odds[b as usize]
+ - weight_sum
+ + no_alias_odds[s as usize];
+
+ if no_alias_odds[b as usize] < weight_sum {
+ aliases.push_small(b);
+ } else {
+ aliases.push_big(b);
+ }
+ }
+
+ // The remaining indices should have no alias odds of about 100%. This is due to
+ // numeric accuracy. Otherwise they would be exactly 100%.
+ while !aliases.smalls_is_empty() {
+ no_alias_odds[aliases.pop_small() as usize] = weight_sum;
+ }
+ while !aliases.bigs_is_empty() {
+ no_alias_odds[aliases.pop_big() as usize] = weight_sum;
+ }
+
+ // Prepare distributions for sampling. Creating them beforehand improves
+ // sampling performance.
+ let uniform_index = Uniform::new(0, n);
+ let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum);
+
+ Ok(Self {
+ aliases: aliases.aliases,
+ no_alias_odds,
+ uniform_index,
+ uniform_within_weight_sum,
+ })
+ }
+}
+
+impl<W: Weight> Distribution<usize> for WeightedIndex<W> {
+ fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
+ let candidate = rng.sample(self.uniform_index);
+ if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
+ candidate as usize
+ } else {
+ self.aliases[candidate as usize] as usize
+ }
+ }
+}
+
+impl<W: Weight> fmt::Debug for WeightedIndex<W>
+where
+ W: fmt::Debug,
+ Uniform<W>: fmt::Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_struct("WeightedIndex")
+ .field("aliases", &self.aliases)
+ .field("no_alias_odds", &self.no_alias_odds)
+ .field("uniform_index", &self.uniform_index)
+ .field("uniform_within_weight_sum", &self.uniform_within_weight_sum)
+ .finish()
+ }
+}
+
+impl<W: Weight> Clone for WeightedIndex<W>
+where
+ Uniform<W>: Clone,
+{
+ fn clone(&self) -> Self {
+ Self {
+ aliases: self.aliases.clone(),
+ no_alias_odds: self.no_alias_odds.clone(),
+ uniform_index: self.uniform_index.clone(),
+ uniform_within_weight_sum: self.uniform_within_weight_sum.clone(),
+ }
+ }
+}
+
+/// Trait that must be implemented for weights, that are used with
+/// [`WeightedIndex`]. Currently no guarantees on the correctness of
+/// [`WeightedIndex`] are given for custom implementations of this trait.
+pub trait Weight:
+ Sized
+ + Copy
+ + SampleUniform
+ + PartialOrd
+ + Add<Output = Self>
+ + AddAssign
+ + Sub<Output = Self>
+ + SubAssign
+ + Mul<Output = Self>
+ + MulAssign
+ + Div<Output = Self>
+ + DivAssign
+ + Sum
+{
+ /// Maximum number representable by `Self`.
+ const MAX: Self;
+
+ /// Element of `Self` equivalent to 0.
+ const ZERO: Self;
+
+ /// Produce an instance of `Self` from a `u32` value, or return `None` if
+ /// out of range. Loss of precision (where `Self` is a floating point type)
+ /// is acceptable.
+ fn try_from_u32_lossy(n: u32) -> Option<Self>;
+
+ /// Sums all values in slice `values`.
+ fn sum(values: &[Self]) -> Self {
+ values.iter().map(|x| *x).sum()
+ }
+}
+
+macro_rules! impl_weight_for_float {
+ ($T: ident) => {
+ impl Weight for $T {
+ const MAX: Self = ::core::$T::MAX;
+ const ZERO: Self = 0.0;
+
+ fn try_from_u32_lossy(n: u32) -> Option<Self> {
+ Some(n as $T)
+ }
+
+ fn sum(values: &[Self]) -> Self {
+ pairwise_sum(values)
+ }
+ }
+ };
+}
+
+/// In comparison to naive accumulation, the pairwise sum algorithm reduces
+/// rounding errors when there are many floating point values.
+fn pairwise_sum<T: Weight>(values: &[T]) -> T {
+ if values.len() <= 32 {
+ values.iter().map(|x| *x).sum()
+ } else {
+ let mid = values.len() / 2;
+ let (a, b) = values.split_at(mid);
+ pairwise_sum(a) + pairwise_sum(b)
+ }
+}
+
+macro_rules! impl_weight_for_int {
+ ($T: ident) => {
+ impl Weight for $T {
+ const MAX: Self = ::core::$T::MAX;
+ const ZERO: Self = 0;
+
+ fn try_from_u32_lossy(n: u32) -> Option<Self> {
+ let n_converted = n as Self;
+ if n_converted >= Self::ZERO && n_converted as u32 == n {
+ Some(n_converted)
+ } else {
+ None
+ }
+ }
+ }
+ };
+}
+
+impl_weight_for_float!(f64);
+impl_weight_for_float!(f32);
+impl_weight_for_int!(usize);
+#[cfg(not(target_os = "emscripten"))]
+impl_weight_for_int!(u128);
+impl_weight_for_int!(u64);
+impl_weight_for_int!(u32);
+impl_weight_for_int!(u16);
+impl_weight_for_int!(u8);
+impl_weight_for_int!(isize);
+#[cfg(not(target_os = "emscripten"))]
+impl_weight_for_int!(i128);
+impl_weight_for_int!(i64);
+impl_weight_for_int!(i32);
+impl_weight_for_int!(i16);
+impl_weight_for_int!(i8);
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ #[cfg(not(miri))] // Miri is too slow
+ fn test_weighted_index_f32() {
+ test_weighted_index(f32::into);
+
+ // Floating point special cases
+ assert_eq!(
+ WeightedIndex::new(vec![::core::f32::INFINITY]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![-0_f32]).unwrap_err(),
+ WeightedError::AllWeightsZero
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![-1_f32]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![::core::f32::NAN]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ }
+
+ #[cfg(not(target_os = "emscripten"))]
+ #[test]
+ #[cfg(not(miri))] // Miri is too slow
+ fn test_weighted_index_u128() {
+ test_weighted_index(|x: u128| x as f64);
+ }
+
+ #[cfg(all(rustc_1_26, not(target_os = "emscripten")))]
+ #[test]
+ #[cfg(not(miri))] // Miri is too slow
+ fn test_weighted_index_i128() {
+ test_weighted_index(|x: i128| x as f64);
+
+ // Signed integer special cases
+ assert_eq!(
+ WeightedIndex::new(vec![-1_i128]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![::core::i128::MIN]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ }
+
+ #[test]
+ #[cfg(not(miri))] // Miri is too slow
+ fn test_weighted_index_u8() {
+ test_weighted_index(u8::into);
+ }
+
+ #[test]
+ #[cfg(not(miri))] // Miri is too slow
+ fn test_weighted_index_i8() {
+ test_weighted_index(i8::into);
+
+ // Signed integer special cases
+ assert_eq!(
+ WeightedIndex::new(vec![-1_i8]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![::core::i8::MIN]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ }
+
+ fn test_weighted_index<W: Weight, F: Fn(W) -> f64>(w_to_f64: F)
+ where
+ WeightedIndex<W>: fmt::Debug,
+ {
+ const NUM_WEIGHTS: u32 = 10;
+ const ZERO_WEIGHT_INDEX: u32 = 3;
+ const NUM_SAMPLES: u32 = 15000;
+ let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
+
+ let weights = {
+ let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
+ let random_weight_distribution = crate::distributions::Uniform::new_inclusive(
+ W::ZERO,
+ W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
+ );
+ for _ in 0..NUM_WEIGHTS {
+ weights.push(rng.sample(&random_weight_distribution));
+ }
+ weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
+ weights
+ };
+ let weight_sum = weights.iter().map(|w| *w).sum::<W>();
+ let expected_counts = weights
+ .iter()
+ .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64)
+ .collect::<Vec<f64>>();
+ let weight_distribution = WeightedIndex::new(weights).unwrap();
+
+ let mut counts = vec![0; NUM_WEIGHTS as usize];
+ for _ in 0..NUM_SAMPLES {
+ counts[rng.sample(&weight_distribution)] += 1;
+ }
+
+ assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
+ for (count, expected_count) in counts.into_iter().zip(expected_counts) {
+ let difference = (count as f64 - expected_count).abs();
+ let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;
+ assert!(difference <= max_allowed_difference);
+ }
+
+ assert_eq!(
+ WeightedIndex::<W>::new(vec![]).unwrap_err(),
+ WeightedError::NoItem
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![W::ZERO]).unwrap_err(),
+ WeightedError::AllWeightsZero
+ );
+ assert_eq!(
+ WeightedIndex::new(vec![W::MAX, W::MAX]).unwrap_err(),
+ WeightedError::InvalidWeight
+ );
+ }
+}
diff --git a/rand/src/distributions/weighted/mod.rs b/rand/src/distributions/weighted/mod.rs
new file mode 100644
index 0000000..2711637
--- /dev/null
+++ b/rand/src/distributions/weighted/mod.rs
@@ -0,0 +1,363 @@
+// Copyright 2018 Developers of the Rand project.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+//! Weighted index sampling
+//!
+//! This module provides two implementations for sampling indices:
+//!
+//! * [`WeightedIndex`] allows `O(log N)` sampling
+//! * [`alias_method::WeightedIndex`] allows `O(1)` sampling, but with
+//! much greater set-up cost
+//!
+//! [`alias_method::WeightedIndex`]: alias_method/struct.WeightedIndex.html
+
+pub mod alias_method;
+
+use crate::Rng;
+use crate::distributions::Distribution;
+use crate::distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
+use core::cmp::PartialOrd;
+use core::fmt;
+
+// Note that this whole module is only imported if feature="alloc" is enabled.
+#[cfg(not(feature="std"))] use crate::alloc::vec::Vec;
+
+/// A distribution using weighted sampling to pick a discretely selected
+/// item.
+///
+/// Sampling a `WeightedIndex` distribution returns the index of a randomly
+/// selected element from the iterator used when the `WeightedIndex` was
+/// created. The chance of a given element being picked is proportional to the
+/// value of the element. The weights can use any type `X` for which an
+/// implementation of [`Uniform<X>`] exists.
+///
+/// # Performance
+///
+/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
+/// size is the sum of the size of those objects, possibly plus some alignment.
+///
+/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
+/// weights of type `X`, where `N` is the number of weights. However, since
+/// `Vec` doesn't guarantee a particular growth strategy, additional memory
+/// might be allocated but not used. Since the `WeightedIndex` object also
+/// contains, this might cause additional allocations, though for primitive
+/// types, ['Uniform<X>`] doesn't allocate any memory.
+///
+/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
+/// `N` is the number of weights.
+///
+/// Sampling from `WeightedIndex` will result in a single call to
+/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
+/// will request a single value from the underlying [`RngCore`], though the
+/// exact number depends on the implementaiton of `Uniform<X>::sample`.
+///
+/// # Example
+///
+/// ```
+/// use rand::prelude::*;
+/// use rand::distributions::WeightedIndex;
+///
+/// let choices = ['a', 'b', 'c'];
+/// let weights = [2, 1, 1];
+/// let dist = WeightedIndex::new(&weights).unwrap();
+/// let mut rng = thread_rng();
+/// for _ in 0..100 {
+/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
+/// println!("{}", choices[dist.sample(&mut rng)]);
+/// }
+///
+/// let items = [('a', 0), ('b', 3), ('c', 7)];
+/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
+/// for _ in 0..100 {
+/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
+/// println!("{}", items[dist2.sample(&mut rng)].0);
+/// }
+/// ```
+///
+/// [`Uniform<X>`]: crate::distributions::uniform::Uniform
+/// [`RngCore`]: crate::RngCore
+#[derive(Debug, Clone)]
+pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
+ cumulative_weights: Vec<X>,
+ total_weight: X,
+ weight_distribution: X::Sampler,
+}
+
+impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
+ /// Creates a new a `WeightedIndex` [`Distribution`] using the values
+ /// in `weights`. The weights can use any type `X` for which an
+ /// implementation of [`Uniform<X>`] exists.
+ ///
+ /// Returns an error if the iterator is empty, if any weight is `< 0`, or
+ /// if its total value is 0.
+ ///
+ /// [`Uniform<X>`]: crate::distributions::uniform::Uniform
+ pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
+ where I: IntoIterator,
+ I::Item: SampleBorrow<X>,
+ X: for<'a> ::core::ops::AddAssign<&'a X> +
+ Clone +
+ Default {
+ let mut iter = weights.into_iter();
+ let mut total_weight: X = iter.next()
+ .ok_or(WeightedError::NoItem)?
+ .borrow()
+ .clone();
+
+ let zero = <X as Default>::default();
+ if total_weight < zero {
+ return Err(WeightedError::InvalidWeight);
+ }
+
+ let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
+ for w in iter {
+ if *w.borrow() < zero {
+ return Err(WeightedError::InvalidWeight);
+ }
+ weights.push(total_weight.clone());
+ total_weight += w.borrow();
+ }
+
+ if total_weight == zero {
+ return Err(WeightedError::AllWeightsZero);
+ }
+ let distr = X::Sampler::new(zero, total_weight.clone());
+
+ Ok(WeightedIndex { cumulative_weights: weights, total_weight, weight_distribution: distr })
+ }
+
+ /// Update a subset of weights, without changing the number of weights.
+ ///
+ /// `new_weights` must be sorted by the index.
+ ///
+ /// Using this method instead of `new` might be more efficient if only a small number of
+ /// weights is modified. No allocations are performed, unless the weight type `X` uses
+ /// allocation internally.
+ ///
+ /// In case of error, `self` is not modified.
+ pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
+ where X: for<'a> ::core::ops::AddAssign<&'a X> +
+ for<'a> ::core::ops::SubAssign<&'a X> +
+ Clone +
+ Default {
+ if new_weights.is_empty() {
+ return Ok(());
+ }
+
+ let zero = <X as Default>::default();
+
+ let mut total_weight = self.total_weight.clone();
+
+ // Check for errors first, so we don't modify `self` in case something
+ // goes wrong.
+ let mut prev_i = None;
+ for &(i, w) in new_weights {
+ if let Some(old_i) = prev_i {
+ if old_i >= i {
+ return Err(WeightedError::InvalidWeight);
+ }
+ }
+ if *w < zero {
+ return Err(WeightedError::InvalidWeight);
+ }
+ if i >= self.cumulative_weights.len() + 1 {
+ return Err(WeightedError::TooMany);
+ }
+
+ let mut old_w = if i < self.cumulative_weights.len() {
+ self.cumulative_weights[i].clone()
+ } else {
+ self.total_weight.clone()
+ };
+ if i > 0 {
+ old_w -= &self.cumulative_weights[i - 1];
+ }
+
+ total_weight -= &old_w;
+ total_weight += w;
+ prev_i = Some(i);
+ }
+ if total_weight == zero {
+ return Err(WeightedError::AllWeightsZero);
+ }
+
+ // Update the weights. Because we checked all the preconditions in the
+ // previous loop, this should never panic.
+ let mut iter = new_weights.iter();
+
+ let mut prev_weight = zero.clone();
+ let mut next_new_weight = iter.next();
+ let &(first_new_index, _) = next_new_weight.unwrap();
+ let mut cumulative_weight = if first_new_index > 0 {
+ self.cumulative_weights[first_new_index - 1].clone()
+ } else {
+ zero.clone()
+ };
+ for i in first_new_index..self.cumulative_weights.len() {
+ match next_new_weight {
+ Some(&(j, w)) if i == j => {
+ cumulative_weight += w;
+ next_new_weight = iter.next();
+ },
+ _ => {
+ let mut tmp = self.cumulative_weights[i].clone();
+ tmp -= &prev_weight; // We know this is positive.
+ cumulative_weight += &tmp;
+ }
+ }
+ prev_weight = cumulative_weight.clone();
+ core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
+ }
+
+ self.total_weight = total_weight;
+ self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
+
+ Ok(())
+ }
+}
+
+impl<X> Distribution<usize> for WeightedIndex<X> where
+ X: SampleUniform + PartialOrd {
+ fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
+ use ::core::cmp::Ordering;
+ let chosen_weight = self.weight_distribution.sample(rng);
+ // Find the first item which has a weight *higher* than the chosen weight.
+ self.cumulative_weights.binary_search_by(
+ |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ #[cfg(not(miri))] // Miri is too slow
+ fn test_weightedindex() {
+ let mut r = crate::test::rng(700);
+ const N_REPS: u32 = 5000;
+ let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
+ let total_weight = weights.iter().sum::<u32>() as f32;
+
+ let verify = |result: [i32; 14]| {
+ for (i, count) in result.iter().enumerate() {
+ let exp = (weights[i] * N_REPS) as f32 / total_weight;
+ let mut err = (*count as f32 - exp).abs();
+ if err != 0.0 {
+ err /= exp;
+ }
+ assert!(err <= 0.25);
+ }
+ };
+
+ // WeightedIndex from vec
+ let mut chosen = [0i32; 14];
+ let distr = WeightedIndex::new(weights.to_vec()).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ // WeightedIndex from slice
+ chosen = [0i32; 14];
+ let distr = WeightedIndex::new(&weights[..]).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ // WeightedIndex from iterator
+ chosen = [0i32; 14];
+ let distr = WeightedIndex::new(weights.iter()).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ for _ in 0..5 {
+ assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
+ assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
+ assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
+ }
+
+ assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
+ assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
+ assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight);
+ assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::InvalidWeight);
+ assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight);
+ }
+
+ #[test]
+ fn test_update_weights() {
+ let data = [
+ (&[10u32, 2, 3, 4][..],
+ &[(1, &100), (2, &4)][..], // positive change
+ &[10, 100, 4, 4][..]),
+ (&[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
+ &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
+ &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..]),
+ ];
+
+ for (weights, update, expected_weights) in data.into_iter() {
+ let total_weight = weights.iter().sum::<u32>();
+ let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
+ assert_eq!(distr.total_weight, total_weight);
+
+ distr.update_weights(update).unwrap();
+ let expected_total_weight = expected_weights.iter().sum::<u32>();
+ let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
+ assert_eq!(distr.total_weight, expected_total_weight);
+ assert_eq!(distr.total_weight, expected_distr.total_weight);
+ assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
+ }
+ }
+}
+
+/// Error type returned from `WeightedIndex::new`.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum WeightedError {
+ /// The provided weight collection contains no items.
+ NoItem,
+
+ /// A weight is either less than zero, greater than the supported maximum or
+ /// otherwise invalid.
+ InvalidWeight,
+
+ /// All items in the provided weight collection are zero.
+ AllWeightsZero,
+
+ /// Too many weights are provided (length greater than `u32::MAX`)
+ TooMany,
+}
+
+impl WeightedError {
+ fn msg(&self) -> &str {
+ match *self {
+ WeightedError::NoItem => "No weights provided.",
+ WeightedError::InvalidWeight => "A weight is invalid.",
+ WeightedError::AllWeightsZero => "All weights are zero.",
+ WeightedError::TooMany => "Too many weights (hit u32::MAX)",
+ }
+ }
+}
+
+#[cfg(feature="std")]
+impl ::std::error::Error for WeightedError {
+ fn description(&self) -> &str {
+ self.msg()
+ }
+ fn cause(&self) -> Option<&dyn (::std::error::Error)> {
+ None
+ }
+}
+
+impl fmt::Display for WeightedError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", self.msg())
+ }
+}