tor_proto/util/sink_blocker.rs
1//! Implement [`SinkBlocker`], a wrapper type to allow policy-based blocking of
2//! a [futures::Sink].
3
4#![cfg_attr(not(feature = "circ-padding"), expect(dead_code))]
5
6mod boolean_policy;
7mod counting_policy;
8
9pub(crate) use boolean_policy::BooleanPolicy;
10pub(crate) use counting_policy::CountingPolicy;
11
12use std::{
13 pin::Pin,
14 task::{Context, Poll, Waker},
15};
16
17use futures::Sink;
18use pin_project::pin_project;
19use tor_error::Bug;
20
21/// A wrapper for a [`futures::Sink`] that allows its blocking status to be
22/// turned on and off according to a policy.
23///
24/// While the policy is blocking, attempts to enqueue data on the sink
25/// via this `Sink` trait will return [`Poll::Pending`].
26/// Later, when the policy is replaced with a nonblocking one via [`Self::update_policy()`]
27/// this sink can be written to again.
28#[pin_project]
29pub(crate) struct SinkBlocker<S, P = BooleanPolicy> {
30 /// The inner sink.
31 #[pin]
32 inner: S,
33 /// A policy state object, deciding whether we are blocking or not.
34 ///
35 /// Invariant: Whenever we try to send with a blocking Policy,
36 /// we store the context's waker in self.waker.
37 /// If later the policy becomes non-blocking,
38 /// we alert the `Waker`.
39 policy: P,
40 /// A waker that we should alert when `policy` transitions from
41 /// a blocking to a non-blocking state.
42 waker: Option<Waker>,
43}
44
45/// A policy that describes whether cells can be sent on a [`SinkBlocker`].
46///
47/// Each `Policy` object can be in different states:
48/// some states cause the `SinkBlocker` to block traffic,
49/// and some cause the `SinkBlocker` to permit traffic.
50///
51/// The user of a `SinkBlocker` is expected to call
52/// [`update_policy()`](SinkBlocker::update_policy) from time to time,
53/// when they need to make a manual change in the `SinkBlocker`'s status.
54/// This is the only way for a blocked `SinkBlocker` to become unblocked.
55///
56/// Invariants:
57/// - The state of a `Policy` object may transition from
58/// non-blocking to blocking.
59/// - The state of a `Policy` object may _not_ transition
60/// from blocking to non-blocking.
61/// - If [`is_blocking()`](Policy::is_blocking) has returned false,
62/// and no intervening changes have been made to the `Policy`,
63/// [`take_one()`](Policy::take_one) will succeed.
64///
65/// Note that because of this last invariant,
66/// interior mutability is strongly discouraged for implementations of this trait.
67pub(crate) trait Policy {
68 /// Returns true if this policy is currently blocking.
69 ///
70 /// Invariant: If this returns true on a given Policy,
71 /// it must always return true on that Policy in the future.
72 /// (That is, a Policy may become blocked,
73 /// but may not become unblocked.)
74 fn is_blocking(&self) -> bool;
75
76 /// Modify this policy in response to having queued one item.
77 ///
78 /// Requires that `self.is_blocking()` has just returned false.
79 /// Returns an error, and does not change `self`, if this _is_ blocked.
80 /// (That is, you must only call this function on a non-blocked Policy.)
81 //
82 // Notes: The above rules mean that `take_one` can transition from
83 // unblocking to blocking, but never vice versa.
84 fn take_one(&mut self) -> Result<(), Bug>;
85}
86
87impl<S, P> SinkBlocker<S, P> {
88 /// Construct a new `SinkBlocker` wrapping a given sink, with a given
89 /// initial blocking policy.
90 pub(crate) fn new(inner: S, policy: P) -> Self {
91 SinkBlocker {
92 inner,
93 policy,
94 waker: None,
95 }
96 }
97
98 /// Return a reference to the inner `Sink` of this object.
99 ///
100 /// See warnings on `as_inner_mut`.
101 pub(crate) fn as_inner(&self) -> &S {
102 &self.inner
103 }
104
105 /// Return a mutable reference to the inner `Sink` of this object.
106 ///
107 /// Note that with this method, it is possible to bypass the blocking features
108 /// of [`SinkBlocker`]. This is an intentional escape hatch.
109 pub(crate) fn as_inner_mut(&mut self) -> &mut S {
110 &mut self.inner
111 }
112}
113
114impl<S, P: Policy> SinkBlocker<S, P> {
115 /// Replace the current [`Policy`] state object with `new_policy`.
116 ///
117 /// This method is used to make a blocked `SinkBlocker` unblocked,
118 /// or vice versa.
119 //
120 // Invariants: If we become unblocked, alerts our `Waker`.
121 //
122 // (This is the only method that can cause us to transition from blocked to
123 // unblocked, so this is the only place where we have to alert the waker.)
124 pub(crate) fn update_policy(&mut self, new_policy: P) {
125 let was_blocking = self.policy.is_blocking();
126 let is_blocking = new_policy.is_blocking();
127 self.policy = new_policy;
128 if was_blocking && !is_blocking {
129 if let Some(waker) = self.waker.take() {
130 waker.wake();
131 }
132 }
133 }
134}
135
136impl<T, S: Sink<T>, P: Policy> Sink<T> for SinkBlocker<S, P> {
137 type Error = S::Error;
138
139 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 let self_ = self.project();
141 if self_.policy.is_blocking() {
142 // We're blocked. We're going to store the context's Waker,
143 // so that we can invoke it later when the policy changes.
144 *self_.waker = Some(cx.waker().clone());
145 Poll::Pending
146 } else {
147 // If this returns Ready, great!
148 // If this returns Pending, it will wake up the context when it is
149 // no longer blocked.
150 self_.inner.poll_ready(cx)
151 }
152 }
153
154 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
155 let self_ = self.project();
156 // We're only allowed to call this method if poll_ready succeeded,
157 // so we know that is_blocking() was false.
158 let () = self_.inner.start_send(item)?;
159
160 // (Invoke take_one, to account for this item.)
161 //
162 // Note: Instead of calling expect, perhaps it would be better to have a custom error type
163 // that wraps S::Error and also allows for a Bug. But that might be overkill, since
164 // we only expect this error to happen in the event of a bug.
165 let _: () = self_.policy.take_one().expect(
166 "take_one failed after is_blocking returned false: bug in Policy or SinkBlocker",
167 );
168 // (Take_one is not allowed to cause us to become unblocked, so we don't
169 // need to invoke the waiter.)
170
171 Ok(())
172 }
173
174 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175 // Note that we want to flush the inner sink,
176 // even if we are blocking attempts to send onto it.
177 self.project().inner.poll_flush(cx)
178 }
179
180 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181 self.project().inner.poll_close(cx)
182 }
183}
184
185#[cfg(test)]
186mod test {
187 // @@ begin test lint list maintained by maint/add_warning @@
188 #![allow(clippy::bool_assert_comparison)]
189 #![allow(clippy::clone_on_copy)]
190 #![allow(clippy::dbg_macro)]
191 #![allow(clippy::mixed_attributes_style)]
192 #![allow(clippy::print_stderr)]
193 #![allow(clippy::print_stdout)]
194 #![allow(clippy::single_char_pattern)]
195 #![allow(clippy::unwrap_used)]
196 #![allow(clippy::unchecked_time_subtraction)]
197 #![allow(clippy::useless_vec)]
198 #![allow(clippy::needless_pass_by_value)]
199 //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
200
201 use std::sync::{
202 Arc,
203 atomic::{AtomicBool, AtomicUsize, Ordering},
204 };
205
206 use super::*;
207
208 use futures::{SinkExt as _, StreamExt as _, channel::mpsc, poll};
209 use tor_rtmock::MockRuntime;
210
211 #[test]
212 fn block_and_unblock() {
213 // Try a few different schedulers, to make sure that our logic works for all of them.
214 MockRuntime::test_with_various(|runtime| async move {
215 let (tx, mut rx) = mpsc::channel::<u32>(1);
216 let tx = SinkBlocker::new(tx, BooleanPolicy::Unblocked);
217 let mut tx = tx.buffer(5);
218
219 let blocked = Arc::new(AtomicBool::new(false));
220 let n_received = Arc::new(AtomicUsize::new(0));
221
222 let blocked_clone = Arc::clone(&blocked);
223 let n_received_clone = Arc::clone(&n_received);
224 let n_received_clone2 = Arc::clone(&n_received);
225
226 runtime.spawn_identified("Transmitter", async move {
227 tx.send(1).await.unwrap();
228 tx.send(2).await.unwrap();
229 blocked.store(true, Ordering::SeqCst);
230 tx.get_mut().set_blocked();
231 // Have to use "feed" here since send would flush, which would block.
232 tx.feed(3).await.unwrap();
233 tx.feed(4).await.unwrap();
234 assert!(dbg!(n_received.load(Ordering::SeqCst)) <= 2);
235 // Make sure that we _cannot_ flush right now.
236 let flush_future = tx.flush();
237 assert!(poll!(flush_future).is_pending());
238 // Now note that we're unblocked, and unblock.
239 blocked.store(false, Ordering::SeqCst);
240 tx.get_mut().set_unblocked();
241 // This time we should actually flush.
242 tx.flush().await.unwrap();
243 tx.close().await.unwrap();
244 });
245
246 runtime.spawn_identified("Receiver", async move {
247 let n_received = n_received_clone;
248 let blocked = blocked_clone;
249 let mut expected = 1;
250 while let Some(val) = rx.next().await {
251 assert_eq!(val, expected);
252 expected += 1;
253 n_received.fetch_add(1, Ordering::SeqCst);
254 if val >= 3 {
255 assert_eq!(blocked.load(Ordering::SeqCst), false);
256 }
257 }
258 dbg!(expected);
259 });
260
261 runtime.progress_until_stalled().await;
262
263 assert_eq!(dbg!(n_received_clone2.load(Ordering::SeqCst)), 4);
264 });
265 }
266}