summaryrefslogtreecommitdiff
path: root/rand/src/distributions/weighted.rs
diff options
context:
space:
mode:
authorDaniel Mueller <deso@posteo.net>2019-01-02 21:14:10 -0800
committerDaniel Mueller <deso@posteo.net>2019-01-02 21:14:10 -0800
commitecf3474223ca3d16a10f12dc2272e3b0ed72c1bb (patch)
tree03134a683791176b49ef5c92e8d6acd24c3b5a9b /rand/src/distributions/weighted.rs
parent686f61b75055ecb02baf9d9449525ae447a3bed1 (diff)
downloadnitrocli-ecf3474223ca3d16a10f12dc2272e3b0ed72c1bb.tar.gz
nitrocli-ecf3474223ca3d16a10f12dc2272e3b0ed72c1bb.tar.bz2
Update nitrokey crate to 0.2.3
This change updates the nitrokey crate to version 0.2.3. This version bumps the rand crate used to 0.6.1, which in turn requires an additional set of dependencies. Import subrepo nitrokey/:nitrokey at b3e2adc5bb1300441ca74cc7672617c042f3ea31 Import subrepo rand/:rand at 73613ff903512e9503e41cc8ba9eae76269dc598 Import subrepo rustc_version/:rustc_version at 0294f2ba2018bf7be672abd53db351ce5055fa02 Import subrepo semver-parser/:semver-parser at 750da9b11a04125231b1fb293866ca036845acee Import subrepo semver/:semver at 5eb6db94fa03f4d5c64a625a56188f496be47598
Diffstat (limited to 'rand/src/distributions/weighted.rs')
-rw-r--r--rand/src/distributions/weighted.rs232
1 files changed, 232 insertions, 0 deletions
diff --git a/rand/src/distributions/weighted.rs b/rand/src/distributions/weighted.rs
new file mode 100644
index 0000000..01c8fe6
--- /dev/null
+++ b/rand/src/distributions/weighted.rs
@@ -0,0 +1,232 @@
+// Copyright 2018 Developers of the Rand project.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use Rng;
+use distributions::Distribution;
+use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
+use ::core::cmp::PartialOrd;
+use core::fmt;
+
+// Note that this whole module is only imported if feature="alloc" is enabled.
+#[cfg(not(feature="std"))] use alloc::vec::Vec;
+
+/// A distribution using weighted sampling to pick a discretely selected
+/// item.
+///
+/// Sampling a `WeightedIndex` distribution returns the index of a randomly
+/// selected element from the iterator used when the `WeightedIndex` was
+/// created. The chance of a given element being picked is proportional to the
+/// value of the element. The weights can use any type `X` for which an
+/// implementation of [`Uniform<X>`] exists.
+///
+/// # Performance
+///
+/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
+/// size is the sum of the size of those objects, possibly plus some alignment.
+///
+/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
+/// weights of type `X`, where `N` is the number of weights. However, since
+/// `Vec` doesn't guarantee a particular growth strategy, additional memory
+/// might be allocated but not used. Since the `WeightedIndex` object also
+/// contains, this might cause additional allocations, though for primitive
+/// types, ['Uniform<X>`] doesn't allocate any memory.
+///
+/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
+/// `N` is the number of weights.
+///
+/// Sampling from `WeightedIndex` will result in a single call to
+/// [`Uniform<X>::sample`], which typically will request a single value from
+/// the underlying [`RngCore`], though the exact number depends on the
+/// implementaiton of [`Uniform<X>::sample`].
+///
+/// # Example
+///
+/// ```
+/// use rand::prelude::*;
+/// use rand::distributions::WeightedIndex;
+///
+/// let choices = ['a', 'b', 'c'];
+/// let weights = [2, 1, 1];
+/// let dist = WeightedIndex::new(&weights).unwrap();
+/// let mut rng = thread_rng();
+/// for _ in 0..100 {
+/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
+/// println!("{}", choices[dist.sample(&mut rng)]);
+/// }
+///
+/// let items = [('a', 0), ('b', 3), ('c', 7)];
+/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
+/// for _ in 0..100 {
+/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
+/// println!("{}", items[dist2.sample(&mut rng)].0);
+/// }
+/// ```
+///
+/// [`Uniform<X>`]: struct.Uniform.html
+/// [`Uniform<X>::sample`]: struct.Uniform.html#method.sample
+/// [`RngCore`]: ../trait.RngCore.html
+#[derive(Debug, Clone)]
+pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
+ cumulative_weights: Vec<X>,
+ weight_distribution: X::Sampler,
+}
+
+impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
+ /// Creates a new a `WeightedIndex` [`Distribution`] using the values
+ /// in `weights`. The weights can use any type `X` for which an
+ /// implementation of [`Uniform<X>`] exists.
+ ///
+ /// Returns an error if the iterator is empty, if any weight is `< 0`, or
+ /// if its total value is 0.
+ ///
+ /// [`Distribution`]: trait.Distribution.html
+ /// [`Uniform<X>`]: struct.Uniform.html
+ pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
+ where I: IntoIterator,
+ I::Item: SampleBorrow<X>,
+ X: for<'a> ::core::ops::AddAssign<&'a X> +
+ Clone +
+ Default {
+ let mut iter = weights.into_iter();
+ let mut total_weight: X = iter.next()
+ .ok_or(WeightedError::NoItem)?
+ .borrow()
+ .clone();
+
+ let zero = <X as Default>::default();
+ if total_weight < zero {
+ return Err(WeightedError::NegativeWeight);
+ }
+
+ let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
+ for w in iter {
+ if *w.borrow() < zero {
+ return Err(WeightedError::NegativeWeight);
+ }
+ weights.push(total_weight.clone());
+ total_weight += w.borrow();
+ }
+
+ if total_weight == zero {
+ return Err(WeightedError::AllWeightsZero);
+ }
+ let distr = X::Sampler::new(zero, total_weight);
+
+ Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
+ }
+}
+
+impl<X> Distribution<usize> for WeightedIndex<X> where
+ X: SampleUniform + PartialOrd {
+ fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
+ use ::core::cmp::Ordering;
+ let chosen_weight = self.weight_distribution.sample(rng);
+ // Find the first item which has a weight *higher* than the chosen weight.
+ self.cumulative_weights.binary_search_by(
+ |w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_weightedindex() {
+ let mut r = ::test::rng(700);
+ const N_REPS: u32 = 5000;
+ let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
+ let total_weight = weights.iter().sum::<u32>() as f32;
+
+ let verify = |result: [i32; 14]| {
+ for (i, count) in result.iter().enumerate() {
+ let exp = (weights[i] * N_REPS) as f32 / total_weight;
+ let mut err = (*count as f32 - exp).abs();
+ if err != 0.0 {
+ err /= exp;
+ }
+ assert!(err <= 0.25);
+ }
+ };
+
+ // WeightedIndex from vec
+ let mut chosen = [0i32; 14];
+ let distr = WeightedIndex::new(weights.to_vec()).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ // WeightedIndex from slice
+ chosen = [0i32; 14];
+ let distr = WeightedIndex::new(&weights[..]).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ // WeightedIndex from iterator
+ chosen = [0i32; 14];
+ let distr = WeightedIndex::new(weights.iter()).unwrap();
+ for _ in 0..N_REPS {
+ chosen[distr.sample(&mut r)] += 1;
+ }
+ verify(chosen);
+
+ for _ in 0..5 {
+ assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
+ assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
+ assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
+ }
+
+ assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
+ assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
+ assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight);
+ assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
+ assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
+ }
+}
+
+/// Error type returned from `WeightedIndex::new`.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum WeightedError {
+ /// The provided iterator contained no items.
+ NoItem,
+
+ /// A weight lower than zero was used.
+ NegativeWeight,
+
+ /// All items in the provided iterator had a weight of zero.
+ AllWeightsZero,
+}
+
+impl WeightedError {
+ fn msg(&self) -> &str {
+ match *self {
+ WeightedError::NoItem => "No items found",
+ WeightedError::NegativeWeight => "Item has negative weight",
+ WeightedError::AllWeightsZero => "All items had weight zero",
+ }
+ }
+}
+
+#[cfg(feature="std")]
+impl ::std::error::Error for WeightedError {
+ fn description(&self) -> &str {
+ self.msg()
+ }
+ fn cause(&self) -> Option<&::std::error::Error> {
+ None
+ }
+}
+
+impl fmt::Display for WeightedError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", self.msg())
+ }
+}