[go: up one dir, main page]

core/iter/adapters/
filter.rs

1use core::array;
2use core::mem::MaybeUninit;
3use core::ops::ControlFlow;
4
5use crate::fmt;
6use crate::iter::adapters::SourceIter;
7use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
8use crate::num::NonZero;
9use crate::ops::Try;
10
11/// An iterator that filters the elements of `iter` with `predicate`.
12///
13/// This `struct` is created by the [`filter`] method on [`Iterator`]. See its
14/// documentation for more.
15///
16/// [`filter`]: Iterator::filter
17/// [`Iterator`]: trait.Iterator.html
18#[must_use = "iterators are lazy and do nothing unless consumed"]
19#[stable(feature = "rust1", since = "1.0.0")]
20#[derive(Clone)]
21pub struct Filter<I, P> {
22    // Used for `SplitWhitespace` and `SplitAsciiWhitespace` `as_str` methods
23    pub(crate) iter: I,
24    predicate: P,
25}
26impl<I, P> Filter<I, P> {
27    pub(in crate::iter) fn new(iter: I, predicate: P) -> Filter<I, P> {
28        Filter { iter, predicate }
29    }
30}
31
32impl<I, P> Filter<I, P>
33where
34    I: Iterator,
35    P: FnMut(&I::Item) -> bool,
36{
37    #[inline]
38    fn next_chunk_dropless<const N: usize>(
39        &mut self,
40    ) -> Result<[I::Item; N], array::IntoIter<I::Item, N>> {
41        let mut array: [MaybeUninit<I::Item>; N] = [const { MaybeUninit::uninit() }; N];
42        let mut initialized = 0;
43
44        let result = self.iter.try_for_each(|element| {
45            let idx = initialized;
46            // branchless index update combined with unconditionally copying the value even when
47            // it is filtered reduces branching and dependencies in the loop.
48            initialized = idx + (self.predicate)(&element) as usize;
49            // SAFETY: Loop conditions ensure the index is in bounds.
50            unsafe { array.get_unchecked_mut(idx) }.write(element);
51
52            if initialized < N { ControlFlow::Continue(()) } else { ControlFlow::Break(()) }
53        });
54
55        match result {
56            ControlFlow::Break(()) => {
57                // SAFETY: The loop above is only explicitly broken when the array has been fully initialized
58                Ok(unsafe { MaybeUninit::array_assume_init(array) })
59            }
60            ControlFlow::Continue(()) => {
61                // SAFETY: The range is in bounds since the loop breaks when reaching N elements.
62                Err(unsafe { array::IntoIter::new_unchecked(array, 0..initialized) })
63            }
64        }
65    }
66}
67
68#[stable(feature = "core_impl_debug", since = "1.9.0")]
69impl<I: fmt::Debug, P> fmt::Debug for Filter<I, P> {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("Filter").field("iter", &self.iter).finish()
72    }
73}
74
75fn filter_fold<T, Acc>(
76    mut predicate: impl FnMut(&T) -> bool,
77    mut fold: impl FnMut(Acc, T) -> Acc,
78) -> impl FnMut(Acc, T) -> Acc {
79    move |acc, item| if predicate(&item) { fold(acc, item) } else { acc }
80}
81
82fn filter_try_fold<'a, T, Acc, R: Try<Output = Acc>>(
83    predicate: &'a mut impl FnMut(&T) -> bool,
84    mut fold: impl FnMut(Acc, T) -> R + 'a,
85) -> impl FnMut(Acc, T) -> R + 'a {
86    move |acc, item| if predicate(&item) { fold(acc, item) } else { try { acc } }
87}
88
89#[stable(feature = "rust1", since = "1.0.0")]
90impl<I: Iterator, P> Iterator for Filter<I, P>
91where
92    P: FnMut(&I::Item) -> bool,
93{
94    type Item = I::Item;
95
96    #[inline]
97    fn next(&mut self) -> Option<I::Item> {
98        self.iter.find(&mut self.predicate)
99    }
100
101    #[inline]
102    fn next_chunk<const N: usize>(
103        &mut self,
104    ) -> Result<[Self::Item; N], array::IntoIter<Self::Item, N>> {
105        // avoid codegen for the dead branch
106        let fun = const {
107            if crate::mem::needs_drop::<I::Item>() {
108                array::iter_next_chunk::<I::Item, N>
109            } else {
110                Self::next_chunk_dropless::<N>
111            }
112        };
113
114        fun(self)
115    }
116
117    #[inline]
118    fn size_hint(&self) -> (usize, Option<usize>) {
119        let (_, upper) = self.iter.size_hint();
120        (0, upper) // can't know a lower bound, due to the predicate
121    }
122
123    // this special case allows the compiler to make `.filter(_).count()`
124    // branchless. Barring perfect branch prediction (which is unattainable in
125    // the general case), this will be much faster in >90% of cases (containing
126    // virtually all real workloads) and only a tiny bit slower in the rest.
127    //
128    // Having this specialization thus allows us to write `.filter(p).count()`
129    // where we would otherwise write `.map(|x| p(x) as usize).sum()`, which is
130    // less readable and also less backwards-compatible to Rust before 1.10.
131    //
132    // Using the branchless version will also simplify the LLVM byte code, thus
133    // leaving more budget for LLVM optimizations.
134    #[inline]
135    fn count(self) -> usize {
136        #[inline]
137        fn to_usize<T>(mut predicate: impl FnMut(&T) -> bool) -> impl FnMut(T) -> usize {
138            move |x| predicate(&x) as usize
139        }
140
141        let before = self.iter.size_hint().1.unwrap_or(usize::MAX);
142        let total = self.iter.map(to_usize(self.predicate)).sum();
143        // SAFETY: `total` and `before` came from the same iterator of type `I`
144        unsafe {
145            <I as SpecAssumeCount>::assume_count_le_upper_bound(total, before);
146        }
147        total
148    }
149
150    #[inline]
151    fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
152    where
153        Self: Sized,
154        Fold: FnMut(Acc, Self::Item) -> R,
155        R: Try<Output = Acc>,
156    {
157        self.iter.try_fold(init, filter_try_fold(&mut self.predicate, fold))
158    }
159
160    #[inline]
161    fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
162    where
163        Fold: FnMut(Acc, Self::Item) -> Acc,
164    {
165        self.iter.fold(init, filter_fold(self.predicate, fold))
166    }
167}
168
169#[stable(feature = "rust1", since = "1.0.0")]
170impl<I: DoubleEndedIterator, P> DoubleEndedIterator for Filter<I, P>
171where
172    P: FnMut(&I::Item) -> bool,
173{
174    #[inline]
175    fn next_back(&mut self) -> Option<I::Item> {
176        self.iter.rfind(&mut self.predicate)
177    }
178
179    #[inline]
180    fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
181    where
182        Self: Sized,
183        Fold: FnMut(Acc, Self::Item) -> R,
184        R: Try<Output = Acc>,
185    {
186        self.iter.try_rfold(init, filter_try_fold(&mut self.predicate, fold))
187    }
188
189    #[inline]
190    fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
191    where
192        Fold: FnMut(Acc, Self::Item) -> Acc,
193    {
194        self.iter.rfold(init, filter_fold(self.predicate, fold))
195    }
196}
197
198#[stable(feature = "fused", since = "1.26.0")]
199impl<I: FusedIterator, P> FusedIterator for Filter<I, P> where P: FnMut(&I::Item) -> bool {}
200
201#[unstable(issue = "none", feature = "trusted_fused")]
202unsafe impl<I: TrustedFused, F> TrustedFused for Filter<I, F> {}
203
204#[unstable(issue = "none", feature = "inplace_iteration")]
205unsafe impl<P, I> SourceIter for Filter<I, P>
206where
207    I: SourceIter,
208{
209    type Source = I::Source;
210
211    #[inline]
212    unsafe fn as_inner(&mut self) -> &mut I::Source {
213        // SAFETY: unsafe function forwarding to unsafe function with the same requirements
214        unsafe { SourceIter::as_inner(&mut self.iter) }
215    }
216}
217
218#[unstable(issue = "none", feature = "inplace_iteration")]
219unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
220    const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
221    const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
222}
223
224trait SpecAssumeCount {
225    /// # Safety
226    ///
227    /// `count` must be an number of items actually read from the iterator.
228    ///
229    /// `upper` must either:
230    /// - have come from `size_hint().1` on the iterator, or
231    /// - be `usize::MAX` which will vacuously do nothing.
232    unsafe fn assume_count_le_upper_bound(count: usize, upper: usize);
233}
234
235impl<I: Iterator> SpecAssumeCount for I {
236    #[inline]
237    #[rustc_inherit_overflow_checks]
238    default unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
239        // In the default we can't trust the `upper` for soundness
240        // because it came from an untrusted `size_hint`.
241
242        // In debug mode we might as well check that the size_hint wasn't too small
243        let _ = upper - count;
244    }
245}
246
247impl<I: TrustedLen> SpecAssumeCount for I {
248    #[inline]
249    unsafe fn assume_count_le_upper_bound(count: usize, upper: usize) {
250        // SAFETY: The `upper` is trusted because it came from a `TrustedLen` iterator.
251        unsafe { crate::hint::assert_unchecked(count <= upper) }
252    }
253}