Skip to main content

tor_proto/util/
sink_blocker.rs

1//! Implement [`SinkBlocker`], a wrapper type to allow policy-based blocking of
2//! a [futures::Sink].
3
4#![cfg_attr(not(feature = "circ-padding"), expect(dead_code))]
5
6mod boolean_policy;
7mod counting_policy;
8
9pub(crate) use boolean_policy::BooleanPolicy;
10pub(crate) use counting_policy::CountingPolicy;
11
12use std::{
13    pin::Pin,
14    task::{Context, Poll, Waker},
15};
16
17use futures::Sink;
18use pin_project::pin_project;
19use tor_error::Bug;
20
21/// A wrapper for a [`futures::Sink`] that allows its blocking status to be
22/// turned on and off according to a policy.
23///
24/// While the policy is blocking, attempts to enqueue data on the sink
25/// via this `Sink` trait will return [`Poll::Pending`].
26/// Later, when the policy is replaced with a nonblocking one via [`Self::update_policy()`]
27/// this sink can be written to again.
28#[pin_project]
29pub(crate) struct SinkBlocker<S, P = BooleanPolicy> {
30    /// The inner sink.
31    #[pin]
32    inner: S,
33    /// A policy state object, deciding whether we are blocking or not.
34    ///
35    /// Invariant: Whenever we try to send with a blocking Policy,
36    /// we store the context's waker in self.waker.
37    /// If later the policy becomes non-blocking,
38    /// we alert the `Waker`.
39    policy: P,
40    /// A waker that we should alert when `policy` transitions from
41    /// a blocking to a non-blocking state.
42    waker: Option<Waker>,
43}
44
45/// A policy that describes whether cells can be sent on a [`SinkBlocker`].
46///
47/// Each `Policy` object can be in different states:
48/// some states cause the `SinkBlocker` to block traffic,
49/// and some cause the `SinkBlocker` to permit traffic.
50///
51/// The user of a `SinkBlocker` is expected to call
52/// [`update_policy()`](SinkBlocker::update_policy) from time to time,
53/// when they need to make a manual change in the `SinkBlocker`'s status.
54/// This is the only way for a blocked `SinkBlocker` to become unblocked.
55///
56/// Invariants:
57///  - The state of a `Policy` object may transition from
58///    non-blocking to blocking.
59///  - The state of a `Policy` object may _not_ transition
60///    from blocking to non-blocking.
61///  - If [`is_blocking()`](Policy::is_blocking) has returned false,
62///    and no intervening changes have been made to the `Policy`,
63///    [`take_one()`](Policy::take_one) will succeed.
64///
65/// Note that because of this last invariant,
66/// interior mutability is strongly discouraged for implementations of this trait.
67pub(crate) trait Policy {
68    /// Returns true if this policy is currently blocking.
69    ///
70    /// Invariant: If this returns true on a given Policy,
71    /// it must always return true on that Policy in the future.
72    /// (That is, a Policy may become blocked,
73    /// but may not become unblocked.)
74    fn is_blocking(&self) -> bool;
75
76    /// Modify this policy in response to having queued one item.
77    ///
78    /// Requires that `self.is_blocking()` has just returned false.
79    /// Returns an error, and does not change `self`, if this _is_ blocked.
80    /// (That is, you must only call this function on a non-blocked Policy.)
81    //
82    // Notes: The above rules mean that `take_one` can transition from
83    // unblocking to blocking, but never vice versa.
84    fn take_one(&mut self) -> Result<(), Bug>;
85}
86
87impl<S, P> SinkBlocker<S, P> {
88    /// Construct a new `SinkBlocker` wrapping a given sink, with a given
89    /// initial blocking policy.
90    pub(crate) fn new(inner: S, policy: P) -> Self {
91        SinkBlocker {
92            inner,
93            policy,
94            waker: None,
95        }
96    }
97
98    /// Return a reference to the inner `Sink` of this object.
99    ///
100    /// See warnings on `as_inner_mut`.
101    pub(crate) fn as_inner(&self) -> &S {
102        &self.inner
103    }
104
105    /// Return a mutable reference to the inner `Sink` of this object.
106    ///
107    /// Note that with this method, it is possible to bypass the blocking features
108    /// of [`SinkBlocker`].  This is an intentional escape hatch.
109    pub(crate) fn as_inner_mut(&mut self) -> &mut S {
110        &mut self.inner
111    }
112}
113
114impl<S, P: Policy> SinkBlocker<S, P> {
115    /// Replace the current [`Policy`] state object with `new_policy`.
116    ///
117    /// This method is used to make a blocked `SinkBlocker` unblocked,
118    /// or vice versa.
119    //
120    // Invariants: If we become unblocked, alerts our `Waker`.
121    //
122    // (This is the only method that can cause us to transition from blocked to
123    // unblocked, so this is the only place where we have to alert the waker.)
124    pub(crate) fn update_policy(&mut self, new_policy: P) {
125        let was_blocking = self.policy.is_blocking();
126        let is_blocking = new_policy.is_blocking();
127        self.policy = new_policy;
128        if was_blocking && !is_blocking {
129            if let Some(waker) = self.waker.take() {
130                waker.wake();
131            }
132        }
133    }
134}
135
136impl<T, S: Sink<T>, P: Policy> Sink<T> for SinkBlocker<S, P> {
137    type Error = S::Error;
138
139    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140        let self_ = self.project();
141        if self_.policy.is_blocking() {
142            // We're blocked.  We're going to store the context's Waker,
143            // so that we can invoke it later when the policy changes.
144            *self_.waker = Some(cx.waker().clone());
145            Poll::Pending
146        } else {
147            // If this returns Ready, great!
148            // If this returns Pending, it will wake up the context when it is
149            // no longer blocked.
150            self_.inner.poll_ready(cx)
151        }
152    }
153
154    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
155        let self_ = self.project();
156        // We're only allowed to call this method if poll_ready succeeded,
157        // so we know that is_blocking() was false.
158        let () = self_.inner.start_send(item)?;
159
160        // (Invoke take_one, to account for this item.)
161        //
162        // Note: Instead of calling expect, perhaps it would be better to have a custom error type
163        // that wraps S::Error and also allows for a Bug.  But that might be overkill, since
164        // we only expect this error to happen in the event of a bug.
165        let _: () = self_.policy.take_one().expect(
166            "take_one failed after is_blocking returned false: bug in Policy or SinkBlocker",
167        );
168        // (Take_one is not allowed to cause us to become unblocked, so we don't
169        // need to invoke the waiter.)
170
171        Ok(())
172    }
173
174    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175        // Note that we want to flush the inner sink,
176        // even if we are blocking attempts to send onto it.
177        self.project().inner.poll_flush(cx)
178    }
179
180    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
181        self.project().inner.poll_close(cx)
182    }
183}
184
185#[cfg(test)]
186mod test {
187    // @@ begin test lint list maintained by maint/add_warning @@
188    #![allow(clippy::bool_assert_comparison)]
189    #![allow(clippy::clone_on_copy)]
190    #![allow(clippy::dbg_macro)]
191    #![allow(clippy::mixed_attributes_style)]
192    #![allow(clippy::print_stderr)]
193    #![allow(clippy::print_stdout)]
194    #![allow(clippy::single_char_pattern)]
195    #![allow(clippy::unwrap_used)]
196    #![allow(clippy::unchecked_time_subtraction)]
197    #![allow(clippy::useless_vec)]
198    #![allow(clippy::needless_pass_by_value)]
199    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
200
201    use std::sync::{
202        Arc,
203        atomic::{AtomicBool, AtomicUsize, Ordering},
204    };
205
206    use super::*;
207
208    use futures::{SinkExt as _, StreamExt as _, channel::mpsc, poll};
209    use tor_rtmock::MockRuntime;
210
211    #[test]
212    fn block_and_unblock() {
213        // Try a few different schedulers, to make sure that our logic works for all of them.
214        MockRuntime::test_with_various(|runtime| async move {
215            let (tx, mut rx) = mpsc::channel::<u32>(1);
216            let tx = SinkBlocker::new(tx, BooleanPolicy::Unblocked);
217            let mut tx = tx.buffer(5);
218
219            let blocked = Arc::new(AtomicBool::new(false));
220            let n_received = Arc::new(AtomicUsize::new(0));
221
222            let blocked_clone = Arc::clone(&blocked);
223            let n_received_clone = Arc::clone(&n_received);
224            let n_received_clone2 = Arc::clone(&n_received);
225
226            runtime.spawn_identified("Transmitter", async move {
227                tx.send(1).await.unwrap();
228                tx.send(2).await.unwrap();
229                blocked.store(true, Ordering::SeqCst);
230                tx.get_mut().set_blocked();
231                // Have to use "feed" here since send would flush, which would block.
232                tx.feed(3).await.unwrap();
233                tx.feed(4).await.unwrap();
234                assert!(dbg!(n_received.load(Ordering::SeqCst)) <= 2);
235                // Make sure that we _cannot_ flush right now.
236                let flush_future = tx.flush();
237                assert!(poll!(flush_future).is_pending());
238                // Now note that we're unblocked, and unblock.
239                blocked.store(false, Ordering::SeqCst);
240                tx.get_mut().set_unblocked();
241                // This time we should actually flush.
242                tx.flush().await.unwrap();
243                tx.close().await.unwrap();
244            });
245
246            runtime.spawn_identified("Receiver", async move {
247                let n_received = n_received_clone;
248                let blocked = blocked_clone;
249                let mut expected = 1;
250                while let Some(val) = rx.next().await {
251                    assert_eq!(val, expected);
252                    expected += 1;
253                    n_received.fetch_add(1, Ordering::SeqCst);
254                    if val >= 3 {
255                        assert_eq!(blocked.load(Ordering::SeqCst), false);
256                    }
257                }
258                dbg!(expected);
259            });
260
261            runtime.progress_until_stalled().await;
262
263            assert_eq!(dbg!(n_received_clone2.load(Ordering::SeqCst)), 4);
264        });
265    }
266}