diff --git a/benches/combinations_with_replacement.rs b/benches/combinations_with_replacement.rs index 8e4fa3dc3..203532714 100644 --- a/benches/combinations_with_replacement.rs +++ b/benches/combinations_with_replacement.rs @@ -30,11 +30,42 @@ fn comb_replacement_n10_k10(c: &mut Criterion) { }) }); } +fn array_comb_replacement_n10_k5(c: &mut Criterion) { + c.bench_function("array comb replacement n10k5", move |b| { + b.iter(|| { + for i in (0..10).array_combinations_with_replacement::<5>() { + black_box(i); + } + }) + }); +} +fn array_comb_replacement_n5_k10(c: &mut Criterion) { + c.bench_function("array comb replacement n5 k10", move |b| { + b.iter(|| { + for i in (0..5).array_combinations_with_replacement::<10>() { + black_box(i); + } + }) + }); +} + +fn array_comb_replacement_n10_k10(c: &mut Criterion) { + c.bench_function("array comb replacement n10 k10", move |b| { + b.iter(|| { + for i in (0..10).array_combinations_with_replacement::<10>() { + black_box(i); + } + }) + }); +} criterion_group!( benches, comb_replacement_n10_k5, comb_replacement_n5_k10, comb_replacement_n10_k10, + array_comb_replacement_n10_k5, + array_comb_replacement_n5_k10, + array_comb_replacement_n10_k10, ); criterion_main!(benches); diff --git a/benches/specializations.rs b/benches/specializations.rs index e70323f8e..2b9eff6d1 100644 --- a/benches/specializations.rs +++ b/benches/specializations.rs @@ -441,6 +441,30 @@ bench_specializations! { } v.iter().combinations_with_replacement(4) } + array_combinations_with_replacement1 { + { + let v = black_box(vec![0; 4096]); + } + v.iter().array_combinations_with_replacement::<1>() + } + array_combinations_with_replacement2 { + { + let v = black_box(vec![0; 90]); + } + v.iter().array_combinations_with_replacement::<2>() + } + array_combinations_with_replacement3 { + { + let v = black_box(vec![0; 28]); + } + v.iter().array_combinations_with_replacement::<3>() + } + array_combinations_with_replacement4 { + { + let v = black_box(vec![0; 16]); + } + v.iter().array_combinations_with_replacement::<4>() + } permutations1 { { let v = black_box(vec![0; 1024]); diff --git a/src/combinations.rs b/src/combinations.rs index 54a027551..7bf9dfeab 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -1,3 +1,4 @@ +use alloc::boxed::Box; use core::array; use core::borrow::BorrowMut; use std::fmt; @@ -52,7 +53,16 @@ pub trait PoolIndex: BorrowMut<[usize]> { self.borrow().len() } } +impl PoolIndex for Box<[usize]> { + type Item = Vec; + fn extract_item>(&self, pool: &LazyBuffer) -> Vec + where + T: Clone, + { + pool.get_at(self) + } +} impl PoolIndex for Vec { type Item = Vec; diff --git a/src/combinations_with_replacement.rs b/src/combinations_with_replacement.rs index c17e75250..7142681fe 100644 --- a/src/combinations_with_replacement.rs +++ b/src/combinations_with_replacement.rs @@ -1,35 +1,50 @@ use alloc::boxed::Box; -use alloc::vec::Vec; use std::fmt; use std::iter::FusedIterator; use super::lazy_buffer::LazyBuffer; use crate::adaptors::checked_binomial; - +use crate::combinations::PoolIndex; /// An iterator to iterate through all the `n`-length combinations in an iterator, with replacement. /// /// See [`.combinations_with_replacement()`](crate::Itertools::combinations_with_replacement) /// for more information. #[derive(Clone)] #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] -pub struct CombinationsWithReplacement +pub struct CombinationsWithReplacementGeneric where I: Iterator, I::Item: Clone, { - indices: Box<[usize]>, + indices: Idx, pool: LazyBuffer, first: bool, } -impl fmt::Debug for CombinationsWithReplacement +/// Iterator for `Box<[I]>` valued combinations_with_replacement returned by [`.combinations_with_replacement()`](crate::Itertools::combinations_with_replacement) +pub type CombinationsWithReplacement = CombinationsWithReplacementGeneric>; +/// Iterator for const generic combinations_with_replacement returned by [`.array_combinations_with_replacement()`](crate::Itertools::array_combinations_with_replacement) +pub type ArrayCombinationsWithReplacement = + CombinationsWithReplacementGeneric; + +impl fmt::Debug for CombinationsWithReplacementGeneric where I: Iterator + fmt::Debug, I::Item: fmt::Debug + Clone, + Idx: fmt::Debug, { - debug_fmt_fields!(CombinationsWithReplacement, indices, pool, first); + debug_fmt_fields!(CombinationsWithReplacementGeneric, indices, pool, first); } +/// Create a new `ArrayCombinationsWithReplacement`` from a clonable iterator. +pub fn array_combinations_with_replacement( + iter: I, +) -> ArrayCombinationsWithReplacement +where + I::Item: Clone, +{ + ArrayCombinationsWithReplacement::new(iter, [0; K]) +} /// Create a new `CombinationsWithReplacement` from a clonable iterator. pub fn combinations_with_replacement(iter: I, k: usize) -> CombinationsWithReplacement where @@ -37,16 +52,11 @@ where I::Item: Clone, { let indices = alloc::vec![0; k].into_boxed_slice(); - let pool: LazyBuffer = LazyBuffer::new(iter); - CombinationsWithReplacement { - indices, - pool, - first: true, - } + CombinationsWithReplacementGeneric::new(iter, indices) } -impl CombinationsWithReplacement +impl> CombinationsWithReplacementGeneric where I: Iterator, I::Item: Clone, @@ -62,7 +72,8 @@ where // Work out where we need to update our indices let mut increment = None; - for (i, indices_int) in self.indices.iter().enumerate().rev() { + let indices: &mut [usize] = self.indices.borrow_mut(); + for (i, indices_int) in indices.iter().enumerate().rev() { if *indices_int < self.pool.len() - 1 { increment = Some((i, indices_int + 1)); break; @@ -73,39 +84,48 @@ where Some((increment_from, increment_value)) => { // We need to update the rightmost non-max value // and all those to the right - self.indices[increment_from..].fill(increment_value); + indices[increment_from..].fill(increment_value); false } // Otherwise, we're done None => true, } } + /// Constructor with arguments the inner iterator and the initial state for the indices. + fn new(iter: I, indices: Idx) -> Self { + Self { + indices, + pool: LazyBuffer::new(iter), + first: true, + } + } } -impl Iterator for CombinationsWithReplacement +impl Iterator for CombinationsWithReplacementGeneric where I: Iterator, I::Item: Clone, + Idx: PoolIndex, { - type Item = Vec; + type Item = Idx::Item; fn next(&mut self) -> Option { if self.first { // In empty edge cases, stop iterating immediately - if !(self.indices.is_empty() || self.pool.get_next()) { + if !(self.indices.borrow().is_empty() || self.pool.get_next()) { return None; } self.first = false; } else if self.increment_indices() { return None; } - Some(self.pool.get_at(&self.indices)) + Some(self.indices.extract_item(&self.pool)) } fn nth(&mut self, n: usize) -> Option { if self.first { // In empty edge cases, stop iterating immediately - if !(self.indices.is_empty() || self.pool.get_next()) { + if !(self.indices.borrow().is_empty() || self.pool.get_next()) { return None; } self.first = false; @@ -117,13 +137,13 @@ where return None; } } - Some(self.pool.get_at(&self.indices)) + Some(self.indices.extract_item(&self.pool)) } fn size_hint(&self) -> (usize, Option) { let (mut low, mut upp) = self.pool.size_hint(); - low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX); - upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices)); + low = remaining_for(low, self.first, self.indices.borrow()).unwrap_or(usize::MAX); + upp = upp.and_then(|upp| remaining_for(upp, self.first, self.indices.borrow())); (low, upp) } @@ -134,14 +154,15 @@ where first, } = self; let n = pool.count(); - remaining_for(n, first, &indices).unwrap() + remaining_for(n, first, indices.borrow()).unwrap() } } -impl FusedIterator for CombinationsWithReplacement +impl FusedIterator for CombinationsWithReplacementGeneric where I: Iterator, I::Item: Clone, + Idx: PoolIndex, { } diff --git a/src/lib.rs b/src/lib.rs index 9f8d1cd1a..7c541ae3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -153,6 +153,8 @@ pub mod traits { pub use crate::tuple_impl::HomogeneousTuple; } +#[cfg(feature = "use_alloc")] +use crate::combinations_with_replacement::ArrayCombinationsWithReplacement; pub use crate::concat_impl::concat; pub use crate::cons_tuples_impl::cons_tuples; pub use crate::diff::diff_with; @@ -1804,7 +1806,35 @@ pub trait Itertools: Iterator { { combinations_with_replacement::combinations_with_replacement(self, k) } - + /// Return an iterator that iterates over the `k`-length combinations of + /// the elements from an iterator, with replacement. + /// + /// Iterator element type is [Self::Item; K]. The iterator produces a new + /// array per iteration, and clones the iterator elements. + /// + /// ``` + /// use itertools::Itertools; + /// + /// let it = (1..4).array_combinations_with_replacement::<2>(); + /// itertools::assert_equal(it, vec![ + /// [1, 1], + /// [1, 2], + /// [1, 3], + /// [2, 2], + /// [2, 3], + /// [3, 3], + /// ]); + /// ``` + #[cfg(feature = "use_alloc")] + fn array_combinations_with_replacement( + self, + ) -> ArrayCombinationsWithReplacement + where + Self: Sized, + Self::Item: Clone, + { + combinations_with_replacement::array_combinations_with_replacement(self) + } /// Return an iterator adaptor that iterates over all k-permutations of the /// elements from an iterator. /// diff --git a/tests/adaptors_no_collect.rs b/tests/adaptors_no_collect.rs index 977224af2..28b20a441 100644 --- a/tests/adaptors_no_collect.rs +++ b/tests/adaptors_no_collect.rs @@ -49,3 +49,7 @@ fn combinations_no_collect() { fn combinations_with_replacement_no_collect() { no_collect_test(|iter| iter.combinations_with_replacement(5)) } +#[test] +fn array_combinations_with_replacement_no_collect() { + no_collect_test(|iter| iter.array_combinations_with_replacement::<5>()) +} diff --git a/tests/laziness.rs b/tests/laziness.rs index c559d33ad..dfeee68f8 100644 --- a/tests/laziness.rs +++ b/tests/laziness.rs @@ -217,6 +217,11 @@ must_use_tests! { let _ = Panicking.combinations_with_replacement(1); let _ = Panicking.combinations_with_replacement(2); } + array_combinations_with_replacement { + let _ = Panicking.array_combinations_with_replacement::<0>(); + let _ = Panicking.array_combinations_with_replacement::<1>(); + let _ = Panicking.array_combinations_with_replacement::<2>(); + } permutations { let _ = Panicking.permutations(0); let _ = Panicking.permutations(1); diff --git a/tests/quick.rs b/tests/quick.rs index e0632fa47..0af73778a 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -1838,6 +1838,11 @@ quickcheck! { is_fused(a.combinations_with_replacement(3)) } + fn fused_array_combination_with_replacement(a: Iter) -> bool + { + is_fused(a.clone().array_combinations_with_replacement::<1>()) && + is_fused(a.array_combinations_with_replacement::<3>()) + } fn fused_tuple_combination(a: Iter) -> bool { is_fused(a.clone().fuse().tuple_combinations::<(_,)>()) && diff --git a/tests/specializations.rs b/tests/specializations.rs index 44e3cedec..26d1f5367 100644 --- a/tests/specializations.rs +++ b/tests/specializations.rs @@ -299,6 +299,16 @@ quickcheck! { TestResult::passed() } + fn array_combinations_with_replacement(a: Vec) -> TestResult { + if a.len() > 10 { + return TestResult::discard(); + } + test_specializations(&a.iter().array_combinations_with_replacement::<1>()); + test_specializations(&a.iter().array_combinations_with_replacement::<2>()); + test_specializations(&a.iter().array_combinations_with_replacement::<3>()); + + TestResult::passed() + } fn permutations(a: Vec, n: u8) -> TestResult { if n > 3 || a.len() > 8 { return TestResult::discard(); diff --git a/tests/test_std.rs b/tests/test_std.rs index ad391faad..c0ee373aa 100644 --- a/tests/test_std.rs +++ b/tests/test_std.rs @@ -1256,6 +1256,30 @@ fn combinations_with_replacement_range_count() { } } +#[test] +#[cfg(not(miri))] +fn array_combinations_with_replacement() { + // Pool smaller than n + it::assert_equal( + (0..1).array_combinations_with_replacement::<2>(), + vec![[0, 0]], + ); + // Pool larger than n + it::assert_equal( + (0..3).array_combinations_with_replacement::<2>(), + vec![[0, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 2]], + ); + // Zero size + it::assert_equal((0..3).array_combinations_with_replacement::<0>(), vec![[]]); + // Zero size on empty pool + it::assert_equal((0..0).array_combinations_with_replacement::<0>(), vec![[]]); + // Empty pool + it::assert_equal( + (0..0).array_combinations_with_replacement::<2>(), + vec![] as Vec<[_; 2]>, + ); +} + #[test] fn powerset() { it::assert_equal((0..0).powerset(), vec![vec![]]);