/rust/registry/src/index.crates.io-1949cf8c6b5b557f/rand-0.8.5/src/seq/mod.rs
Line | Count | Source |
1 | | // Copyright 2018 Developers of the Rand project. |
2 | | // |
3 | | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or |
4 | | // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license |
5 | | // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your |
6 | | // option. This file may not be copied, modified, or distributed |
7 | | // except according to those terms. |
8 | | |
9 | | //! Sequence-related functionality |
10 | | //! |
11 | | //! This module provides: |
12 | | //! |
13 | | //! * [`SliceRandom`] slice sampling and mutation |
14 | | //! * [`IteratorRandom`] iterator sampling |
15 | | //! * [`index::sample`] low-level API to choose multiple indices from |
16 | | //! `0..length` |
17 | | //! |
18 | | //! Also see: |
19 | | //! |
20 | | //! * [`crate::distributions::WeightedIndex`] distribution which provides |
21 | | //! weighted index sampling. |
22 | | //! |
23 | | //! In order to make results reproducible across 32-64 bit architectures, all |
24 | | //! `usize` indices are sampled as a `u32` where possible (also providing a |
25 | | //! small performance boost in some cases). |
26 | | |
27 | | |
28 | | #[cfg(feature = "alloc")] |
29 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] |
30 | | pub mod index; |
31 | | |
32 | | #[cfg(feature = "alloc")] use core::ops::Index; |
33 | | |
34 | | #[cfg(feature = "alloc")] use alloc::vec::Vec; |
35 | | |
36 | | #[cfg(feature = "alloc")] |
37 | | use crate::distributions::uniform::{SampleBorrow, SampleUniform}; |
38 | | #[cfg(feature = "alloc")] use crate::distributions::WeightedError; |
39 | | use crate::Rng; |
40 | | |
41 | | /// Extension trait on slices, providing random mutation and sampling methods. |
42 | | /// |
43 | | /// This trait is implemented on all `[T]` slice types, providing several |
44 | | /// methods for choosing and shuffling elements. You must `use` this trait: |
45 | | /// |
46 | | /// ``` |
47 | | /// use rand::seq::SliceRandom; |
48 | | /// |
49 | | /// let mut rng = rand::thread_rng(); |
50 | | /// let mut bytes = "Hello, random!".to_string().into_bytes(); |
51 | | /// bytes.shuffle(&mut rng); |
52 | | /// let str = String::from_utf8(bytes).unwrap(); |
53 | | /// println!("{}", str); |
54 | | /// ``` |
55 | | /// Example output (non-deterministic): |
56 | | /// ```none |
57 | | /// l,nmroHado !le |
58 | | /// ``` |
59 | | pub trait SliceRandom { |
60 | | /// The element type. |
61 | | type Item; |
62 | | |
63 | | /// Returns a reference to one random element of the slice, or `None` if the |
64 | | /// slice is empty. |
65 | | /// |
66 | | /// For slices, complexity is `O(1)`. |
67 | | /// |
68 | | /// # Example |
69 | | /// |
70 | | /// ``` |
71 | | /// use rand::thread_rng; |
72 | | /// use rand::seq::SliceRandom; |
73 | | /// |
74 | | /// let choices = [1, 2, 4, 8, 16, 32]; |
75 | | /// let mut rng = thread_rng(); |
76 | | /// println!("{:?}", choices.choose(&mut rng)); |
77 | | /// assert_eq!(choices[..0].choose(&mut rng), None); |
78 | | /// ``` |
79 | | fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item> |
80 | | where R: Rng + ?Sized; |
81 | | |
82 | | /// Returns a mutable reference to one random element of the slice, or |
83 | | /// `None` if the slice is empty. |
84 | | /// |
85 | | /// For slices, complexity is `O(1)`. |
86 | | fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> |
87 | | where R: Rng + ?Sized; |
88 | | |
89 | | /// Chooses `amount` elements from the slice at random, without repetition, |
90 | | /// and in random order. The returned iterator is appropriate both for |
91 | | /// collection into a `Vec` and filling an existing buffer (see example). |
92 | | /// |
93 | | /// In case this API is not sufficiently flexible, use [`index::sample`]. |
94 | | /// |
95 | | /// For slices, complexity is the same as [`index::sample`]. |
96 | | /// |
97 | | /// # Example |
98 | | /// ``` |
99 | | /// use rand::seq::SliceRandom; |
100 | | /// |
101 | | /// let mut rng = &mut rand::thread_rng(); |
102 | | /// let sample = "Hello, audience!".as_bytes(); |
103 | | /// |
104 | | /// // collect the results into a vector: |
105 | | /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect(); |
106 | | /// |
107 | | /// // store in a buffer: |
108 | | /// let mut buf = [0u8; 5]; |
109 | | /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { |
110 | | /// *slot = *b; |
111 | | /// } |
112 | | /// ``` |
113 | | #[cfg(feature = "alloc")] |
114 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] |
115 | | fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> |
116 | | where R: Rng + ?Sized; |
117 | | |
118 | | /// Similar to [`choose`], but where the likelihood of each outcome may be |
119 | | /// specified. |
120 | | /// |
121 | | /// The specified function `weight` maps each item `x` to a relative |
122 | | /// likelihood `weight(x)`. The probability of each item being selected is |
123 | | /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. |
124 | | /// |
125 | | /// For slices of length `n`, complexity is `O(n)`. |
126 | | /// See also [`choose_weighted_mut`], [`distributions::weighted`]. |
127 | | /// |
128 | | /// # Example |
129 | | /// |
130 | | /// ``` |
131 | | /// use rand::prelude::*; |
132 | | /// |
133 | | /// let choices = [('a', 2), ('b', 1), ('c', 1)]; |
134 | | /// let mut rng = thread_rng(); |
135 | | /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' |
136 | | /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); |
137 | | /// ``` |
138 | | /// [`choose`]: SliceRandom::choose |
139 | | /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut |
140 | | /// [`distributions::weighted`]: crate::distributions::weighted |
141 | | #[cfg(feature = "alloc")] |
142 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] |
143 | | fn choose_weighted<R, F, B, X>( |
144 | | &self, rng: &mut R, weight: F, |
145 | | ) -> Result<&Self::Item, WeightedError> |
146 | | where |
147 | | R: Rng + ?Sized, |
148 | | F: Fn(&Self::Item) -> B, |
149 | | B: SampleBorrow<X>, |
150 | | X: SampleUniform |
151 | | + for<'a> ::core::ops::AddAssign<&'a X> |
152 | | + ::core::cmp::PartialOrd<X> |
153 | | + Clone |
154 | | + Default; |
155 | | |
156 | | /// Similar to [`choose_mut`], but where the likelihood of each outcome may |
157 | | /// be specified. |
158 | | /// |
159 | | /// The specified function `weight` maps each item `x` to a relative |
160 | | /// likelihood `weight(x)`. The probability of each item being selected is |
161 | | /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. |
162 | | /// |
163 | | /// For slices of length `n`, complexity is `O(n)`. |
164 | | /// See also [`choose_weighted`], [`distributions::weighted`]. |
165 | | /// |
166 | | /// [`choose_mut`]: SliceRandom::choose_mut |
167 | | /// [`choose_weighted`]: SliceRandom::choose_weighted |
168 | | /// [`distributions::weighted`]: crate::distributions::weighted |
169 | | #[cfg(feature = "alloc")] |
170 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] |
171 | | fn choose_weighted_mut<R, F, B, X>( |
172 | | &mut self, rng: &mut R, weight: F, |
173 | | ) -> Result<&mut Self::Item, WeightedError> |
174 | | where |
175 | | R: Rng + ?Sized, |
176 | | F: Fn(&Self::Item) -> B, |
177 | | B: SampleBorrow<X>, |
178 | | X: SampleUniform |
179 | | + for<'a> ::core::ops::AddAssign<&'a X> |
180 | | + ::core::cmp::PartialOrd<X> |
181 | | + Clone |
182 | | + Default; |
183 | | |
184 | | /// Similar to [`choose_multiple`], but where the likelihood of each element's |
185 | | /// inclusion in the output may be specified. The elements are returned in an |
186 | | /// arbitrary, unspecified order. |
187 | | /// |
188 | | /// The specified function `weight` maps each item `x` to a relative |
189 | | /// likelihood `weight(x)`. The probability of each item being selected is |
190 | | /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. |
191 | | /// |
192 | | /// If all of the weights are equal, even if they are all zero, each element has |
193 | | /// an equal likelihood of being selected. |
194 | | /// |
195 | | /// The complexity of this method depends on the feature `partition_at_index`. |
196 | | /// If the feature is enabled, then for slices of length `n`, the complexity |
197 | | /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and |
198 | | /// `O(n * log amount)` time. |
199 | | /// |
200 | | /// # Example |
201 | | /// |
202 | | /// ``` |
203 | | /// use rand::prelude::*; |
204 | | /// |
205 | | /// let choices = [('a', 2), ('b', 1), ('c', 1)]; |
206 | | /// let mut rng = thread_rng(); |
207 | | /// // First Draw * Second Draw = total odds |
208 | | /// // ----------------------- |
209 | | /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. |
210 | | /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. |
211 | | /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. |
212 | | /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>()); |
213 | | /// ``` |
214 | | /// [`choose_multiple`]: SliceRandom::choose_multiple |
215 | | // |
216 | | // Note: this is feature-gated on std due to usage of f64::powf. |
217 | | // If necessary, we may use alloc+libm as an alternative (see PR #1089). |
218 | | #[cfg(feature = "std")] |
219 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] |
220 | | fn choose_multiple_weighted<R, F, X>( |
221 | | &self, rng: &mut R, amount: usize, weight: F, |
222 | | ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> |
223 | | where |
224 | | R: Rng + ?Sized, |
225 | | F: Fn(&Self::Item) -> X, |
226 | | X: Into<f64>; |
227 | | |
228 | | /// Shuffle a mutable slice in place. |
229 | | /// |
230 | | /// For slices of length `n`, complexity is `O(n)`. |
231 | | /// |
232 | | /// # Example |
233 | | /// |
234 | | /// ``` |
235 | | /// use rand::seq::SliceRandom; |
236 | | /// use rand::thread_rng; |
237 | | /// |
238 | | /// let mut rng = thread_rng(); |
239 | | /// let mut y = [1, 2, 3, 4, 5]; |
240 | | /// println!("Unshuffled: {:?}", y); |
241 | | /// y.shuffle(&mut rng); |
242 | | /// println!("Shuffled: {:?}", y); |
243 | | /// ``` |
244 | | fn shuffle<R>(&mut self, rng: &mut R) |
245 | | where R: Rng + ?Sized; |
246 | | |
247 | | /// Shuffle a slice in place, but exit early. |
248 | | /// |
249 | | /// Returns two mutable slices from the source slice. The first contains |
250 | | /// `amount` elements randomly permuted. The second has the remaining |
251 | | /// elements that are not fully shuffled. |
252 | | /// |
253 | | /// This is an efficient method to select `amount` elements at random from |
254 | | /// the slice, provided the slice may be mutated. |
255 | | /// |
256 | | /// If you only need to choose elements randomly and `amount > self.len()/2` |
257 | | /// then you may improve performance by taking |
258 | | /// `amount = values.len() - amount` and using only the second slice. |
259 | | /// |
260 | | /// If `amount` is greater than the number of elements in the slice, this |
261 | | /// will perform a full shuffle. |
262 | | /// |
263 | | /// For slices, complexity is `O(m)` where `m = amount`. |
264 | | fn partial_shuffle<R>( |
265 | | &mut self, rng: &mut R, amount: usize, |
266 | | ) -> (&mut [Self::Item], &mut [Self::Item]) |
267 | | where R: Rng + ?Sized; |
268 | | } |
269 | | |
270 | | /// Extension trait on iterators, providing random sampling methods. |
271 | | /// |
272 | | /// This trait is implemented on all iterators `I` where `I: Iterator + Sized` |
273 | | /// and provides methods for |
274 | | /// choosing one or more elements. You must `use` this trait: |
275 | | /// |
276 | | /// ``` |
277 | | /// use rand::seq::IteratorRandom; |
278 | | /// |
279 | | /// let mut rng = rand::thread_rng(); |
280 | | /// |
281 | | /// let faces = "😀😎😐😕😠😢"; |
282 | | /// println!("I am {}!", faces.chars().choose(&mut rng).unwrap()); |
283 | | /// ``` |
284 | | /// Example output (non-deterministic): |
285 | | /// ```none |
286 | | /// I am 😀! |
287 | | /// ``` |
288 | | pub trait IteratorRandom: Iterator + Sized { |
289 | | /// Choose one element at random from the iterator. |
290 | | /// |
291 | | /// Returns `None` if and only if the iterator is empty. |
292 | | /// |
293 | | /// This method uses [`Iterator::size_hint`] for optimisation. With an |
294 | | /// accurate hint and where [`Iterator::nth`] is a constant-time operation |
295 | | /// this method can offer `O(1)` performance. Where no size hint is |
296 | | /// available, complexity is `O(n)` where `n` is the iterator length. |
297 | | /// Partial hints (where `lower > 0`) also improve performance. |
298 | | /// |
299 | | /// Note that the output values and the number of RNG samples used |
300 | | /// depends on size hints. In particular, `Iterator` combinators that don't |
301 | | /// change the values yielded but change the size hints may result in |
302 | | /// `choose` returning different elements. If you want consistent results |
303 | | /// and RNG usage consider using [`IteratorRandom::choose_stable`]. |
304 | 0 | fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item> |
305 | 0 | where R: Rng + ?Sized { |
306 | 0 | let (mut lower, mut upper) = self.size_hint(); |
307 | 0 | let mut consumed = 0; |
308 | 0 | let mut result = None; |
309 | | |
310 | | // Handling for this condition outside the loop allows the optimizer to eliminate the loop |
311 | | // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. |
312 | | // seq_iter_choose_from_1000. |
313 | 0 | if upper == Some(lower) { |
314 | 0 | return if lower == 0 { |
315 | 0 | None |
316 | | } else { |
317 | 0 | self.nth(gen_index(rng, lower)) |
318 | | }; |
319 | 0 | } |
320 | | |
321 | | // Continue until the iterator is exhausted |
322 | | loop { |
323 | 0 | if lower > 1 { |
324 | 0 | let ix = gen_index(rng, lower + consumed); |
325 | 0 | let skip = if ix < lower { |
326 | 0 | result = self.nth(ix); |
327 | 0 | lower - (ix + 1) |
328 | | } else { |
329 | 0 | lower |
330 | | }; |
331 | 0 | if upper == Some(lower) { |
332 | 0 | return result; |
333 | 0 | } |
334 | 0 | consumed += lower; |
335 | 0 | if skip > 0 { |
336 | 0 | self.nth(skip - 1); |
337 | 0 | } |
338 | | } else { |
339 | 0 | let elem = self.next(); |
340 | 0 | if elem.is_none() { |
341 | 0 | return result; |
342 | 0 | } |
343 | 0 | consumed += 1; |
344 | 0 | if gen_index(rng, consumed) == 0 { |
345 | 0 | result = elem; |
346 | 0 | } |
347 | | } |
348 | | |
349 | 0 | let hint = self.size_hint(); |
350 | 0 | lower = hint.0; |
351 | 0 | upper = hint.1; |
352 | | } |
353 | 0 | } |
354 | | |
355 | | /// Choose one element at random from the iterator. |
356 | | /// |
357 | | /// Returns `None` if and only if the iterator is empty. |
358 | | /// |
359 | | /// This method is very similar to [`choose`] except that the result |
360 | | /// only depends on the length of the iterator and the values produced by |
361 | | /// `rng`. Notably for any iterator of a given length this will make the |
362 | | /// same requests to `rng` and if the same sequence of values are produced |
363 | | /// the same index will be selected from `self`. This may be useful if you |
364 | | /// need consistent results no matter what type of iterator you are working |
365 | | /// with. If you do not need this stability prefer [`choose`]. |
366 | | /// |
367 | | /// Note that this method still uses [`Iterator::size_hint`] to skip |
368 | | /// constructing elements where possible, however the selection and `rng` |
369 | | /// calls are the same in the face of this optimization. If you want to |
370 | | /// force every element to be created regardless call `.inspect(|e| ())`. |
371 | | /// |
372 | | /// [`choose`]: IteratorRandom::choose |
373 | 0 | fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item> |
374 | 0 | where R: Rng + ?Sized { |
375 | 0 | let mut consumed = 0; |
376 | 0 | let mut result = None; |
377 | | |
378 | | loop { |
379 | | // Currently the only way to skip elements is `nth()`. So we need to |
380 | | // store what index to access next here. |
381 | | // This should be replaced by `advance_by()` once it is stable: |
382 | | // https://github.com/rust-lang/rust/issues/77404 |
383 | 0 | let mut next = 0; |
384 | | |
385 | 0 | let (lower, _) = self.size_hint(); |
386 | 0 | if lower >= 2 { |
387 | 0 | let highest_selected = (0..lower) |
388 | 0 | .filter(|ix| gen_index(rng, consumed+ix+1) == 0) |
389 | 0 | .last(); |
390 | | |
391 | 0 | consumed += lower; |
392 | 0 | next = lower; |
393 | | |
394 | 0 | if let Some(ix) = highest_selected { |
395 | 0 | result = self.nth(ix); |
396 | 0 | next -= ix + 1; |
397 | 0 | debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); |
398 | 0 | } |
399 | 0 | } |
400 | | |
401 | 0 | let elem = self.nth(next); |
402 | 0 | if elem.is_none() { |
403 | 0 | return result |
404 | 0 | } |
405 | | |
406 | 0 | if gen_index(rng, consumed+1) == 0 { |
407 | 0 | result = elem; |
408 | 0 | } |
409 | 0 | consumed += 1; |
410 | | } |
411 | 0 | } |
412 | | |
413 | | /// Collects values at random from the iterator into a supplied buffer |
414 | | /// until that buffer is filled. |
415 | | /// |
416 | | /// Although the elements are selected randomly, the order of elements in |
417 | | /// the buffer is neither stable nor fully random. If random ordering is |
418 | | /// desired, shuffle the result. |
419 | | /// |
420 | | /// Returns the number of elements added to the buffer. This equals the length |
421 | | /// of the buffer unless the iterator contains insufficient elements, in which |
422 | | /// case this equals the number of elements available. |
423 | | /// |
424 | | /// Complexity is `O(n)` where `n` is the length of the iterator. |
425 | | /// For slices, prefer [`SliceRandom::choose_multiple`]. |
426 | 0 | fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize |
427 | 0 | where R: Rng + ?Sized { |
428 | 0 | let amount = buf.len(); |
429 | 0 | let mut len = 0; |
430 | 0 | while len < amount { |
431 | 0 | if let Some(elem) = self.next() { |
432 | 0 | buf[len] = elem; |
433 | 0 | len += 1; |
434 | 0 | } else { |
435 | | // Iterator exhausted; stop early |
436 | 0 | return len; |
437 | | } |
438 | | } |
439 | | |
440 | | // Continue, since the iterator was not exhausted |
441 | 0 | for (i, elem) in self.enumerate() { |
442 | 0 | let k = gen_index(rng, i + 1 + amount); |
443 | 0 | if let Some(slot) = buf.get_mut(k) { |
444 | 0 | *slot = elem; |
445 | 0 | } |
446 | | } |
447 | 0 | len |
448 | 0 | } |
449 | | |
450 | | /// Collects `amount` values at random from the iterator into a vector. |
451 | | /// |
452 | | /// This is equivalent to `choose_multiple_fill` except for the result type. |
453 | | /// |
454 | | /// Although the elements are selected randomly, the order of elements in |
455 | | /// the buffer is neither stable nor fully random. If random ordering is |
456 | | /// desired, shuffle the result. |
457 | | /// |
458 | | /// The length of the returned vector equals `amount` unless the iterator |
459 | | /// contains insufficient elements, in which case it equals the number of |
460 | | /// elements available. |
461 | | /// |
462 | | /// Complexity is `O(n)` where `n` is the length of the iterator. |
463 | | /// For slices, prefer [`SliceRandom::choose_multiple`]. |
464 | | #[cfg(feature = "alloc")] |
465 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] |
466 | 0 | fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item> |
467 | 0 | where R: Rng + ?Sized { |
468 | 0 | let mut reservoir = Vec::with_capacity(amount); |
469 | 0 | reservoir.extend(self.by_ref().take(amount)); |
470 | | |
471 | | // Continue unless the iterator was exhausted |
472 | | // |
473 | | // note: this prevents iterators that "restart" from causing problems. |
474 | | // If the iterator stops once, then so do we. |
475 | 0 | if reservoir.len() == amount { |
476 | 0 | for (i, elem) in self.enumerate() { |
477 | 0 | let k = gen_index(rng, i + 1 + amount); |
478 | 0 | if let Some(slot) = reservoir.get_mut(k) { |
479 | 0 | *slot = elem; |
480 | 0 | } |
481 | | } |
482 | 0 | } else { |
483 | 0 | // Don't hang onto extra memory. There is a corner case where |
484 | 0 | // `amount` was much less than `self.len()`. |
485 | 0 | reservoir.shrink_to_fit(); |
486 | 0 | } |
487 | 0 | reservoir |
488 | 0 | } |
489 | | } |
490 | | |
491 | | |
492 | | impl<T> SliceRandom for [T] { |
493 | | type Item = T; |
494 | | |
495 | 0 | fn choose<R>(&self, rng: &mut R) -> Option<&Self::Item> |
496 | 0 | where R: Rng + ?Sized { |
497 | 0 | if self.is_empty() { |
498 | 0 | None |
499 | | } else { |
500 | 0 | Some(&self[gen_index(rng, self.len())]) |
501 | | } |
502 | 0 | } |
503 | | |
504 | 0 | fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Item> |
505 | 0 | where R: Rng + ?Sized { |
506 | 0 | if self.is_empty() { |
507 | 0 | None |
508 | | } else { |
509 | 0 | let len = self.len(); |
510 | 0 | Some(&mut self[gen_index(rng, len)]) |
511 | | } |
512 | 0 | } |
513 | | |
514 | | #[cfg(feature = "alloc")] |
515 | 0 | fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> |
516 | 0 | where R: Rng + ?Sized { |
517 | 0 | let amount = ::core::cmp::min(amount, self.len()); |
518 | 0 | SliceChooseIter { |
519 | 0 | slice: self, |
520 | 0 | _phantom: Default::default(), |
521 | 0 | indices: index::sample(rng, self.len(), amount).into_iter(), |
522 | 0 | } |
523 | 0 | } |
524 | | |
525 | | #[cfg(feature = "alloc")] |
526 | 0 | fn choose_weighted<R, F, B, X>( |
527 | 0 | &self, rng: &mut R, weight: F, |
528 | 0 | ) -> Result<&Self::Item, WeightedError> |
529 | 0 | where |
530 | 0 | R: Rng + ?Sized, |
531 | 0 | F: Fn(&Self::Item) -> B, |
532 | 0 | B: SampleBorrow<X>, |
533 | 0 | X: SampleUniform |
534 | 0 | + for<'a> ::core::ops::AddAssign<&'a X> |
535 | 0 | + ::core::cmp::PartialOrd<X> |
536 | 0 | + Clone |
537 | 0 | + Default, |
538 | | { |
539 | | use crate::distributions::{Distribution, WeightedIndex}; |
540 | 0 | let distr = WeightedIndex::new(self.iter().map(weight))?; |
541 | 0 | Ok(&self[distr.sample(rng)]) |
542 | 0 | } |
543 | | |
544 | | #[cfg(feature = "alloc")] |
545 | 0 | fn choose_weighted_mut<R, F, B, X>( |
546 | 0 | &mut self, rng: &mut R, weight: F, |
547 | 0 | ) -> Result<&mut Self::Item, WeightedError> |
548 | 0 | where |
549 | 0 | R: Rng + ?Sized, |
550 | 0 | F: Fn(&Self::Item) -> B, |
551 | 0 | B: SampleBorrow<X>, |
552 | 0 | X: SampleUniform |
553 | 0 | + for<'a> ::core::ops::AddAssign<&'a X> |
554 | 0 | + ::core::cmp::PartialOrd<X> |
555 | 0 | + Clone |
556 | 0 | + Default, |
557 | | { |
558 | | use crate::distributions::{Distribution, WeightedIndex}; |
559 | 0 | let distr = WeightedIndex::new(self.iter().map(weight))?; |
560 | 0 | Ok(&mut self[distr.sample(rng)]) |
561 | 0 | } |
562 | | |
563 | | #[cfg(feature = "std")] |
564 | 0 | fn choose_multiple_weighted<R, F, X>( |
565 | 0 | &self, rng: &mut R, amount: usize, weight: F, |
566 | 0 | ) -> Result<SliceChooseIter<Self, Self::Item>, WeightedError> |
567 | 0 | where |
568 | 0 | R: Rng + ?Sized, |
569 | 0 | F: Fn(&Self::Item) -> X, |
570 | 0 | X: Into<f64>, |
571 | | { |
572 | 0 | let amount = ::core::cmp::min(amount, self.len()); |
573 | | Ok(SliceChooseIter { |
574 | 0 | slice: self, |
575 | 0 | _phantom: Default::default(), |
576 | 0 | indices: index::sample_weighted( |
577 | 0 | rng, |
578 | 0 | self.len(), |
579 | 0 | |idx| weight(&self[idx]).into(), |
580 | 0 | amount, |
581 | 0 | )? |
582 | 0 | .into_iter(), |
583 | | }) |
584 | 0 | } |
585 | | |
586 | 0 | fn shuffle<R>(&mut self, rng: &mut R) |
587 | 0 | where R: Rng + ?Sized { |
588 | 0 | for i in (1..self.len()).rev() { |
589 | 0 | // invariant: elements with index > i have been locked in place. |
590 | 0 | self.swap(i, gen_index(rng, i + 1)); |
591 | 0 | } |
592 | 0 | } |
593 | | |
594 | 0 | fn partial_shuffle<R>( |
595 | 0 | &mut self, rng: &mut R, amount: usize, |
596 | 0 | ) -> (&mut [Self::Item], &mut [Self::Item]) |
597 | 0 | where R: Rng + ?Sized { |
598 | | // This applies Durstenfeld's algorithm for the |
599 | | // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) |
600 | | // for an unbiased permutation, but exits early after choosing `amount` |
601 | | // elements. |
602 | | |
603 | 0 | let len = self.len(); |
604 | 0 | let end = if amount >= len { 0 } else { len - amount }; |
605 | | |
606 | 0 | for i in (end..len).rev() { |
607 | 0 | // invariant: elements with index > i have been locked in place. |
608 | 0 | self.swap(i, gen_index(rng, i + 1)); |
609 | 0 | } |
610 | 0 | let r = self.split_at_mut(end); |
611 | 0 | (r.1, r.0) |
612 | 0 | } |
613 | | } |
614 | | |
615 | | impl<I> IteratorRandom for I where I: Iterator + Sized {} |
616 | | |
617 | | |
618 | | /// An iterator over multiple slice elements. |
619 | | /// |
620 | | /// This struct is created by |
621 | | /// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple). |
622 | | #[cfg(feature = "alloc")] |
623 | | #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] |
624 | | #[derive(Debug)] |
625 | | pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { |
626 | | slice: &'a S, |
627 | | _phantom: ::core::marker::PhantomData<T>, |
628 | | indices: index::IndexVecIntoIter, |
629 | | } |
630 | | |
631 | | #[cfg(feature = "alloc")] |
632 | | impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { |
633 | | type Item = &'a T; |
634 | | |
635 | 0 | fn next(&mut self) -> Option<Self::Item> { |
636 | | // TODO: investigate using SliceIndex::get_unchecked when stable |
637 | 0 | self.indices.next().map(|i| &self.slice[i as usize]) |
638 | 0 | } |
639 | | |
640 | 0 | fn size_hint(&self) -> (usize, Option<usize>) { |
641 | 0 | (self.indices.len(), Some(self.indices.len())) |
642 | 0 | } |
643 | | } |
644 | | |
645 | | #[cfg(feature = "alloc")] |
646 | | impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator |
647 | | for SliceChooseIter<'a, S, T> |
648 | | { |
649 | 0 | fn len(&self) -> usize { |
650 | 0 | self.indices.len() |
651 | 0 | } |
652 | | } |
653 | | |
654 | | |
655 | | // Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where |
656 | | // possible, primarily in order to produce the same output on 32-bit and 64-bit |
657 | | // platforms. |
658 | | #[inline] |
659 | 0 | fn gen_index<R: Rng + ?Sized>(rng: &mut R, ubound: usize) -> usize { |
660 | 0 | if ubound <= (core::u32::MAX as usize) { |
661 | 0 | rng.gen_range(0..ubound as u32) as usize |
662 | | } else { |
663 | 0 | rng.gen_range(0..ubound) |
664 | | } |
665 | 0 | } |
666 | | |
667 | | |
668 | | #[cfg(test)] |
669 | | mod test { |
670 | | use super::*; |
671 | | #[cfg(feature = "alloc")] use crate::Rng; |
672 | | #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec; |
673 | | |
674 | | #[test] |
675 | | fn test_slice_choose() { |
676 | | let mut r = crate::test::rng(107); |
677 | | let chars = [ |
678 | | 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', |
679 | | ]; |
680 | | let mut chosen = [0i32; 14]; |
681 | | // The below all use a binomial distribution with n=1000, p=1/14. |
682 | | // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 |
683 | | for _ in 0..1000 { |
684 | | let picked = *chars.choose(&mut r).unwrap(); |
685 | | chosen[(picked as usize) - ('a' as usize)] += 1; |
686 | | } |
687 | | for count in chosen.iter() { |
688 | | assert!(40 < *count && *count < 106); |
689 | | } |
690 | | |
691 | | chosen.iter_mut().for_each(|x| *x = 0); |
692 | | for _ in 0..1000 { |
693 | | *chosen.choose_mut(&mut r).unwrap() += 1; |
694 | | } |
695 | | for count in chosen.iter() { |
696 | | assert!(40 < *count && *count < 106); |
697 | | } |
698 | | |
699 | | let mut v: [isize; 0] = []; |
700 | | assert_eq!(v.choose(&mut r), None); |
701 | | assert_eq!(v.choose_mut(&mut r), None); |
702 | | } |
703 | | |
704 | | #[test] |
705 | | fn value_stability_slice() { |
706 | | let mut r = crate::test::rng(413); |
707 | | let chars = [ |
708 | | 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', |
709 | | ]; |
710 | | let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; |
711 | | |
712 | | assert_eq!(chars.choose(&mut r), Some(&'l')); |
713 | | assert_eq!(nums.choose_mut(&mut r), Some(&mut 10)); |
714 | | |
715 | | #[cfg(feature = "alloc")] |
716 | | assert_eq!( |
717 | | &chars |
718 | | .choose_multiple(&mut r, 8) |
719 | | .cloned() |
720 | | .collect::<Vec<char>>(), |
721 | | &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e'] |
722 | | ); |
723 | | |
724 | | #[cfg(feature = "alloc")] |
725 | | assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f')); |
726 | | #[cfg(feature = "alloc")] |
727 | | assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5)); |
728 | | |
729 | | let mut r = crate::test::rng(414); |
730 | | nums.shuffle(&mut r); |
731 | | assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); |
732 | | nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; |
733 | | let res = nums.partial_shuffle(&mut r, 6); |
734 | | assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); |
735 | | assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); |
736 | | } |
737 | | |
738 | | #[derive(Clone)] |
739 | | struct UnhintedIterator<I: Iterator + Clone> { |
740 | | iter: I, |
741 | | } |
742 | | impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> { |
743 | | type Item = I::Item; |
744 | | |
745 | | fn next(&mut self) -> Option<Self::Item> { |
746 | | self.iter.next() |
747 | | } |
748 | | } |
749 | | |
750 | | #[derive(Clone)] |
751 | | struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> { |
752 | | iter: I, |
753 | | chunk_remaining: usize, |
754 | | chunk_size: usize, |
755 | | hint_total_size: bool, |
756 | | } |
757 | | impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> { |
758 | | type Item = I::Item; |
759 | | |
760 | | fn next(&mut self) -> Option<Self::Item> { |
761 | | if self.chunk_remaining == 0 { |
762 | | self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len()); |
763 | | } |
764 | | self.chunk_remaining = self.chunk_remaining.saturating_sub(1); |
765 | | |
766 | | self.iter.next() |
767 | | } |
768 | | |
769 | | fn size_hint(&self) -> (usize, Option<usize>) { |
770 | | ( |
771 | | self.chunk_remaining, |
772 | | if self.hint_total_size { |
773 | | Some(self.iter.len()) |
774 | | } else { |
775 | | None |
776 | | }, |
777 | | ) |
778 | | } |
779 | | } |
780 | | |
781 | | #[derive(Clone)] |
782 | | struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> { |
783 | | iter: I, |
784 | | window_size: usize, |
785 | | hint_total_size: bool, |
786 | | } |
787 | | impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> { |
788 | | type Item = I::Item; |
789 | | |
790 | | fn next(&mut self) -> Option<Self::Item> { |
791 | | self.iter.next() |
792 | | } |
793 | | |
794 | | fn size_hint(&self) -> (usize, Option<usize>) { |
795 | | ( |
796 | | ::core::cmp::min(self.iter.len(), self.window_size), |
797 | | if self.hint_total_size { |
798 | | Some(self.iter.len()) |
799 | | } else { |
800 | | None |
801 | | }, |
802 | | ) |
803 | | } |
804 | | } |
805 | | |
806 | | #[test] |
807 | | #[cfg_attr(miri, ignore)] // Miri is too slow |
808 | | fn test_iterator_choose() { |
809 | | let r = &mut crate::test::rng(109); |
810 | | fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) { |
811 | | let mut chosen = [0i32; 9]; |
812 | | for _ in 0..1000 { |
813 | | let picked = iter.clone().choose(r).unwrap(); |
814 | | chosen[picked] += 1; |
815 | | } |
816 | | for count in chosen.iter() { |
817 | | // Samples should follow Binomial(1000, 1/9) |
818 | | // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x |
819 | | // Note: have seen 153, which is unlikely but not impossible. |
820 | | assert!( |
821 | | 72 < *count && *count < 154, |
822 | | "count not close to 1000/9: {}", |
823 | | count |
824 | | ); |
825 | | } |
826 | | } |
827 | | |
828 | | test_iter(r, 0..9); |
829 | | test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); |
830 | | #[cfg(feature = "alloc")] |
831 | | test_iter(r, (0..9).collect::<Vec<_>>().into_iter()); |
832 | | test_iter(r, UnhintedIterator { iter: 0..9 }); |
833 | | test_iter(r, ChunkHintedIterator { |
834 | | iter: 0..9, |
835 | | chunk_size: 4, |
836 | | chunk_remaining: 4, |
837 | | hint_total_size: false, |
838 | | }); |
839 | | test_iter(r, ChunkHintedIterator { |
840 | | iter: 0..9, |
841 | | chunk_size: 4, |
842 | | chunk_remaining: 4, |
843 | | hint_total_size: true, |
844 | | }); |
845 | | test_iter(r, WindowHintedIterator { |
846 | | iter: 0..9, |
847 | | window_size: 2, |
848 | | hint_total_size: false, |
849 | | }); |
850 | | test_iter(r, WindowHintedIterator { |
851 | | iter: 0..9, |
852 | | window_size: 2, |
853 | | hint_total_size: true, |
854 | | }); |
855 | | |
856 | | assert_eq!((0..0).choose(r), None); |
857 | | assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); |
858 | | } |
859 | | |
860 | | #[test] |
861 | | #[cfg_attr(miri, ignore)] // Miri is too slow |
862 | | fn test_iterator_choose_stable() { |
863 | | let r = &mut crate::test::rng(109); |
864 | | fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) { |
865 | | let mut chosen = [0i32; 9]; |
866 | | for _ in 0..1000 { |
867 | | let picked = iter.clone().choose_stable(r).unwrap(); |
868 | | chosen[picked] += 1; |
869 | | } |
870 | | for count in chosen.iter() { |
871 | | // Samples should follow Binomial(1000, 1/9) |
872 | | // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x |
873 | | // Note: have seen 153, which is unlikely but not impossible. |
874 | | assert!( |
875 | | 72 < *count && *count < 154, |
876 | | "count not close to 1000/9: {}", |
877 | | count |
878 | | ); |
879 | | } |
880 | | } |
881 | | |
882 | | test_iter(r, 0..9); |
883 | | test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); |
884 | | #[cfg(feature = "alloc")] |
885 | | test_iter(r, (0..9).collect::<Vec<_>>().into_iter()); |
886 | | test_iter(r, UnhintedIterator { iter: 0..9 }); |
887 | | test_iter(r, ChunkHintedIterator { |
888 | | iter: 0..9, |
889 | | chunk_size: 4, |
890 | | chunk_remaining: 4, |
891 | | hint_total_size: false, |
892 | | }); |
893 | | test_iter(r, ChunkHintedIterator { |
894 | | iter: 0..9, |
895 | | chunk_size: 4, |
896 | | chunk_remaining: 4, |
897 | | hint_total_size: true, |
898 | | }); |
899 | | test_iter(r, WindowHintedIterator { |
900 | | iter: 0..9, |
901 | | window_size: 2, |
902 | | hint_total_size: false, |
903 | | }); |
904 | | test_iter(r, WindowHintedIterator { |
905 | | iter: 0..9, |
906 | | window_size: 2, |
907 | | hint_total_size: true, |
908 | | }); |
909 | | |
910 | | assert_eq!((0..0).choose(r), None); |
911 | | assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); |
912 | | } |
913 | | |
914 | | #[test] |
915 | | #[cfg_attr(miri, ignore)] // Miri is too slow |
916 | | fn test_iterator_choose_stable_stability() { |
917 | | fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] { |
918 | | let r = &mut crate::test::rng(109); |
919 | | let mut chosen = [0i32; 9]; |
920 | | for _ in 0..1000 { |
921 | | let picked = iter.clone().choose_stable(r).unwrap(); |
922 | | chosen[picked] += 1; |
923 | | } |
924 | | chosen |
925 | | } |
926 | | |
927 | | let reference = test_iter(0..9); |
928 | | assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); |
929 | | |
930 | | #[cfg(feature = "alloc")] |
931 | | assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference); |
932 | | assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); |
933 | | assert_eq!(test_iter(ChunkHintedIterator { |
934 | | iter: 0..9, |
935 | | chunk_size: 4, |
936 | | chunk_remaining: 4, |
937 | | hint_total_size: false, |
938 | | }), reference); |
939 | | assert_eq!(test_iter(ChunkHintedIterator { |
940 | | iter: 0..9, |
941 | | chunk_size: 4, |
942 | | chunk_remaining: 4, |
943 | | hint_total_size: true, |
944 | | }), reference); |
945 | | assert_eq!(test_iter(WindowHintedIterator { |
946 | | iter: 0..9, |
947 | | window_size: 2, |
948 | | hint_total_size: false, |
949 | | }), reference); |
950 | | assert_eq!(test_iter(WindowHintedIterator { |
951 | | iter: 0..9, |
952 | | window_size: 2, |
953 | | hint_total_size: true, |
954 | | }), reference); |
955 | | } |
956 | | |
957 | | #[test] |
958 | | #[cfg_attr(miri, ignore)] // Miri is too slow |
959 | | fn test_shuffle() { |
960 | | let mut r = crate::test::rng(108); |
961 | | let empty: &mut [isize] = &mut []; |
962 | | empty.shuffle(&mut r); |
963 | | let mut one = [1]; |
964 | | one.shuffle(&mut r); |
965 | | let b: &[_] = &[1]; |
966 | | assert_eq!(one, b); |
967 | | |
968 | | let mut two = [1, 2]; |
969 | | two.shuffle(&mut r); |
970 | | assert!(two == [1, 2] || two == [2, 1]); |
971 | | |
972 | | fn move_last(slice: &mut [usize], pos: usize) { |
973 | | // use slice[pos..].rotate_left(1); once we can use that |
974 | | let last_val = slice[pos]; |
975 | | for i in pos..slice.len() - 1 { |
976 | | slice[i] = slice[i + 1]; |
977 | | } |
978 | | *slice.last_mut().unwrap() = last_val; |
979 | | } |
980 | | let mut counts = [0i32; 24]; |
981 | | for _ in 0..10000 { |
982 | | let mut arr: [usize; 4] = [0, 1, 2, 3]; |
983 | | arr.shuffle(&mut r); |
984 | | let mut permutation = 0usize; |
985 | | let mut pos_value = counts.len(); |
986 | | for i in 0..4 { |
987 | | pos_value /= 4 - i; |
988 | | let pos = arr.iter().position(|&x| x == i).unwrap(); |
989 | | assert!(pos < (4 - i)); |
990 | | permutation += pos * pos_value; |
991 | | move_last(&mut arr, pos); |
992 | | assert_eq!(arr[3], i); |
993 | | } |
994 | | for (i, &a) in arr.iter().enumerate() { |
995 | | assert_eq!(a, i); |
996 | | } |
997 | | counts[permutation] += 1; |
998 | | } |
999 | | for count in counts.iter() { |
1000 | | // Binomial(10000, 1/24) with average 416.667 |
1001 | | // Octave: binocdf(n, 10000, 1/24) |
1002 | | // 99.9% chance samples lie within this range: |
1003 | | assert!(352 <= *count && *count <= 483, "count: {}", count); |
1004 | | } |
1005 | | } |
1006 | | |
1007 | | #[test] |
1008 | | fn test_partial_shuffle() { |
1009 | | let mut r = crate::test::rng(118); |
1010 | | |
1011 | | let mut empty: [u32; 0] = []; |
1012 | | let res = empty.partial_shuffle(&mut r, 10); |
1013 | | assert_eq!((res.0.len(), res.1.len()), (0, 0)); |
1014 | | |
1015 | | let mut v = [1, 2, 3, 4, 5]; |
1016 | | let res = v.partial_shuffle(&mut r, 2); |
1017 | | assert_eq!((res.0.len(), res.1.len()), (2, 3)); |
1018 | | assert!(res.0[0] != res.0[1]); |
1019 | | // First elements are only modified if selected, so at least one isn't modified: |
1020 | | assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); |
1021 | | } |
1022 | | |
1023 | | #[test] |
1024 | | #[cfg(feature = "alloc")] |
1025 | | fn test_sample_iter() { |
1026 | | let min_val = 1; |
1027 | | let max_val = 100; |
1028 | | |
1029 | | let mut r = crate::test::rng(401); |
1030 | | let vals = (min_val..max_val).collect::<Vec<i32>>(); |
1031 | | let small_sample = vals.iter().choose_multiple(&mut r, 5); |
1032 | | let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); |
1033 | | |
1034 | | assert_eq!(small_sample.len(), 5); |
1035 | | assert_eq!(large_sample.len(), vals.len()); |
1036 | | // no randomization happens when amount >= len |
1037 | | assert_eq!(large_sample, vals.iter().collect::<Vec<_>>()); |
1038 | | |
1039 | | assert!(small_sample |
1040 | | .iter() |
1041 | | .all(|e| { **e >= min_val && **e <= max_val })); |
1042 | | } |
1043 | | |
1044 | | #[test] |
1045 | | #[cfg(feature = "alloc")] |
1046 | | #[cfg_attr(miri, ignore)] // Miri is too slow |
1047 | | fn test_weighted() { |
1048 | | let mut r = crate::test::rng(406); |
1049 | | const N_REPS: u32 = 3000; |
1050 | | let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; |
1051 | | let total_weight = weights.iter().sum::<u32>() as f32; |
1052 | | |
1053 | | let verify = |result: [i32; 14]| { |
1054 | | for (i, count) in result.iter().enumerate() { |
1055 | | let exp = (weights[i] * N_REPS) as f32 / total_weight; |
1056 | | let mut err = (*count as f32 - exp).abs(); |
1057 | | if err != 0.0 { |
1058 | | err /= exp; |
1059 | | } |
1060 | | assert!(err <= 0.25); |
1061 | | } |
1062 | | }; |
1063 | | |
1064 | | // choose_weighted |
1065 | | fn get_weight<T>(item: &(u32, T)) -> u32 { |
1066 | | item.0 |
1067 | | } |
1068 | | let mut chosen = [0i32; 14]; |
1069 | | let mut items = [(0u32, 0usize); 14]; // (weight, index) |
1070 | | for (i, item) in items.iter_mut().enumerate() { |
1071 | | *item = (weights[i], i); |
1072 | | } |
1073 | | for _ in 0..N_REPS { |
1074 | | let item = items.choose_weighted(&mut r, get_weight).unwrap(); |
1075 | | chosen[item.1] += 1; |
1076 | | } |
1077 | | verify(chosen); |
1078 | | |
1079 | | // choose_weighted_mut |
1080 | | let mut items = [(0u32, 0i32); 14]; // (weight, count) |
1081 | | for (i, item) in items.iter_mut().enumerate() { |
1082 | | *item = (weights[i], 0); |
1083 | | } |
1084 | | for _ in 0..N_REPS { |
1085 | | items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; |
1086 | | } |
1087 | | for (ch, item) in chosen.iter_mut().zip(items.iter()) { |
1088 | | *ch = item.1; |
1089 | | } |
1090 | | verify(chosen); |
1091 | | |
1092 | | // Check error cases |
1093 | | let empty_slice = &mut [10][0..0]; |
1094 | | assert_eq!( |
1095 | | empty_slice.choose_weighted(&mut r, |_| 1), |
1096 | | Err(WeightedError::NoItem) |
1097 | | ); |
1098 | | assert_eq!( |
1099 | | empty_slice.choose_weighted_mut(&mut r, |_| 1), |
1100 | | Err(WeightedError::NoItem) |
1101 | | ); |
1102 | | assert_eq!( |
1103 | | ['x'].choose_weighted_mut(&mut r, |_| 0), |
1104 | | Err(WeightedError::AllWeightsZero) |
1105 | | ); |
1106 | | assert_eq!( |
1107 | | [0, -1].choose_weighted_mut(&mut r, |x| *x), |
1108 | | Err(WeightedError::InvalidWeight) |
1109 | | ); |
1110 | | assert_eq!( |
1111 | | [-1, 0].choose_weighted_mut(&mut r, |x| *x), |
1112 | | Err(WeightedError::InvalidWeight) |
1113 | | ); |
1114 | | } |
1115 | | |
1116 | | #[test] |
1117 | | fn value_stability_choose() { |
1118 | | fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> { |
1119 | | let mut rng = crate::test::rng(411); |
1120 | | iter.choose(&mut rng) |
1121 | | } |
1122 | | |
1123 | | assert_eq!(choose([].iter().cloned()), None); |
1124 | | assert_eq!(choose(0..100), Some(33)); |
1125 | | assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); |
1126 | | assert_eq!( |
1127 | | choose(ChunkHintedIterator { |
1128 | | iter: 0..100, |
1129 | | chunk_size: 32, |
1130 | | chunk_remaining: 32, |
1131 | | hint_total_size: false, |
1132 | | }), |
1133 | | Some(39) |
1134 | | ); |
1135 | | assert_eq!( |
1136 | | choose(ChunkHintedIterator { |
1137 | | iter: 0..100, |
1138 | | chunk_size: 32, |
1139 | | chunk_remaining: 32, |
1140 | | hint_total_size: true, |
1141 | | }), |
1142 | | Some(39) |
1143 | | ); |
1144 | | assert_eq!( |
1145 | | choose(WindowHintedIterator { |
1146 | | iter: 0..100, |
1147 | | window_size: 32, |
1148 | | hint_total_size: false, |
1149 | | }), |
1150 | | Some(90) |
1151 | | ); |
1152 | | assert_eq!( |
1153 | | choose(WindowHintedIterator { |
1154 | | iter: 0..100, |
1155 | | window_size: 32, |
1156 | | hint_total_size: true, |
1157 | | }), |
1158 | | Some(90) |
1159 | | ); |
1160 | | } |
1161 | | |
1162 | | #[test] |
1163 | | fn value_stability_choose_stable() { |
1164 | | fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> { |
1165 | | let mut rng = crate::test::rng(411); |
1166 | | iter.choose_stable(&mut rng) |
1167 | | } |
1168 | | |
1169 | | assert_eq!(choose([].iter().cloned()), None); |
1170 | | assert_eq!(choose(0..100), Some(40)); |
1171 | | assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); |
1172 | | assert_eq!( |
1173 | | choose(ChunkHintedIterator { |
1174 | | iter: 0..100, |
1175 | | chunk_size: 32, |
1176 | | chunk_remaining: 32, |
1177 | | hint_total_size: false, |
1178 | | }), |
1179 | | Some(40) |
1180 | | ); |
1181 | | assert_eq!( |
1182 | | choose(ChunkHintedIterator { |
1183 | | iter: 0..100, |
1184 | | chunk_size: 32, |
1185 | | chunk_remaining: 32, |
1186 | | hint_total_size: true, |
1187 | | }), |
1188 | | Some(40) |
1189 | | ); |
1190 | | assert_eq!( |
1191 | | choose(WindowHintedIterator { |
1192 | | iter: 0..100, |
1193 | | window_size: 32, |
1194 | | hint_total_size: false, |
1195 | | }), |
1196 | | Some(40) |
1197 | | ); |
1198 | | assert_eq!( |
1199 | | choose(WindowHintedIterator { |
1200 | | iter: 0..100, |
1201 | | window_size: 32, |
1202 | | hint_total_size: true, |
1203 | | }), |
1204 | | Some(40) |
1205 | | ); |
1206 | | } |
1207 | | |
1208 | | #[test] |
1209 | | fn value_stability_choose_multiple() { |
1210 | | fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) { |
1211 | | let mut rng = crate::test::rng(412); |
1212 | | let mut buf = [0u32; 8]; |
1213 | | assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len()); |
1214 | | assert_eq!(&buf[0..v.len()], v); |
1215 | | } |
1216 | | |
1217 | | do_test(0..4, &[0, 1, 2, 3]); |
1218 | | do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); |
1219 | | do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); |
1220 | | |
1221 | | #[cfg(feature = "alloc")] |
1222 | | { |
1223 | | fn do_test<I: Iterator<Item = u32>>(iter: I, v: &[u32]) { |
1224 | | let mut rng = crate::test::rng(412); |
1225 | | assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); |
1226 | | } |
1227 | | |
1228 | | do_test(0..4, &[0, 1, 2, 3]); |
1229 | | do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); |
1230 | | do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); |
1231 | | } |
1232 | | } |
1233 | | |
1234 | | #[test] |
1235 | | #[cfg(feature = "std")] |
1236 | | fn test_multiple_weighted_edge_cases() { |
1237 | | use super::*; |
1238 | | |
1239 | | let mut rng = crate::test::rng(413); |
1240 | | |
1241 | | // Case 1: One of the weights is 0 |
1242 | | let choices = [('a', 2), ('b', 1), ('c', 0)]; |
1243 | | for _ in 0..100 { |
1244 | | let result = choices |
1245 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1246 | | .unwrap() |
1247 | | .collect::<Vec<_>>(); |
1248 | | |
1249 | | assert_eq!(result.len(), 2); |
1250 | | assert!(!result.iter().any(|val| val.0 == 'c')); |
1251 | | } |
1252 | | |
1253 | | // Case 2: All of the weights are 0 |
1254 | | let choices = [('a', 0), ('b', 0), ('c', 0)]; |
1255 | | |
1256 | | assert_eq!(choices |
1257 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1258 | | .unwrap().count(), 2); |
1259 | | |
1260 | | // Case 3: Negative weights |
1261 | | let choices = [('a', -1), ('b', 1), ('c', 1)]; |
1262 | | assert_eq!( |
1263 | | choices |
1264 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1265 | | .unwrap_err(), |
1266 | | WeightedError::InvalidWeight |
1267 | | ); |
1268 | | |
1269 | | // Case 4: Empty list |
1270 | | let choices = []; |
1271 | | assert_eq!(choices |
1272 | | .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) |
1273 | | .unwrap().count(), 0); |
1274 | | |
1275 | | // Case 5: NaN weights |
1276 | | let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; |
1277 | | assert_eq!( |
1278 | | choices |
1279 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1280 | | .unwrap_err(), |
1281 | | WeightedError::InvalidWeight |
1282 | | ); |
1283 | | |
1284 | | // Case 6: +infinity weights |
1285 | | let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; |
1286 | | for _ in 0..100 { |
1287 | | let result = choices |
1288 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1289 | | .unwrap() |
1290 | | .collect::<Vec<_>>(); |
1291 | | assert_eq!(result.len(), 2); |
1292 | | assert!(result.iter().any(|val| val.0 == 'a')); |
1293 | | } |
1294 | | |
1295 | | // Case 7: -infinity weights |
1296 | | let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; |
1297 | | assert_eq!( |
1298 | | choices |
1299 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1300 | | .unwrap_err(), |
1301 | | WeightedError::InvalidWeight |
1302 | | ); |
1303 | | |
1304 | | // Case 8: -0 weights |
1305 | | let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; |
1306 | | assert!(choices |
1307 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1308 | | .is_ok()); |
1309 | | } |
1310 | | |
1311 | | #[test] |
1312 | | #[cfg(feature = "std")] |
1313 | | fn test_multiple_weighted_distributions() { |
1314 | | use super::*; |
1315 | | |
1316 | | // The theoretical probabilities of the different outcomes are: |
1317 | | // AB: 0.5 * 0.5 = 0.250 |
1318 | | // AC: 0.5 * 0.5 = 0.250 |
1319 | | // BA: 0.25 * 0.67 = 0.167 |
1320 | | // BC: 0.25 * 0.33 = 0.082 |
1321 | | // CA: 0.25 * 0.67 = 0.167 |
1322 | | // CB: 0.25 * 0.33 = 0.082 |
1323 | | let choices = [('a', 2), ('b', 1), ('c', 1)]; |
1324 | | let mut rng = crate::test::rng(414); |
1325 | | |
1326 | | let mut results = [0i32; 3]; |
1327 | | let expected_results = [4167, 4167, 1666]; |
1328 | | for _ in 0..10000 { |
1329 | | let result = choices |
1330 | | .choose_multiple_weighted(&mut rng, 2, |item| item.1) |
1331 | | .unwrap() |
1332 | | .collect::<Vec<_>>(); |
1333 | | |
1334 | | assert_eq!(result.len(), 2); |
1335 | | |
1336 | | match (result[0].0, result[1].0) { |
1337 | | ('a', 'b') | ('b', 'a') => { |
1338 | | results[0] += 1; |
1339 | | } |
1340 | | ('a', 'c') | ('c', 'a') => { |
1341 | | results[1] += 1; |
1342 | | } |
1343 | | ('b', 'c') | ('c', 'b') => { |
1344 | | results[2] += 1; |
1345 | | } |
1346 | | (_, _) => panic!("unexpected result"), |
1347 | | } |
1348 | | } |
1349 | | |
1350 | | let mut diffs = results |
1351 | | .iter() |
1352 | | .zip(&expected_results) |
1353 | | .map(|(a, b)| (a - b).abs()); |
1354 | | assert!(!diffs.any(|deviation| deviation > 100)); |
1355 | | } |
1356 | | } |