tor_circmgr/timeouts/
estimator.rs1use crate::TimeoutStateHandle;
4use crate::timeouts::{
5 Action, TimeoutEstimator,
6 pareto::{ParetoTimeoutEstimator, ParetoTimeoutState},
7 readonly::ReadonlyTimeoutEstimator,
8};
9use std::sync::Mutex;
10use std::time::Duration;
11use tor_error::warn_report;
12use tor_netdir::params::NetParameters;
13use tracing::{debug, warn};
14
15pub(crate) struct Estimator {
18 inner: Mutex<Box<dyn TimeoutEstimator + Send + 'static>>,
20}
21
22impl Estimator {
23 #[cfg(test)]
25 pub(crate) fn new(est: impl TimeoutEstimator + Send + 'static) -> Self {
26 Self {
27 inner: Mutex::new(Box::new(est)),
28 }
29 }
30
31 pub(crate) fn from_storage(storage: &TimeoutStateHandle) -> Self {
34 let (_, est) = estimator_from_storage(storage);
35 Self {
36 inner: Mutex::new(est),
37 }
38 }
39
40 pub(crate) fn upgrade_to_owning_storage(&self, storage: &TimeoutStateHandle) {
43 let (readonly, est) = estimator_from_storage(storage);
44 if readonly {
45 warn!("Unable to upgrade to owned persistent storage.");
46 return;
47 }
48 *self.inner.lock().expect("Timeout estimator lock poisoned") = est;
49 }
50
51 pub(crate) fn reload_readonly_from_storage(&self, storage: &TimeoutStateHandle) {
54 if let Ok(Some(v)) = storage.load() {
55 let est = ReadonlyTimeoutEstimator::from_state(&v);
56 *self.inner.lock().expect("Timeout estimator lock poisoned") = Box::new(est);
57 } else {
58 debug!("Unable to reload timeout state.");
59 }
60 }
61
62 pub(crate) fn note_hop_completed(&self, hop: u8, delay: Duration, is_last: bool) {
71 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
72
73 inner.note_hop_completed(hop, delay, is_last);
74 }
75
76 pub(crate) fn note_circ_timeout(&self, hop: u8, delay: Duration) {
84 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
85 inner.note_circ_timeout(hop, delay);
86 }
87
88 pub(crate) fn timeouts(&self, action: &Action) -> (Duration, Duration) {
98 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
99
100 inner.timeouts(action)
101 }
102
103 pub(crate) fn learning_timeouts(&self) -> bool {
106 let inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
107 inner.learning_timeouts()
108 }
109
110 pub(crate) fn update_params(&self, params: &NetParameters) {
113 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
114 inner.update_params(params);
115 }
116
117 pub(crate) fn save_state(&self, storage: &TimeoutStateHandle) -> crate::Result<()> {
119 let state = {
120 let mut inner = self.inner.lock().expect("Timeout estimator lock poisoned.");
121 inner.build_state()
122 };
123 if let Some(state) = state {
124 storage.store(&state)?;
125 }
126 Ok(())
127 }
128}
129
130impl tor_proto::client::circuit::TimeoutEstimator for Estimator {
131 fn circuit_build_timeout(&self, length: usize) -> Duration {
132 let (timeout, _abandon) = self
133 .inner
134 .lock()
135 .expect("poisoned lock")
136 .timeouts(&Action::BuildCircuit { length });
137
138 timeout
139 }
140}
141
142fn estimator_from_storage(
147 storage: &TimeoutStateHandle,
148) -> (bool, Box<dyn TimeoutEstimator + Send + 'static>) {
149 let state = match storage.load() {
150 Ok(Some(v)) => v,
151 Ok(None) => ParetoTimeoutState::default(),
152 Err(e) => {
153 warn_report!(e, "Unable to load timeout state");
154 return (true, Box::new(ReadonlyTimeoutEstimator::new()));
155 }
156 };
157
158 if storage.can_store() {
159 (false, Box::new(ParetoTimeoutEstimator::from_state(state)))
161 } else {
162 (true, Box::new(ReadonlyTimeoutEstimator::from_state(&state)))
163 }
164}
165
166#[cfg(test)]
167mod test {
168 #![allow(clippy::bool_assert_comparison)]
170 #![allow(clippy::clone_on_copy)]
171 #![allow(clippy::dbg_macro)]
172 #![allow(clippy::mixed_attributes_style)]
173 #![allow(clippy::print_stderr)]
174 #![allow(clippy::print_stdout)]
175 #![allow(clippy::single_char_pattern)]
176 #![allow(clippy::unwrap_used)]
177 #![allow(clippy::unchecked_time_subtraction)]
178 #![allow(clippy::useless_vec)]
179 #![allow(clippy::needless_pass_by_value)]
180 use super::*;
182 use tor_persist::StateMgr;
183
184 #[test]
185 fn load_estimator() {
186 let params = NetParameters::default();
187
188 let storage = tor_persist::TestingStateMgr::new();
190 assert!(storage.try_lock().unwrap().held());
191 let handle = storage.clone().create_handle("paretorama");
192
193 let est = Estimator::from_storage(&handle);
194 assert!(est.learning_timeouts());
195 est.save_state(&handle).unwrap();
196
197 let storage2 = storage.new_manager();
200 assert!(!storage2.try_lock().unwrap().held());
201 let handle2 = storage2.clone().create_handle("paretorama");
202
203 let est2 = Estimator::from_storage(&handle2);
204 assert!(!est2.learning_timeouts());
205
206 est.update_params(¶ms);
207 est2.update_params(¶ms);
208
209 let act = Action::BuildCircuit { length: 3 };
211 assert_eq!(
212 est.timeouts(&act),
213 (Duration::from_secs(60), Duration::from_secs(60))
214 );
215 assert_eq!(
216 est2.timeouts(&act),
217 (Duration::from_secs(60), Duration::from_secs(60))
218 );
219
220 for _ in 0..500 {
222 est.note_hop_completed(2, Duration::from_secs(7), true);
223 est.note_hop_completed(2, Duration::from_secs(2), true);
224 est2.note_hop_completed(2, Duration::from_secs(4), true);
225 }
226 assert!(!est.learning_timeouts());
227
228 est.save_state(&handle).unwrap();
230 let to_1 = est.timeouts(&act);
231 assert_ne!(
232 est.timeouts(&act),
233 (Duration::from_secs(60), Duration::from_secs(60))
234 );
235 assert_eq!(
236 est2.timeouts(&act),
237 (Duration::from_secs(60), Duration::from_secs(60))
238 );
239 est2.reload_readonly_from_storage(&handle2);
240 let to_1_secs = to_1.0.as_secs_f64();
241 let timeouts = est2.timeouts(&act);
242 assert!((timeouts.0.as_secs_f64() - to_1_secs).abs() < 0.001);
243 assert!((timeouts.1.as_secs_f64() - to_1_secs).abs() < 0.001);
244
245 drop(est);
246 drop(handle);
247 drop(storage);
248
249 assert!(storage2.try_lock().unwrap().held());
251 est2.upgrade_to_owning_storage(&handle2);
252 let to_2 = est2.timeouts(&act);
253 assert!(to_2.0 > to_1.0 - Duration::from_secs(1));
255 assert!(to_2.0 < to_1.0 + Duration::from_secs(1));
256 for _ in 0..200 {
258 est2.note_hop_completed(2, Duration::from_secs(1), true);
259 }
260 let to_3 = est2.timeouts(&act);
261 assert!(to_3.0 < to_2.0);
262 }
263}