Skip to main content

tor_proto/util/
poll_all.rs

1//! [`PollAll`]
2
3use futures::FutureExt as _;
4use smallvec::{SmallVec, smallvec};
5
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10/// The future type in a [`PollAll`].
11type BoxedFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
12
13/// Helper for driving multiple futures in lockstep.
14///
15/// When `.await`ed, a [`PollAll`] will unconditionally poll *all* of its
16/// underlying futures, in the order they were [`push`](PollAll::push)ed,
17/// until one or more of them resolves.
18/// Any remaining unresolved futures will be dropped.
19/// An empty `PollAll` will resolve immediately, yielding an empty list.
20///
21/// `PollAll` resolves to an *ordered* list of results, obtained from polling
22/// the futures in insertion order. Because some of the futures may not
23/// get a chance to resolve, the number of results will always
24/// be less than or equal to the number of inserted futures.
25///
26/// Because `PollAll` drives the futures in lockstep,
27/// if one future becomes ready, all of the futures will get polled,
28/// even if they didn't generate a wakeup notification.
29///
30/// ### Invariants
31///
32/// All of the futures inserted into this set **must** be cancellation safe.
33#[derive(Default)]
34pub(crate) struct PollAll<'a, const N: usize, T> {
35    /// The futures to drive in lockstep.
36    inner: SmallVec<[BoxedFut<'a, T>; N]>,
37}
38
39impl<'a, const N: usize, T> PollAll<'a, N, T> {
40    /// Create an empty [`PollAll`].
41    pub(crate) fn new() -> Self {
42        Self { inner: smallvec![] }
43    }
44
45    /// Add a future to this [`PollAll`].
46    pub(crate) fn push<S: Future<Output = T> + Send + 'a>(&mut self, item: S) {
47        self.inner.push(Box::pin(item));
48    }
49}
50
51impl<'a, const N: usize, T> Future for PollAll<'a, N, T> {
52    type Output = SmallVec<[T; N]>;
53
54    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55        let mut results = smallvec![];
56
57        if self.inner.is_empty() {
58            // Nothing to do.
59            return Poll::Ready(results);
60        }
61
62        for fut in self.inner.iter_mut() {
63            match fut.poll_unpin(cx) {
64                Poll::Ready(res) => results.push(res),
65                Poll::Pending => continue,
66            }
67        }
68
69        if results.is_empty() {
70            return Poll::Pending;
71        }
72
73        Poll::Ready(results)
74    }
75}
76
77#[cfg(test)]
78mod test {
79    // @@ begin test lint list maintained by maint/add_warning @@
80    #![allow(clippy::bool_assert_comparison)]
81    #![allow(clippy::clone_on_copy)]
82    #![allow(clippy::dbg_macro)]
83    #![allow(clippy::mixed_attributes_style)]
84    #![allow(clippy::print_stderr)]
85    #![allow(clippy::print_stdout)]
86    #![allow(clippy::single_char_pattern)]
87    #![allow(clippy::unwrap_used)]
88    #![allow(clippy::unchecked_time_subtraction)]
89    #![allow(clippy::useless_vec)]
90    #![allow(clippy::needless_pass_by_value)]
91    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
92    use super::*;
93
94    use tor_rtmock::MockRuntime;
95
96    use std::sync::Arc;
97    use std::sync::atomic::{AtomicUsize, Ordering};
98
99    /// Dummy smallvec capacity.
100    const RES_COUNT: usize = 5;
101
102    /// A wrapper over a future, that counts how many times it is polled.
103    struct PollCounter<F> {
104        /// The poll count, shared with the caller.
105        count: Arc<AtomicUsize>,
106        /// The underlying future.
107        inner: F,
108    }
109
110    /// A future that resolves after a fixed number of calls to `poll()`.
111    struct ResolveAfter {
112        /// The number of poll() calls until this future resolves
113        resolve_after: usize,
114        /// The number of times poll() was called on this.
115        poll_count: usize,
116    }
117
118    impl ResolveAfter {
119        fn new(resolve_after: usize) -> Self {
120            Self {
121                resolve_after,
122                poll_count: 0,
123            }
124        }
125    }
126
127    impl<F> PollCounter<F> {
128        fn new(inner: F) -> (Self, Arc<AtomicUsize>) {
129            let count = Arc::new(AtomicUsize::new(0));
130            let poll_counter = Self {
131                count: Arc::clone(&count),
132                inner,
133            };
134
135            (poll_counter, count)
136        }
137    }
138
139    impl<F: Future + Unpin> Future for PollCounter<F> {
140        type Output = F::Output;
141
142        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143            let _ = self.count.fetch_add(1, Ordering::Relaxed);
144            self.inner.poll_unpin(cx)
145        }
146    }
147
148    impl Future for ResolveAfter {
149        type Output = usize;
150
151        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152            self.poll_count += 1;
153
154            // TODO MSRV 1.87: Remove this allow.
155            #[allow(
156                clippy::comparison_chain,
157                reason = "This is more readable than a match, and the lint is
158                moved to clippy::pedantic in 1.87."
159            )]
160            if self.poll_count == self.resolve_after {
161                Poll::Ready(self.resolve_after)
162            } else if self.poll_count > self.resolve_after {
163                panic!("future polled after completion?!");
164            } else {
165                // Immediately wake the waker
166                cx.waker().wake_by_ref();
167                Poll::Pending
168            }
169        }
170    }
171
172    #[test]
173    fn poll_none() {
174        MockRuntime::test_with_various(|_| async move {
175            assert!(PollAll::<RES_COUNT, ()>::new().await.is_empty());
176        });
177    }
178
179    #[test]
180    fn poll_multiple() {
181        MockRuntime::test_with_various(|_| async move {
182            let mut poll_all = PollAll::<RES_COUNT, usize>::new();
183
184            let (never_fut, never_count) = PollCounter::new(futures::future::pending::<usize>());
185            poll_all.push(never_fut);
186
187            let (futures, counters): (Vec<_>, Vec<_>) = [
188                PollCounter::new(ResolveAfter::new(5)),
189                PollCounter::new(ResolveAfter::new(5)),
190                // These won't get a chance to resolve
191                PollCounter::new(ResolveAfter::new(8)),
192                PollCounter::new(ResolveAfter::new(9)),
193            ]
194            .into_iter()
195            .unzip();
196
197            for fut in futures {
198                poll_all.push(fut);
199            }
200
201            let res = poll_all.await;
202            assert_eq!(&res[..], &[5, 5]);
203
204            // All futures were polled 5 times.
205            assert_eq!(never_count.load(Ordering::Relaxed), 5);
206            for counter in counters {
207                assert_eq!(counter.load(Ordering::Relaxed), 5);
208            }
209        });
210    }
211}