1use bitvec::prelude::*;
4use derive_deftly::Deftly;
5use oneshot_fused_workaround as oneshot;
6use postage::watch;
7
8use tor_cell::relaycell::msg::AnyRelayMsg;
9use tor_cell::relaycell::{RelayCellFormat, RelayCmd, StreamId, UnparsedRelayMsg, msg};
10use tor_cell::restricted_msg;
11use tor_error::internal;
12use tor_memquota::derive_deftly_template_HasMemoryCost;
13use tor_memquota::mq_queue::{self, MpscSpec};
14use tor_rtcompat::DynTimeProvider;
15
16use crate::circuit::CircHopSyncView;
17use crate::stream::cmdcheck::{AnyCmdChecker, CmdChecker, StreamStatus};
18use crate::stream::{CloseStreamBehavior, StreamComponents};
19use crate::{Error, Result};
20
21use crate::client::stream::DataStream;
23
24use crate::memquota::StreamAccount;
25use crate::stream::StreamMpscSender;
26use crate::stream::flow_ctrl::state::StreamRateLimit;
27use crate::stream::flow_ctrl::xon_xoff::reader::DrainRateRequest;
28use crate::stream::queue::StreamQueueReceiver;
29use crate::util::notify::NotifyReceiver;
30use crate::{HopLocation, HopNum};
31
32use std::mem::size_of;
33
34#[derive(Debug, Default)]
36pub(crate) struct InboundDataCmdChecker;
37
38restricted_msg! {
39 enum IncomingDataStreamMsg:RelayMsg {
41 Data, End,
43 }
44}
45
46impl CmdChecker for InboundDataCmdChecker {
47 fn check_msg(&mut self, msg: &tor_cell::relaycell::UnparsedRelayMsg) -> Result<StreamStatus> {
48 use StreamStatus::*;
49 match msg.cmd() {
50 RelayCmd::DATA => Ok(Open),
51 RelayCmd::END => Ok(Closed),
52 _ => Err(Error::StreamProto(format!(
53 "Unexpected {} on an incoming data stream!",
54 msg.cmd()
55 ))),
56 }
57 }
58
59 fn consume_checked_msg(&mut self, msg: tor_cell::relaycell::UnparsedRelayMsg) -> Result<()> {
60 let _ = msg
61 .decode::<IncomingDataStreamMsg>()
62 .map_err(|err| Error::from_bytes_err(err, "cell on half-closed stream"))?;
63 Ok(())
64 }
65}
66
67impl InboundDataCmdChecker {
68 pub(crate) fn new_connected() -> AnyCmdChecker {
74 Box::new(Self)
75 }
76}
77
78#[derive(Debug)]
88pub struct IncomingStream {
89 time_provider: DynTimeProvider,
91 request: IncomingStreamRequest,
93 components: StreamComponents,
95}
96
97impl IncomingStream {
98 pub(crate) fn new(
100 time_provider: DynTimeProvider,
101 request: IncomingStreamRequest,
102 components: StreamComponents,
103 ) -> Self {
104 Self {
105 time_provider,
106 request,
107 components,
108 }
109 }
110
111 pub fn request(&self) -> &IncomingStreamRequest {
113 &self.request
114 }
115
116 pub async fn accept_data(self, message: msg::Connected) -> Result<DataStream> {
119 let Self {
120 time_provider,
121 request,
122 components:
123 StreamComponents {
124 mut target,
125 stream_receiver,
126 xon_xoff_reader_ctrl,
127 memquota,
128 },
129 } = self;
130
131 match request {
132 IncomingStreamRequest::Begin(_) | IncomingStreamRequest::BeginDir(_) => {
133 target.send(message.into()).await?;
134 Ok(DataStream::new_connected(
135 time_provider,
136 stream_receiver,
137 xon_xoff_reader_ctrl,
138 target,
139 memquota,
140 ))
141 }
142 IncomingStreamRequest::Resolve(_) => {
143 Err(internal!("Cannot accept data on a RESOLVE stream").into())
144 }
145 }
146 }
147
148 pub async fn reject(mut self, message: msg::End) -> Result<()> {
150 let rx = self.reject_inner(CloseStreamBehavior::SendEnd(message))?;
151
152 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
153 }
154
155 fn reject_inner(
159 &mut self,
160 message: CloseStreamBehavior,
161 ) -> Result<oneshot::Receiver<Result<()>>> {
162 self.components.target.close_pending(message)
163 }
164
165 pub async fn discard(mut self) -> Result<()> {
171 let rx = self.reject_inner(CloseStreamBehavior::SendNothing)?;
172
173 rx.await.map_err(|_| Error::CircuitClosed)?.map(|_| ())
174 }
175}
176
177restricted_msg! {
183 #[derive(Clone, Debug, Deftly)]
185 #[derive_deftly(HasMemoryCost)]
186 #[non_exhaustive]
187 pub enum IncomingStreamRequest: RelayMsg {
188 Begin,
190 BeginDir,
192 Resolve,
194 }
195}
196
197type RelayCmdSet = bitvec::BitArr!(for 256);
202
203#[derive(Debug)]
206pub(crate) struct IncomingCmdChecker {
207 allow_commands: RelayCmdSet,
215}
216
217impl IncomingCmdChecker {
218 pub(crate) fn new_any(allow_commands: &[RelayCmd]) -> AnyCmdChecker {
220 let mut array = BitArray::ZERO;
221 for c in allow_commands {
222 array.set(u8::from(*c) as usize, true);
223 }
224 Box::new(Self {
225 allow_commands: array,
226 })
227 }
228}
229
230impl CmdChecker for IncomingCmdChecker {
231 fn check_msg(&mut self, msg: &UnparsedRelayMsg) -> Result<StreamStatus> {
232 if self.allow_commands[u8::from(msg.cmd()) as usize] {
233 Ok(StreamStatus::Open)
234 } else {
235 Err(Error::StreamProto(format!(
236 "Unexpected {} on incoming stream",
237 msg.cmd()
238 )))
239 }
240 }
241
242 fn consume_checked_msg(&mut self, msg: UnparsedRelayMsg) -> Result<()> {
243 let _ = msg
244 .decode::<IncomingStreamRequest>()
245 .map_err(|err| Error::from_bytes_err(err, "invalid message on incoming stream"))?;
246
247 Ok(())
248 }
249}
250
251pub trait IncomingStreamRequestFilter: Send + 'static {
259 fn disposition(
261 &mut self,
262 ctx: &IncomingStreamRequestContext<'_>,
263 circ: &CircHopSyncView<'_>,
264 ) -> Result<IncomingStreamRequestDisposition>;
265}
266
267#[derive(Clone, Debug)]
269#[non_exhaustive]
270pub enum IncomingStreamRequestDisposition {
271 Accept,
274 CloseCircuit,
276 RejectRequest(msg::End),
278}
279
280pub struct IncomingStreamRequestContext<'a> {
282 pub(crate) request: &'a IncomingStreamRequest,
284}
285impl<'a> IncomingStreamRequestContext<'a> {
286 pub fn request(&self) -> &'a IncomingStreamRequest {
288 self.request
289 }
290}
291
292#[derive(Debug, Deftly)]
294#[derive_deftly(HasMemoryCost)]
295pub(crate) struct StreamReqInfo {
296 pub(crate) req: IncomingStreamRequest,
298 pub(crate) stream_id: StreamId,
300 pub(crate) hop: Option<HopLocation>,
307 #[deftly(has_memory_cost(indirect_size = "0"))]
309 pub(crate) relay_cell_format: RelayCellFormat,
310 #[deftly(has_memory_cost(indirect_size = "0"))] pub(crate) receiver: StreamQueueReceiver,
313 #[deftly(has_memory_cost(indirect_size = "size_of::<AnyRelayMsg>()"))] pub(crate) msg_tx: StreamMpscSender<AnyRelayMsg>,
316 #[deftly(has_memory_cost(indirect_size = "0"))]
320 pub(crate) rate_limit_stream: watch::Receiver<StreamRateLimit>,
321 #[deftly(has_memory_cost(indirect_size = "0"))]
324 pub(crate) drain_rate_request_stream: NotifyReceiver<DrainRateRequest>,
325 #[deftly(has_memory_cost(indirect_size = "0"))] pub(crate) memquota: StreamAccount,
328}
329
330#[cfg(any(feature = "hs-service", feature = "relay"))]
332pub(crate) type StreamReqSender = mq_queue::Sender<StreamReqInfo, MpscSpec>;
333
334#[derive(educe::Educe)]
336#[educe(Debug)]
337#[cfg(any(feature = "hs-service", feature = "relay"))]
338pub(crate) struct IncomingStreamRequestHandler {
339 pub(crate) incoming_sender: StreamReqSender,
341 pub(crate) hop_num: Option<HopNum>,
345 pub(crate) cmd_checker: AnyCmdChecker,
347 #[educe(Debug(ignore))]
350 pub(crate) filter: Box<dyn IncomingStreamRequestFilter>,
351}
352
353#[cfg(test)]
354mod test {
355 #![allow(clippy::bool_assert_comparison)]
357 #![allow(clippy::clone_on_copy)]
358 #![allow(clippy::dbg_macro)]
359 #![allow(clippy::mixed_attributes_style)]
360 #![allow(clippy::print_stderr)]
361 #![allow(clippy::print_stdout)]
362 #![allow(clippy::single_char_pattern)]
363 #![allow(clippy::unwrap_used)]
364 #![allow(clippy::unchecked_time_subtraction)]
365 #![allow(clippy::useless_vec)]
366 #![allow(clippy::needless_pass_by_value)]
367 use tor_cell::relaycell::{
370 AnyRelayMsgOuter, RelayCellFormat,
371 msg::{Begin, BeginDir, Data, Resolve},
372 };
373
374 use super::*;
375
376 #[test]
377 fn incoming_cmd_checker() {
378 let u = |msg| {
380 let body = AnyRelayMsgOuter::new(None, msg)
381 .encode(RelayCellFormat::V0, &mut rand::rng())
382 .unwrap();
383 UnparsedRelayMsg::from_singleton_body(RelayCellFormat::V0, body).unwrap()
384 };
385 let begin = u(Begin::new("allium.example.com", 443, 0).unwrap().into());
386 let begin_dir = u(BeginDir::default().into());
387 let resolve = u(Resolve::new("allium.example.com").into());
388 let data = u(Data::new(&[1, 2, 3]).unwrap().into());
389
390 {
391 let mut cc_none = IncomingCmdChecker::new_any(&[]);
392 for m in [&begin, &begin_dir, &resolve, &data] {
393 assert!(cc_none.check_msg(m).is_err());
394 }
395 }
396
397 {
398 let mut cc_begin = IncomingCmdChecker::new_any(&[RelayCmd::BEGIN]);
399 assert_eq!(cc_begin.check_msg(&begin).unwrap(), StreamStatus::Open);
400 for m in [&begin_dir, &resolve, &data] {
401 assert!(cc_begin.check_msg(m).is_err());
402 }
403 }
404
405 {
406 let mut cc_any = IncomingCmdChecker::new_any(&[
407 RelayCmd::BEGIN,
408 RelayCmd::BEGIN_DIR,
409 RelayCmd::RESOLVE,
410 ]);
411 for m in [&begin, &begin_dir, &resolve] {
412 assert_eq!(cc_any.check_msg(m).unwrap(), StreamStatus::Open);
413 }
414 assert!(cc_any.check_msg(&data).is_err());
415 }
416 }
417}