tor_proto/util/sink_blocker/
counting_policy.rs1use nonany::NonMaxU32;
4use tor_error::{Bug, internal};
5
6#[derive(Debug, Clone, Copy)]
13pub(crate) struct CountingPolicy {
14 remaining: Option<NonMaxU32>,
18}
19
20const MAX_LIMIT: NonMaxU32 = NonMaxU32::new(u32::MAX - 1).expect("Couldn't construct MAX_LIMIT");
22
23impl CountingPolicy {
24 pub(crate) fn new_unlimited() -> Self {
26 Self { remaining: None }
27 }
28
29 pub(crate) fn new_blocked() -> Self {
31 Self {
32 remaining: Some(
33 const { NonMaxU32::new(0).expect("Couldn't construct NonMaxU32 from zero.") },
34 ),
35 }
36 }
37
38 pub(crate) fn new_limited(n: u32) -> Self {
44 Self {
45 remaining: Some(NonMaxU32::new(n).unwrap_or(MAX_LIMIT)),
46 }
47 }
48
49 fn saturating_add(&self, n: u32) -> Self {
61 match self.remaining {
62 Some(current) => Self::new_limited(current.get().saturating_add(n)),
63 None => Self::new_unlimited(),
64 }
65 }
66}
67
68impl super::Policy for CountingPolicy {
69 fn is_blocking(&self) -> bool {
70 self.remaining.is_some_and(|n| n.get() == 0)
71 }
72
73 fn take_one(&mut self) -> Result<(), Bug> {
81 match &mut self.remaining {
82 None => Ok(()),
84
85 Some(remaining) => {
86 if let Some(n) = remaining.get().checked_sub(1) {
87 *remaining = n
88 .try_into()
89 .expect("Somehow subtracting 1 made us exceed MAX_LIMIT!?");
90 Ok(())
91 } else {
92 Err(internal!(
93 "Tried to take_one() from a blocked CountingPolicy."
94 ))
95 }
96 }
97 }
98 }
99}
100
101impl<S> super::SinkBlocker<S, CountingPolicy> {
102 pub(crate) fn set_blocked(&mut self) {
104 self.update_policy(CountingPolicy::new_blocked());
105 }
106
107 pub(crate) fn set_unlimited(&mut self) {
109 self.update_policy(CountingPolicy::new_unlimited());
113 }
114
115 pub(crate) fn allow_n_additional_items(&mut self, n: u32) {
119 self.update_policy(self.policy.saturating_add(n));
123 }
124
125 pub(crate) fn is_unlimited(&self) -> bool {
127 self.policy.remaining.is_none()
128 }
129}
130
131#[cfg(test)]
132mod test {
133 #![allow(clippy::bool_assert_comparison)]
135 #![allow(clippy::clone_on_copy)]
136 #![allow(clippy::dbg_macro)]
137 #![allow(clippy::mixed_attributes_style)]
138 #![allow(clippy::print_stderr)]
139 #![allow(clippy::print_stdout)]
140 #![allow(clippy::single_char_pattern)]
141 #![allow(clippy::unwrap_used)]
142 #![allow(clippy::unchecked_time_subtraction)]
143 #![allow(clippy::useless_vec)]
144 #![allow(clippy::needless_pass_by_value)]
145 use super::*;
148 use crate::util::sink_blocker::Policy as _;
149
150 #[test]
151 fn counting_unlimited() {
152 let mut unlimited = CountingPolicy::new_unlimited();
153 assert_eq!(unlimited.is_blocking(), false);
154 assert!(unlimited.take_one().is_ok());
155 assert!(unlimited.take_one().is_ok());
156 assert_eq!(unlimited.is_blocking(), false);
157 let u2 = unlimited.saturating_add(99);
158 assert!(u2.remaining.is_none()); }
160
161 #[test]
162 fn counting_blocked() {
163 let mut blocked = CountingPolicy::new_blocked();
164 assert_eq!(blocked.is_blocking(), true);
165 assert!(blocked.take_one().is_err());
166 let mut u2 = blocked.saturating_add(99);
167 assert_eq!(u2.remaining.unwrap().get(), 99); assert_eq!(u2.is_blocking(), false);
169 assert!(u2.take_one().is_ok());
170 assert_eq!(u2.remaining.unwrap().get(), 98); }
172
173 #[test]
174 fn counting_limited() {
175 let mut limited = CountingPolicy::new_limited(2);
176 assert_eq!(limited.is_blocking(), false);
177 assert!(limited.take_one().is_ok());
178 assert_eq!(limited.is_blocking(), false);
179 assert!(limited.take_one().is_ok());
180 assert_eq!(limited.is_blocking(), true);
181 assert!(limited.take_one().is_err());
182
183 let limited = CountingPolicy::new_limited(99);
184 let lim2 = limited.saturating_add(25);
185 assert_eq!(lim2.remaining.unwrap().get(), 25 + 99);
186 let lim3 = limited.saturating_add(u32::MAX);
187 assert_eq!(lim3.remaining.unwrap(), MAX_LIMIT);
188
189 let limited = CountingPolicy::new_limited(u32::MAX);
190 assert_eq!(limited.remaining.unwrap(), MAX_LIMIT);
191 }
192}