1use std::cmp::{max, min};
4use std::collections::VecDeque;
5use std::sync::atomic::{AtomicBool, Ordering};
6use web_time_compat::{Duration, Instant};
7
8use super::params::RoundTripEstimatorParams;
9use super::{CongestionWindow, State};
10
11use thiserror::Error;
12use tor_error::{ErrorKind, HasKind};
13
14#[derive(Error, Debug, Clone)]
16#[non_exhaustive]
17pub(crate) enum Error {
18 #[error("Informed of a SENDME we weren't expecting")]
21 MismatchedEstimationCall,
22}
23
24impl HasKind for Error {
25 fn kind(&self) -> ErrorKind {
26 use Error as E;
27 match self {
28 E::MismatchedEstimationCall => ErrorKind::TorProtocolViolation,
29 }
30 }
31}
32
33#[derive(Debug)]
35#[allow(dead_code)]
36pub(crate) struct RoundtripTimeEstimator {
37 sendme_expected_from: VecDeque<Instant>,
45 last_rtt: Option<Duration>,
49 ewma_rtt: Option<Duration>,
53 min_rtt: Option<Duration>,
57 max_rtt: Option<Duration>,
61 params: RoundTripEstimatorParams,
63 clock_stalled: AtomicBool,
66}
67
68#[allow(dead_code)]
69impl RoundtripTimeEstimator {
70 pub(crate) fn new(params: &RoundTripEstimatorParams) -> Self {
73 Self {
74 sendme_expected_from: Default::default(),
75 last_rtt: None,
76 ewma_rtt: None,
77 min_rtt: None,
78 max_rtt: None,
79 params: params.clone(),
80 clock_stalled: AtomicBool::default(),
81 }
82 }
83
84 pub(crate) fn is_ready(&self) -> bool {
86 !self.clock_stalled() && self.last_rtt.is_some()
87 }
88
89 pub(crate) fn clock_stalled(&self) -> bool {
91 self.clock_stalled.load(Ordering::SeqCst)
92 }
93
94 pub(crate) fn ewma_rtt_usec(&self) -> Option<u32> {
96 self.ewma_rtt
97 .map(|rtt| u32::try_from(rtt.as_micros()).ok().unwrap_or(u32::MAX))
98 }
99
100 pub(crate) fn min_rtt_usec(&self) -> Option<u32> {
102 self.min_rtt
103 .map(|rtt| u32::try_from(rtt.as_micros()).ok().unwrap_or(u32::MAX))
104 }
105
106 pub(crate) fn max_rtt_usec(&self) -> Option<u32> {
108 self.max_rtt
109 .map(|rtt| u32::try_from(rtt.as_micros()).ok().unwrap_or(u32::MAX))
110 }
111
112 pub(crate) fn expect_sendme(&mut self, now: Instant) {
115 self.sendme_expected_from.push_back(now);
116 }
117
118 fn can_crosscheck_with_current_estimate(&self, in_slow_start: bool) -> bool {
124 !in_slow_start && self.ewma_rtt.is_some()
128 }
129
130 fn is_clock_stalled(&self, raw_rtt: Duration, in_slow_start: bool) -> bool {
133 if raw_rtt.is_zero() {
134 self.clock_stalled.store(true, Ordering::SeqCst);
136 true
137 } else if self.can_crosscheck_with_current_estimate(in_slow_start) {
138 let ewma_rtt = self
139 .ewma_rtt
140 .expect("ewma_rtt was not checked by can_crosscheck_with_current_estimate?!");
141
142 const DELTA_DISCREPANCY_RATIO_MAX: u32 = 5000;
146 if raw_rtt > ewma_rtt * DELTA_DISCREPANCY_RATIO_MAX {
148 true
155 } else if ewma_rtt > raw_rtt * DELTA_DISCREPANCY_RATIO_MAX {
156 self.clock_stalled.load(Ordering::SeqCst)
159 } else {
160 self.clock_stalled.store(false, Ordering::SeqCst);
162 false
163 }
164 } else {
165 false
167 }
168 }
169
170 pub(crate) fn update(
184 &mut self,
185 now: Instant,
186 state: &State,
187 cwnd: &CongestionWindow,
188 ) -> Result<(), Error> {
189 let data_sent_at = self
190 .sendme_expected_from
191 .pop_front()
192 .ok_or(Error::MismatchedEstimationCall)?;
193 let raw_rtt = now.saturating_duration_since(data_sent_at);
194
195 if self.is_clock_stalled(raw_rtt, state.in_slow_start()) {
196 return Ok(());
197 }
198
199 self.max_rtt = self.max_rtt.max(Some(raw_rtt));
200 self.last_rtt = Some(raw_rtt);
201
202 let ewma_n = u64::from(if state.in_slow_start() {
204 self.params.ewma_ss_max()
205 } else {
206 min(
207 (cwnd.update_rate(state) * (self.params.ewma_cwnd_pct().as_percent())) / 100,
208 self.params.ewma_max(),
209 )
210 });
211 let ewma_n = max(ewma_n, 2);
212
213 let raw_rtt_usec = raw_rtt.as_micros() as u64;
215 let prev_ewma_rtt_usec = self.ewma_rtt.map(|rtt| rtt.as_micros() as u64);
216
217 let new_ewma_rtt_usec = match prev_ewma_rtt_usec {
225 None => raw_rtt_usec,
226 Some(prev_ewma_rtt_usec) => {
227 ((raw_rtt_usec * 2) + ((ewma_n - 1) * prev_ewma_rtt_usec)) / (ewma_n + 1)
228 }
229 };
230 let ewma_rtt = Duration::from_micros(new_ewma_rtt_usec);
231 self.ewma_rtt = Some(ewma_rtt);
232
233 let Some(min_rtt) = self.min_rtt else {
234 self.min_rtt = self.ewma_rtt;
235 return Ok(());
236 };
237
238 if cwnd.get() == cwnd.min() && !state.in_slow_start() {
239 let max = max(ewma_rtt, min_rtt).as_micros() as u64;
241 let min = min(ewma_rtt, min_rtt).as_micros() as u64;
242 let rtt_reset_pct = u64::from(self.params.rtt_reset_pct().as_percent());
243 let min_rtt = Duration::from_micros(
244 (rtt_reset_pct * max / 100) + (100 - rtt_reset_pct) * min / 100,
245 );
246
247 self.min_rtt = Some(min_rtt);
248 } else if self.ewma_rtt < self.min_rtt {
249 self.min_rtt = self.ewma_rtt;
250 }
251
252 Ok(())
253 }
254}
255
256#[cfg(test)]
257mod test {
258 #![allow(clippy::bool_assert_comparison)]
260 #![allow(clippy::clone_on_copy)]
261 #![allow(clippy::dbg_macro)]
262 #![allow(clippy::mixed_attributes_style)]
263 #![allow(clippy::print_stderr)]
264 #![allow(clippy::print_stdout)]
265 #![allow(clippy::single_char_pattern)]
266 #![allow(clippy::unwrap_used)]
267 #![allow(clippy::unchecked_time_subtraction)]
268 #![allow(clippy::useless_vec)]
269 #![allow(clippy::needless_pass_by_value)]
270 use web_time_compat::{Duration, Instant, InstantExt};
273
274 use crate::congestion::test_utils::{new_cwnd, new_rtt_estimator};
275
276 use super::*;
277
278 #[derive(Debug)]
279 struct RttTestSample {
280 sent_usec_in: u64,
281 sendme_received_usec_in: u64,
282 cwnd_in: u32,
283 ss_in: bool,
284 last_rtt_usec_out: u64,
285 ewma_rtt_usec_out: u64,
286 min_rtt_usec_out: u64,
287 }
288
289 impl From<[u64; 7]> for RttTestSample {
290 fn from(arr: [u64; 7]) -> Self {
291 Self {
292 sent_usec_in: arr[0],
293 sendme_received_usec_in: arr[1],
294 cwnd_in: arr[2] as u32,
295 ss_in: arr[3] == 1,
296 last_rtt_usec_out: arr[4],
297 ewma_rtt_usec_out: arr[5],
298 min_rtt_usec_out: arr[6],
299 }
300 }
301 }
302 impl RttTestSample {
303 fn test(&self, estimator: &mut RoundtripTimeEstimator, start: Instant) {
304 let state = if self.ss_in {
305 State::SlowStart
306 } else {
307 State::Steady
308 };
309 let mut cwnd = new_cwnd();
310 cwnd.set(self.cwnd_in);
311 let sent = start + Duration::from_micros(self.sent_usec_in);
312 let sendme_received = start + Duration::from_micros(self.sendme_received_usec_in);
313
314 estimator.expect_sendme(sent);
315 estimator
316 .update(sendme_received, &state, &cwnd)
317 .expect("Error on RTT update");
318 assert_eq!(
319 estimator.last_rtt,
320 Some(Duration::from_micros(self.last_rtt_usec_out))
321 );
322 assert_eq!(
323 estimator.ewma_rtt,
324 Some(Duration::from_micros(self.ewma_rtt_usec_out))
325 );
326 assert_eq!(
327 estimator.min_rtt,
328 Some(Duration::from_micros(self.min_rtt_usec_out))
329 );
330 }
331 }
332
333 #[test]
334 fn test_vectors() {
335 let mut rtt = new_rtt_estimator();
336 let now = Instant::get();
337 let vectors = [
339 [100000, 200000, 124, 1, 100000, 100000, 100000],
340 [200000, 300000, 124, 1, 100000, 100000, 100000],
341 [350000, 500000, 124, 1, 150000, 133333, 100000],
342 [500000, 550000, 124, 1, 50000, 77777, 77777],
343 [600000, 700000, 124, 1, 100000, 92592, 77777],
344 [700000, 750000, 124, 1, 50000, 64197, 64197],
345 [750000, 875000, 124, 0, 125000, 104732, 104732],
346 [875000, 900000, 124, 0, 25000, 51577, 104732],
347 [900000, 950000, 200, 0, 50000, 50525, 50525],
348 ];
349 for vect in vectors {
350 let vect = RttTestSample::from(vect);
351 eprintln!("Testing vector: {:?}", vect);
352 vect.test(&mut rtt, now);
353 }
354 }
355}