tor_proto/stream/
queue.rs1use std::fmt::Debug;
14use std::pin::Pin;
15use std::sync::{Arc, Mutex};
16use std::task::{Context, Poll};
17
18use futures::{Sink, SinkExt, Stream};
19use tor_async_utils::SinkTrySend;
20use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
21use tor_async_utils::stream_peek::StreamUnobtrusivePeeker;
22use tor_cell::relaycell::UnparsedRelayMsg;
23use tor_memquota::mq_queue::{self, ChannelSpec, MpscSpec, MpscUnboundedSpec};
24use tor_rtcompat::DynTimeProvider;
25
26use crate::memquota::{SpecificAccount, StreamAccount};
27
28#[cfg(feature = "flowctl-cc")]
31type Spec = MpscUnboundedSpec;
33#[cfg(not(feature = "flowctl-cc"))]
34type Spec = MpscSpec;
36
37pub(crate) fn stream_queue(
39 #[cfg(not(feature = "flowctl-cc"))] size: usize,
40 memquota: &StreamAccount,
41 time_prov: &DynTimeProvider,
42) -> Result<(StreamQueueSender, StreamQueueReceiver), tor_memquota::Error> {
43 let (sender, receiver) = {
44 cfg_if::cfg_if! {
45 if #[cfg(not(feature = "flowctl-cc"))] {
46 MpscSpec::new(size).new_mq(time_prov.clone(), memquota.as_raw_account())?
47 } else {
48 MpscUnboundedSpec::new().new_mq(time_prov.clone(), memquota.as_raw_account())?
49 }
50 }
51 };
52
53 let receiver = StreamUnobtrusivePeeker::new(receiver);
54 let counter = Arc::new(Mutex::new(0));
55 Ok((
56 StreamQueueSender {
57 sender,
58 counter: Arc::clone(&counter),
59 },
60 StreamQueueReceiver { receiver, counter },
61 ))
62}
63
64#[cfg(test)]
67pub(crate) fn fake_stream_queue(
68 #[cfg(not(feature = "flowctl-cc"))] size: usize,
69) -> (StreamQueueSender, StreamQueueReceiver) {
70 stream_queue(
76 #[cfg(not(feature = "flowctl-cc"))]
77 size,
78 &StreamAccount::new_noop(),
79 &DynTimeProvider::new(tor_rtmock::MockRuntime::default()),
80 )
81 .expect("create fake stream queue")
82}
83
84#[derive(Debug)]
86#[pin_project::pin_project]
87pub(crate) struct StreamQueueSender {
88 #[pin]
90 sender: mq_queue::Sender<UnparsedRelayMsg, Spec>,
91 counter: Arc<Mutex<usize>>,
93}
94
95#[derive(Debug)]
97#[pin_project::pin_project]
98pub(crate) struct StreamQueueReceiver {
99 #[pin]
106 receiver: StreamUnobtrusivePeeker<mq_queue::Receiver<UnparsedRelayMsg, Spec>>,
107 counter: Arc<Mutex<usize>>,
109}
110
111impl StreamQueueSender {
112 pub(crate) fn approx_stream_bytes(&self) -> usize {
117 *self.counter.lock().expect("poisoned")
118 }
119}
120
121impl StreamQueueReceiver {
122 pub(crate) fn approx_stream_bytes(&self) -> usize {
127 *self.counter.lock().expect("poisoned")
128 }
129}
130
131impl Sink<UnparsedRelayMsg> for StreamQueueSender {
132 type Error = <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as Sink<UnparsedRelayMsg>>::Error;
133
134 fn poll_ready(
135 mut self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 ) -> Poll<std::result::Result<(), Self::Error>> {
138 self.sender.poll_ready_unpin(cx)
139 }
140
141 fn start_send(
142 mut self: Pin<&mut Self>,
143 item: UnparsedRelayMsg,
144 ) -> std::result::Result<(), Self::Error> {
145 let mut self_ = self.as_mut().project();
146
147 let stream_data_len = data_len(&item);
148
149 let mut counter = self_.counter.lock().expect("poisoned");
153
154 self_.sender.start_send_unpin(item)?;
155
156 *counter = counter
157 .checked_add(stream_data_len.into())
158 .expect("queue has more than `usize::MAX` bytes?!");
159
160 Ok(())
161 }
162
163 fn poll_flush(
164 mut self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 ) -> Poll<std::result::Result<(), Self::Error>> {
167 self.sender.poll_flush_unpin(cx)
168 }
169
170 fn poll_close(
171 mut self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 ) -> Poll<std::result::Result<(), Self::Error>> {
174 self.sender.poll_close_unpin(cx)
175 }
176}
177
178impl SinkTrySend<UnparsedRelayMsg> for StreamQueueSender {
179 type Error =
180 <mq_queue::Sender<UnparsedRelayMsg, MpscSpec> as SinkTrySend<UnparsedRelayMsg>>::Error;
181
182 fn try_send_or_return(
183 mut self: Pin<&mut Self>,
184 item: UnparsedRelayMsg,
185 ) -> Result<
186 (),
187 (
188 <Self as SinkTrySend<UnparsedRelayMsg>>::Error,
189 UnparsedRelayMsg,
190 ),
191 > {
192 let self_ = self.as_mut().project();
193
194 let stream_data_len = data_len(&item);
195
196 let mut counter = self_.counter.lock().expect("poisoned");
198
199 self_.sender.try_send_or_return(item)?;
200
201 *counter = counter
202 .checked_add(stream_data_len.into())
203 .expect("queue has more than `usize::MAX` bytes?!");
204
205 Ok(())
206 }
207}
208
209impl Stream for StreamQueueReceiver {
210 type Item = UnparsedRelayMsg;
211
212 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
213 let self_ = self.as_mut().project();
214
215 let mut counter = self_.counter.lock().expect("poisoned");
219
220 let item = match self_.receiver.poll_next(cx) {
221 Poll::Ready(Some(x)) => x,
222 Poll::Ready(None) => return Poll::Ready(None),
223 Poll::Pending => return Poll::Pending,
224 };
225
226 let stream_data_len = data_len(&item);
227
228 if stream_data_len != 0 {
229 *counter = counter
230 .checked_sub(stream_data_len.into())
231 .expect("we've removed more bytes than we've added?!");
232 }
233
234 Poll::Ready(Some(item))
235 }
236}
237
238impl UnobtrusivePeekableStream for StreamQueueReceiver {
239 fn unobtrusive_peek_mut<'s>(
240 self: Pin<&'s mut Self>,
241 ) -> Option<&'s mut <Self as futures::Stream>::Item> {
242 self.project().receiver.unobtrusive_peek_mut()
243 }
244}
245
246fn data_len(item: &UnparsedRelayMsg) -> u16 {
255 item.data_len().unwrap_or(0)
256}