1use crate::circuit::CircuitRxSender;
7use crate::client::circuit::padding::{PaddingController, QueuedCellPaddingInfo};
8use crate::{Error, Result};
9use tor_basic_utils::RngExt;
10use tor_cell::chancell::CircId;
11use tor_cell::chancell::msg::DestroyReason;
12
13use crate::circuit::celltypes::CreateResponse;
14use crate::client::circuit::halfcirc::HalfCirc;
15
16use oneshot_fused_workaround as oneshot;
17
18use rand::Rng;
19use rand::distr::Distribution;
20use std::collections::{HashMap, hash_map::Entry};
21use std::ops::{Deref, DerefMut};
22use std::result::Result as StdResult;
23use std::sync::Arc;
24
25#[cfg(feature = "relay")]
26use crate::relay::RelayCirc;
27
28#[derive(Copy, Clone)]
33pub(crate) enum CircIdRange {
34 #[allow(dead_code)] Low,
37 High,
39 }
43
44impl CircIdRange {
45 const fn integer_range(&self) -> std::ops::RangeInclusive<u32> {
48 const MIDPOINT: u32 = 0x8000_0000;
49
50 match self {
51 Self::Low => 1..=(MIDPOINT - 1),
53 Self::High => MIDPOINT..=u32::MAX,
54 }
55 }
56
57 pub(crate) fn is_allowed_for_peer(&self, id: CircId) -> bool {
59 !self.integer_range().contains(&id.into())
63 }
64}
65
66impl rand::distr::Distribution<CircId> for CircIdRange {
67 fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> CircId {
69 let v = rng.gen_range_checked(self.integer_range());
70 let v = v.expect("Unexpected empty range passed to gen_range_checked");
71 CircId::new(v).expect("Unexpected zero value")
72 }
73}
74
75#[derive(Debug)]
79pub(super) enum CircEnt {
80 Opening {
89 create_response_sender: oneshot::Sender<CreateResponse>,
91 cell_sender: CircuitRxSender,
94 padding_ctrl: PaddingController,
96 },
97
98 OpenOrigin {
101 cell_sender: CircuitRxSender,
104 padding_ctrl: PaddingController,
106 },
107
108 #[cfg(feature = "relay")]
111 OpenRelay {
112 _circ: Arc<RelayCirc>,
118 cell_sender: CircuitRxSender,
121 padding_ctrl: PaddingController,
123 },
124
125 DestroySent(HalfCirc),
128}
129
130pub(super) struct MutCircEnt<'a> {
136 value: &'a mut CircEnt,
138 open_count: &'a mut usize,
141 was_open: bool,
143}
144
145impl<'a> Drop for MutCircEnt<'a> {
146 fn drop(&mut self) {
147 let is_open = !matches!(self.value, CircEnt::DestroySent(_));
148 match (self.was_open, is_open) {
149 (false, true) => *self.open_count = self.open_count.saturating_add(1),
150 (true, false) => *self.open_count = self.open_count.saturating_sub(1),
151 (_, _) => (),
152 };
153 }
154}
155
156impl<'a> Deref for MutCircEnt<'a> {
157 type Target = CircEnt;
158 fn deref(&self) -> &Self::Target {
159 self.value
160 }
161}
162
163impl<'a> DerefMut for MutCircEnt<'a> {
164 fn deref_mut(&mut self) -> &mut Self::Target {
165 self.value
166 }
167}
168
169pub(super) struct CircMap {
171 m: HashMap<CircId, CircEnt>,
173 range: CircIdRange,
175 open_count: usize,
177}
178
179impl CircMap {
180 pub(super) fn new(idrange: CircIdRange) -> Self {
182 CircMap {
183 m: HashMap::new(),
184 range: idrange,
185 open_count: 0,
186 }
187 }
188
189 pub(super) fn add_origin_ent<R: Rng>(
195 &mut self,
196 rng: &mut R,
197 createdsink: oneshot::Sender<CreateResponse>,
198 sink: CircuitRxSender,
199 padding_ctrl: PaddingController,
200 ) -> Result<CircId> {
201 const N_ATTEMPTS: usize = 16;
206 let iter = self.range.sample_iter(rng).take(N_ATTEMPTS);
207 let circ_ent = CircEnt::Opening {
208 create_response_sender: createdsink,
209 cell_sender: sink,
210 padding_ctrl,
211 };
212 for id in iter {
213 let ent = self.m.entry(id);
214 if let Entry::Vacant(_) = &ent {
215 ent.or_insert(circ_ent);
216 self.open_count += 1;
217 return Ok(id);
218 }
219 }
220 Err(Error::IdRangeFull)
221 }
222
223 #[cfg(feature = "relay")]
228 pub(super) fn add_relay_ent(
229 &mut self,
230 circ_id: CircId,
231 circ: Arc<RelayCirc>,
232 sink: CircuitRxSender,
233 padding_ctrl: PaddingController,
234 ) -> StdResult<(), DestroyReason> {
235 if !self.range.is_allowed_for_peer(circ_id) {
237 return Err(DestroyReason::PROTOCOL);
238 }
239
240 let circ_ent = CircEnt::OpenRelay {
241 _circ: circ,
242 cell_sender: sink,
243 padding_ctrl,
244 };
245
246 if let Entry::Vacant(ent) = self.m.entry(circ_id) {
247 ent.insert(circ_ent);
248 self.open_count += 1;
249 Ok(())
250 } else {
251 Err(DestroyReason::PROTOCOL)
252 }
253 }
254
255 #[cfg(test)]
258 pub(super) fn put_unchecked(&mut self, id: CircId, ent: CircEnt) {
259 self.m.insert(id, ent);
260 }
261
262 pub(super) fn get_mut(&mut self, id: CircId) -> Option<MutCircEnt> {
264 let open_count = &mut self.open_count;
265 self.m.get_mut(&id).map(move |ent| MutCircEnt {
266 open_count,
267 was_open: !matches!(ent, CircEnt::DestroySent(_)),
268 value: ent,
269 })
270 }
271
272 pub(super) fn note_cell_flushed(&mut self, id: CircId, info: QueuedCellPaddingInfo) {
274 let padding_ctrl = match self.m.get(&id) {
275 Some(CircEnt::Opening { padding_ctrl, .. }) => padding_ctrl,
276 Some(CircEnt::OpenOrigin { padding_ctrl, .. }) => padding_ctrl,
277 #[cfg(feature = "relay")]
278 Some(CircEnt::OpenRelay { padding_ctrl, .. }) => padding_ctrl,
279 Some(CircEnt::DestroySent(..)) | None => return,
280 };
281 padding_ctrl.flushed_relay_cell(info);
282 }
283
284 pub(super) fn advance_from_opening(
287 &mut self,
288 id: CircId,
289 ) -> Result<oneshot::Sender<CreateResponse>> {
290 let ok = matches!(self.m.get(&id), Some(CircEnt::Opening { .. }));
295 if ok {
296 if let Some(CircEnt::Opening {
297 create_response_sender: oneshot,
298 cell_sender: sink,
299 padding_ctrl,
300 }) = self.m.remove(&id)
301 {
302 self.m.insert(
303 id,
304 CircEnt::OpenOrigin {
305 cell_sender: sink,
306 padding_ctrl,
307 },
308 );
309 Ok(oneshot)
310 } else {
311 panic!("internal error: inconsistent circuit state");
312 }
313 } else {
314 Err(Error::ChanProto(
315 "Unexpected CREATED* cell not on opening circuit".into(),
316 ))
317 }
318 }
319
320 pub(super) fn destroy_sent(&mut self, id: CircId, hs: HalfCirc) {
324 if let Some(replaced) = self.m.insert(id, CircEnt::DestroySent(hs)) {
325 if !matches!(replaced, CircEnt::DestroySent(_)) {
326 self.open_count = self.open_count.saturating_sub(1);
328 }
329 }
330 }
331
332 pub(super) fn remove(&mut self, id: CircId) -> Option<CircEnt> {
334 self.m.remove(&id).map(|removed| {
335 if !matches!(removed, CircEnt::DestroySent(_)) {
336 self.open_count = self.open_count.saturating_sub(1);
337 }
338 removed
339 })
340 }
341
342 pub(super) fn open_ent_count(&self) -> usize {
344 self.open_count
345 }
346
347 }
350
351#[cfg(test)]
352mod test {
353 #![allow(clippy::bool_assert_comparison)]
355 #![allow(clippy::clone_on_copy)]
356 #![allow(clippy::dbg_macro)]
357 #![allow(clippy::mixed_attributes_style)]
358 #![allow(clippy::print_stderr)]
359 #![allow(clippy::print_stdout)]
360 #![allow(clippy::single_char_pattern)]
361 #![allow(clippy::unwrap_used)]
362 #![allow(clippy::unchecked_time_subtraction)]
363 #![allow(clippy::useless_vec)]
364 #![allow(clippy::needless_pass_by_value)]
365 use super::*;
367 use crate::{client::circuit::padding::new_padding, fake_mpsc};
368 use tor_basic_utils::test_rng::testing_rng;
369 use tor_rtcompat::DynTimeProvider;
370
371 #[test]
372 fn circmap_basics() {
373 let mut map_low = CircMap::new(CircIdRange::Low);
374 let mut map_high = CircMap::new(CircIdRange::High);
375 let mut ids_low: Vec<CircId> = Vec::new();
376 let mut ids_high: Vec<CircId> = Vec::new();
377 let mut rng = testing_rng();
378 tor_rtcompat::test_with_one_runtime!(|runtime| async {
379 let (padding_ctrl, _padding_stream) = new_padding(DynTimeProvider::new(runtime));
380
381 assert!(map_low.get_mut(CircId::new(77).unwrap()).is_none());
382
383 for _ in 0..128 {
384 let (csnd, _) = oneshot::channel();
385 let (snd, _) = fake_mpsc(8);
386 let id_low = map_low
387 .add_origin_ent(&mut rng, csnd, snd, padding_ctrl.clone())
388 .unwrap();
389 assert!(u32::from(id_low) > 0);
390 assert!(u32::from(id_low) < 0x80000000);
391 assert!(!ids_low.contains(&id_low));
392 ids_low.push(id_low);
393
394 assert!(matches!(
395 *map_low.get_mut(id_low).unwrap(),
396 CircEnt::Opening { .. }
397 ));
398
399 let (csnd, _) = oneshot::channel();
400 let (snd, _) = fake_mpsc(8);
401 let id_high = map_high
402 .add_origin_ent(&mut rng, csnd, snd, padding_ctrl.clone())
403 .unwrap();
404 assert!(u32::from(id_high) >= 0x80000000);
405 assert!(!ids_high.contains(&id_high));
406 ids_high.push(id_high);
407 }
408
409 assert_eq!(128, map_low.open_ent_count());
411 assert_eq!(128, map_high.open_ent_count());
412
413 assert!(map_low.get_mut(ids_low[0]).is_some());
415 map_low.remove(ids_low[0]);
416 assert!(map_low.get_mut(ids_low[0]).is_none());
417 assert_eq!(127, map_low.open_ent_count());
418
419 map_low.destroy_sent(CircId::new(256).unwrap(), HalfCirc::new(1));
421 assert_eq!(127, map_low.open_ent_count());
422
423 assert!(map_high.get_mut(ids_high[0]).is_some());
427 assert!(matches!(
428 *map_high.get_mut(ids_high[0]).unwrap(),
429 CircEnt::Opening { .. }
430 ));
431 let adv = map_high.advance_from_opening(ids_high[0]);
432 assert!(adv.is_ok());
433 assert!(matches!(
434 *map_high.get_mut(ids_high[0]).unwrap(),
435 CircEnt::OpenOrigin { .. }
436 ));
437
438 let adv = map_high.advance_from_opening(ids_high[0]);
440 assert!(adv.is_err());
441
442 let adv = map_high.advance_from_opening(CircId::new(77).unwrap());
446 assert!(adv.is_err());
447 });
448 }
449}