Skip to main content

tor_proto/stream/
queue.rs

1//! Queues for stream messages.
2//!
3//! While these are technically "channels", we call them "queues" to indicate that they're mostly
4//! just dumb pipes. They do some tracking (memquota and size), but nothing else. The higher-level
5//! object is [`StreamReceiver`](crate::stream::raw::StreamReceiver) which tracks SENDME and END
6//! messages. So the idea is that the "queue" (ex: [`StreamQueueReceiver`]) just holds data and the
7//! "channel" (ex: `StreamReceiver`) adds the Tor logic.
8//!
9//! The main purpose of these types are so that we can count how many bytes of stream data are
10//! stored for the stream. Ideally we'd use a channel type that tracks and reports this as part of
11//! its implementation, but popular channel implementations don't seem to do that.
12
13use std::fmt::Debug;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use std::task::{Context, Poll};
17
18use futures::{Sink, SinkExt, Stream};
19use tor_async_utils::SinkTrySend;
20use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
21use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
22use tor_cell::relaycell::UnparsedRelayMsg;
23use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec, MpscUnboundedSpec};
24use tor_rtcompat::DynTimeProvider;
25
26use crate::memquota::{SpecificAccount, StreamAccount};
27
28// TODO(arti#534): remove these type aliases when we remove the "flowctl-cc" feature,
29// and just use `MpscUnboundedSpec` everywhere
30#[cfg(feature = "flowctl-cc")]
31/// Alias for the memquota mpsc spec.
32type Spec = MpscUnboundedSpec;
33#[cfg(not(feature = "flowctl-cc"))]
34/// Alias for the memquota mpsc spec.
35type Spec = MpscSpec;
36
37/// Create a new stream queue for incoming messages.
38pub(crate) fn stream_queue(
39    #[cfg(not(feature = "flowctl-cc"))] size: usize,
40    memquota: &StreamAccount,
41    time_prov: &DynTimeProvider,
42) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
43    let (sender, receiver) = {
44        cfg_if::cfg_if! {
45            if #[cfg(not(feature = "flowctl-cc"))] {
46                MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?
47            } else {
48                MpscUnboundedSpec::new().new_mq(time_prov.clone(), memquota.as_raw_account())?
49            }
50        }
51    };
52
53    let receiver = StreamUnobtrusivePeeker::new(receiver);
54    let counter = Arc::new(Mutex::new(0));
55    Ok((
56        StreamQueueSender {
57            sender,
58            counter: Arc::clone(&counter),
59        },
60        StreamQueueReceiver { receiver, counter },
61    ))
62}
63
64/// For testing purposes, create a stream queue with a no-op memquota account and a fake time
65/// provider.
66#[cfg(test)]
67pub(crate) fn fake_stream_queue(
68    #[cfg(not(feature = "flowctl-cc"))] size: usize,
69) -> (StreamQueueSender, StreamQueueReceiver) {
70    // The fake Account doesn't care about the data ages, so this will do.
71    //
72    // This would be wrong to use generally in tests, where we might want to mock time,
73    // since we end up, here with totally *different* mocked time.
74    // But it's OK here, and saves passing a runtime parameter into this function.
75    stream_queue(
76        #[cfg(not(feature = "flowctl-cc"))]
77        size,
78        &StreamAccount::new_noop(),
79        &DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
80    )
81    .expect("create fake stream queue")
82}
83
84/// The sending end of a channel of incoming stream messages.
85#[derive(Debug)]
86#[pin_project::pin_project]
87pub(crate) struct StreamQueueSender {
88    /// The inner sender.
89    #[pin]
90    sender: mq_queue::Sender<UnparsedRelayMsg, Spec>,
91    /// Number of bytes within the queue.
92    counter: Arc<Mutex<usize>>,
93}
94
95/// The receiving end of a channel of incoming stream messages.
96#[derive(Debug)]
97#[pin_project::pin_project]
98pub(crate) struct StreamQueueReceiver {
99    /// The inner receiver.
100    ///
101    /// We add the [`StreamUnobtrusivePeeker`] here so that peeked messages are included in
102    /// `counter`.
103    // TODO(arti#534): the possible extra msg held by the `StreamUnobtrusivePeeker` isn't tracked by
104    // memquota
105    #[pin]
106    receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, Spec>>,
107    /// Number of bytes within the queue.
108    counter: Arc<Mutex<usize>>,
109}
110
111impl StreamQueueSender {
112    /// Get the approximate number of data bytes queued for this stream.
113    ///
114    /// As messages can be dequeued at any time, the return value may be larger than the actual
115    /// number of bytes queued for this stream.
116    pub(crate) fn approx_stream_bytes(&self) -> usize {
117        *self.counter.lock().expect("poisoned")
118    }
119}
120
121impl StreamQueueReceiver {
122    /// Get the approximate number of data bytes queued for this stream.
123    ///
124    /// As messages can be enqueued at any time, the return value may be smaller than the actual
125    /// number of bytes queued for this stream.
126    pub(crate) fn approx_stream_bytes(&self) -> usize {
127        *self.counter.lock().expect("poisoned")
128    }
129}
130
131impl Sink<UnparsedRelayMsg> for StreamQueueSender {
132    type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
133
134    fn poll_ready(
135        mut self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137    ) -> Poll<std::result::Result<(), Self::Error>> {
138        self.sender.poll_ready_unpin(cx)
139    }
140
141    fn start_send(
142        mut self: Pin<&mut Self>,
143        item: UnparsedRelayMsg,
144    ) -> std::result::Result<(), Self::Error> {
145        let mut self_ = self.as_mut().project();
146
147        let stream_data_len = data_len(&item);
148
149        // This lock ensures that us sending the item and the counter increase are done
150        // "atomically", so that the receiver doesn't see the item and try to decrement the
151        // counter before we've incremented the counter, which could cause an underflow.
152        let mut counter = self_.counter.lock().expect("poisoned");
153
154        self_.sender.start_send_unpin(item)?;
155
156        *counter = counter
157            .checked_add(stream_data_len.into())
158            .expect("queue has more than `usize::MAX` bytes?!");
159
160        Ok(())
161    }
162
163    fn poll_flush(
164        mut self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166    ) -> Poll<std::result::Result<(), Self::Error>> {
167        self.sender.poll_flush_unpin(cx)
168    }
169
170    fn poll_close(
171        mut self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173    ) -> Poll<std::result::Result<(), Self::Error>> {
174        self.sender.poll_close_unpin(cx)
175    }
176}
177
178impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
179    type Error =
180        <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
181
182    fn try_send_or_return(
183        mut self: Pin<&mut Self>,
184        item: UnparsedRelayMsg,
185    ) -> Result<
186        (),
187        (
188            <Self as SinkTrySend<UnparsedRelayMsg>>::Error,
189            UnparsedRelayMsg,
190        ),
191    > {
192        let self_ = self.as_mut().project();
193
194        let stream_data_len = data_len(&item);
195
196        // See comments in `StreamQueueSender::start_send`.
197        let mut counter = self_.counter.lock().expect("poisoned");
198
199        self_.sender.try_send_or_return(item)?;
200
201        *counter = counter
202            .checked_add(stream_data_len.into())
203            .expect("queue has more than `usize::MAX` bytes?!");
204
205        Ok(())
206    }
207}
208
209impl Stream for StreamQueueReceiver {
210    type Item = UnparsedRelayMsg;
211
212    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213        let self_ = self.as_mut().project();
214
215        // This lock ensures that us receiving the item and the counter decrease are done
216        // "atomically", so that the sender doesn't send a new item and try to increase the
217        // counter before we've decreased the counter, which could cause an overflow.
218        let mut counter = self_.counter.lock().expect("poisoned");
219
220        let item = match self_.receiver.poll_next(cx) {
221            Poll::Ready(Some(x)) => x,
222            Poll::Ready(None) => return Poll::Ready(None),
223            Poll::Pending => return Poll::Pending,
224        };
225
226        let stream_data_len = data_len(&item);
227
228        if stream_data_len != 0 {
229            *counter = counter
230                .checked_sub(stream_data_len.into())
231                .expect("we've removed more bytes than we've added?!");
232        }
233
234        Poll::Ready(Some(item))
235    }
236}
237
238impl UnobtrusivePeekableStream for StreamQueueReceiver {
239    fn unobtrusive_peek_mut<'s>(
240        self: Pin<&'s mut Self>,
241    ) -> Option<&'s mut <Self as futures::Stream>::Item> {
242        self.project().receiver.unobtrusive_peek_mut()
243    }
244}
245
246/// The `length` field of the message, or 0 if not a data message.
247///
248/// If the RELAY_DATA message had an invalid length field, we just ignore the message.
249/// The receiver will find out eventually when it tries to parse the message.
250/// We could return an error here, but for now I think it's best not to behave as if this
251/// queue is performing any validation.
252///
253/// This is its own function so that all parts of the code use the same logic.
254fn data_len(item: &UnparsedRelayMsg) -> u16 {
255    item.data_len().unwrap_or(0)
256}