tor_proto/util/
poll_all.rs1use futures::FutureExt as _;
4use smallvec::{SmallVec, smallvec};
5
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10type BoxedFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
12
13#[derive(Default)]
34pub(crate) struct PollAll<'a, const N: usize, T> {
35 inner: SmallVec<[BoxedFut<'a, T>; N]>,
37}
38
39impl<'a, const N: usize, T> PollAll<'a, N, T> {
40 pub(crate) fn new() -> Self {
42 Self { inner: smallvec![] }
43 }
44
45 pub(crate) fn push<S: Future<Output = T> + Send + 'a>(&mut self, item: S) {
47 self.inner.push(Box::pin(item));
48 }
49}
50
51impl<'a, const N: usize, T> Future for PollAll<'a, N, T> {
52 type Output = SmallVec<[T; N]>;
53
54 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
55 let mut results = smallvec![];
56
57 if self.inner.is_empty() {
58 return Poll::Ready(results);
60 }
61
62 for fut in self.inner.iter_mut() {
63 match fut.poll_unpin(cx) {
64 Poll::Ready(res) => results.push(res),
65 Poll::Pending => continue,
66 }
67 }
68
69 if results.is_empty() {
70 return Poll::Pending;
71 }
72
73 Poll::Ready(results)
74 }
75}
76
77#[cfg(test)]
78mod test {
79 #![allow(clippy::bool_assert_comparison)]
81 #![allow(clippy::clone_on_copy)]
82 #![allow(clippy::dbg_macro)]
83 #![allow(clippy::mixed_attributes_style)]
84 #![allow(clippy::print_stderr)]
85 #![allow(clippy::print_stdout)]
86 #![allow(clippy::single_char_pattern)]
87 #![allow(clippy::unwrap_used)]
88 #![allow(clippy::unchecked_time_subtraction)]
89 #![allow(clippy::useless_vec)]
90 #![allow(clippy::needless_pass_by_value)]
91 use super::*;
93
94 use tor_rtmock::MockRuntime;
95
96 use std::sync::Arc;
97 use std::sync::atomic::{AtomicUsize, Ordering};
98
99 const RES_COUNT: usize = 5;
101
102 struct PollCounter<F> {
104 count: Arc<AtomicUsize>,
106 inner: F,
108 }
109
110 struct ResolveAfter {
112 resolve_after: usize,
114 poll_count: usize,
116 }
117
118 impl ResolveAfter {
119 fn new(resolve_after: usize) -> Self {
120 Self {
121 resolve_after,
122 poll_count: 0,
123 }
124 }
125 }
126
127 impl<F> PollCounter<F> {
128 fn new(inner: F) -> (Self, Arc<AtomicUsize>) {
129 let count = Arc::new(AtomicUsize::new(0));
130 let poll_counter = Self {
131 count: Arc::clone(&count),
132 inner,
133 };
134
135 (poll_counter, count)
136 }
137 }
138
139 impl<F: Future + Unpin> Future for PollCounter<F> {
140 type Output = F::Output;
141
142 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143 let _ = self.count.fetch_add(1, Ordering::Relaxed);
144 self.inner.poll_unpin(cx)
145 }
146 }
147
148 impl Future for ResolveAfter {
149 type Output = usize;
150
151 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152 self.poll_count += 1;
153
154 #[allow(
156 clippy::comparison_chain,
157 reason = "This is more readable than a match, and the lint is
158 moved to clippy::pedantic in 1.87."
159 )]
160 if self.poll_count == self.resolve_after {
161 Poll::Ready(self.resolve_after)
162 } else if self.poll_count > self.resolve_after {
163 panic!("future polled after completion?!");
164 } else {
165 cx.waker().wake_by_ref();
167 Poll::Pending
168 }
169 }
170 }
171
172 #[test]
173 fn poll_none() {
174 MockRuntime::test_with_various(|_| async move {
175 assert!(PollAll::<RES_COUNT, ()>::new().await.is_empty());
176 });
177 }
178
179 #[test]
180 fn poll_multiple() {
181 MockRuntime::test_with_various(|_| async move {
182 let mut poll_all = PollAll::<RES_COUNT, usize>::new();
183
184 let (never_fut, never_count) = PollCounter::new(futures::future::pending::<usize>());
185 poll_all.push(never_fut);
186
187 let (futures, counters): (Vec<_>, Vec<_>) = [
188 PollCounter::new(ResolveAfter::new(5)),
189 PollCounter::new(ResolveAfter::new(5)),
190 PollCounter::new(ResolveAfter::new(8)),
192 PollCounter::new(ResolveAfter::new(9)),
193 ]
194 .into_iter()
195 .unzip();
196
197 for fut in futures {
198 poll_all.push(fut);
199 }
200
201 let res = poll_all.await;
202 assert_eq!(&res[..], &[5, 5]);
203
204 assert_eq!(never_count.load(Ordering::Relaxed), 5);
206 for counter in counters {
207 assert_eq!(counter.load(Ordering::Relaxed), 5);
208 }
209 });
210 }
211}