Skip to main content

tor_proto/util/token_bucket/
writer.rs

1//! An [`AsyncWrite`] rate limiter.
2
3use std::future::Future;
4use std::num::NonZero;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use web_time_compat::{Duration, Instant};
8
9use futures::AsyncWrite;
10use futures::io::Error;
11use sync_wrapper::SyncFuture;
12use tor_rtcompat::SleepProvider;
13
14use super::bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15
16/// A rate-limited async [writer](AsyncWrite).
17///
18/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19#[derive(educe::Educe)]
20#[educe(Debug)]
21#[pin_project::pin_project]
22pub(crate) struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23    /// The token bucket.
24    bucket: TokenBucket<Instant>,
25    /// The sleep provider, for getting the current time and creating new sleep futures.
26    ///
27    /// While we use [`Instant`] for the time, we should always get the time from this
28    /// [`SleepProvider`].
29    /// For example, use [`SleepProvider::now()`],
30    /// not [`Instant::now()`](std::time::Instant::now) or
31    /// [`InstantExt::get`](web_time_compat::InstantExt::get).
32    #[educe(Debug(ignore))]
33    sleep_provider: P,
34    /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
35    wake_when_bytes_available: NonZero<u64>,
36    /// The inner writer.
37    #[educe(Debug(ignore))]
38    #[pin]
39    inner: W,
40    /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
41    #[educe(Debug(ignore))]
42    #[pin]
43    sleep_fut: Option<SyncFuture<P::SleepFuture>>,
44}
45
46impl<W, P> RateLimitedWriter<W, P>
47where
48    W: AsyncWrite,
49    P: SleepProvider,
50{
51    /// Create a new [`RateLimitedWriter`].
52    // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
53    // bucket only ever uses times from `sleep_provider`.
54    pub(crate) fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
55        let bucket_config = TokenBucketConfig {
56            rate: config.rate,
57            bucket_max: config.burst,
58        };
59        Self::from_token_bucket(
60            writer,
61            TokenBucket::new(&bucket_config, sleep_provider.now()),
62            config.wake_when_bytes_available,
63            sleep_provider,
64        )
65    }
66
67    /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
68    ///
69    /// The token bucket must have only been used with times created by `sleep_provider`.
70    #[cfg_attr(test, visibility::make(pub(super)))]
71    fn from_token_bucket(
72        writer: W,
73        bucket: TokenBucket<Instant>,
74        wake_when_bytes_available: NonZero<u64>,
75        sleep_provider: P,
76    ) -> Self {
77        Self {
78            bucket,
79            sleep_provider,
80            wake_when_bytes_available,
81            inner: writer,
82            sleep_fut: None,
83        }
84    }
85
86    /// Access the inner [`AsyncWrite`] writer.
87    pub(crate) fn inner(&self) -> &W {
88        &self.inner
89    }
90
91    /// Adjust the refill rate and burst.
92    ///
93    /// A rate and/or burst of 0 is allowed.
94    pub(crate) fn adjust(
95        self: &mut Pin<&mut Self>,
96        now: Instant,
97        config: &RateLimitedWriterConfig,
98    ) {
99        let self_ = self.as_mut().project();
100
101        // destructuring allows us to make sure we aren't forgetting to handle any fields
102        let RateLimitedWriterConfig {
103            rate,
104            burst,
105            wake_when_bytes_available,
106        } = *config;
107
108        let bucket_config = TokenBucketConfig {
109            rate,
110            bucket_max: burst,
111        };
112
113        self_.bucket.adjust(now, &bucket_config);
114        *self_.wake_when_bytes_available = wake_when_bytes_available;
115    }
116
117    /// The sleep provider.
118    ///
119    /// We don't want this to be generally accessible, only to other token bucket-related modules
120    /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
121    pub(super) fn sleep_provider(&self) -> &P {
122        &self.sleep_provider
123    }
124
125    /// Configure this writer to sleep for `duration`.
126    ///
127    /// A `duration` of `None` is interpreted as "forever".
128    ///
129    /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
130    fn register_sleep(
131        sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
132        sleep_provider: &mut P,
133        cx: &mut Context<'_>,
134        duration: Option<Duration>,
135    ) -> Poll<()> {
136        match duration {
137            None => {
138                sleep_fut.as_mut().set(None);
139                Poll::Pending
140            }
141            Some(duration) => {
142                debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
143                sleep_fut
144                    .as_mut()
145                    .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
146                sleep_fut
147                    .as_mut()
148                    .as_pin_mut()
149                    .expect("but we just set it to `Some`?!")
150                    .poll(cx)
151            }
152        }
153    }
154}
155
156impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
157where
158    W: AsyncWrite,
159    P: SleepProvider,
160{
161    fn poll_write(
162        mut self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164        mut buf: &[u8],
165    ) -> Poll<Result<usize, Error>> {
166        let mut self_ = self.as_mut().project();
167
168        // this should be optimized to a no-op on at least x86-64
169        fn to_u64(x: usize) -> u64 {
170            x.try_into().expect("failed usize to u64 conversion")
171        }
172
173        // for an empty buffer, just defer to the inner writer's impl
174        if buf.is_empty() {
175            return self_.inner.poll_write(cx, buf);
176        }
177
178        let now = self_.sleep_provider.now();
179
180        // refill the bucket and attempt to claim all of the bytes
181        self_.bucket.refill(now);
182        let claim = self_.bucket.claim(to_u64(buf.len()));
183
184        let mut claim = match claim {
185            // claim was successful
186            Ok(x) => x,
187            // not enough tokens, so let's use a smaller buffer
188            Err(e) => {
189                let available = e.available_tokens();
190
191                // need to drop the old claim so that we can access the token bucket again
192                drop(claim);
193
194                // if no tokens in bucket, we must sleep
195                if available == 0 {
196                    // number of tokens we'll wait for
197                    let wake_at_tokens = to_u64(buf.len());
198
199                    // If the user wants to write X tokens, we don't necessarily want to sleep until
200                    // we have room for X tokens. We also don't want to wake every time that a
201                    // single byte can be written. We allow the user to configure this threshold
202                    // with `RateLimitedWriterConfig::wake_when_bytes_available`.
203                    let wake_at_tokens =
204                        std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
205
206                    // max number of tokens the bucket can hold
207                    let bucket_max = self_.bucket.max();
208
209                    // how long to sleep for; `None` indicates to sleep forever
210                    let sleep_for = if bucket_max == 0 {
211                        // bucket can't hold any tokens, so sleep forever
212                        None
213                    } else {
214                        // if the bucket has a max of X tokens, we should never try to wait for >X
215                        // tokens
216                        let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
217
218                        // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
219                        // want
220                        debug_assert!(wake_at_tokens > 0);
221
222                        let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
223                        let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
224
225                        match sleep_for {
226                            Ok(x) => Some(x),
227                            Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
228                                panic!(
229                                    "exceeds max tokens, but we took the max into account above"
230                                );
231                            }
232                            // we aren't refilling, so sleep forever
233                            Err(NeverEnoughTokensError::ZeroRate) => None,
234                            // too far in the future to be represented, so sleep forever
235                            Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
236                        }
237                    };
238
239                    // configure the sleep future and poll it to register
240                    let poll = Self::register_sleep(
241                        &mut self_.sleep_fut,
242                        self_.sleep_provider,
243                        cx,
244                        sleep_for,
245                    );
246                    return match poll {
247                        // wait for the sleep to finish
248                        Poll::Pending => Poll::Pending,
249                        // The sleep is already ready?! A recursive call here isn't great, but
250                        // there's not much else we can do here. Hopefully this second `poll_write`
251                        // will succeed since we should now have enough tokens.
252                        Poll::Ready(()) => self.poll_write(cx, buf),
253                    };
254                }
255
256                /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
257                // This is a separate function to ensure we don't accidentally try to convert a
258                // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
259                // sense.
260                fn to_usize_saturating(x: u64) -> usize {
261                    x.try_into().unwrap_or(usize::MAX)
262                }
263
264                // There are tokens, so try to write as many as are available.
265                let available_usize = to_usize_saturating(available);
266                buf = &buf[0..available_usize];
267                self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
268                    panic!(
269                        "bucket has {available} tokens available, but can't claim {}?",
270                        buf.len(),
271                    )
272                })
273            }
274        };
275
276        let rv = self_.inner.poll_write(cx, buf);
277
278        match rv {
279            // no bytes were written, so discard the claim
280            Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
281            // `x` bytes were written, so only commit those tokens
282            Poll::Ready(Ok(x)) => {
283                if x <= buf.len() {
284                    claim
285                        .reduce(to_u64(x))
286                        .expect("can't commit fewer tokens?!");
287                    claim.commit();
288                } else {
289                    cfg_if::cfg_if! {
290                        if #[cfg(debug_assertions)] {
291                            panic!(
292                                "Writer is claiming it wrote more bytes {x} than we gave it {}",
293                                buf.len(),
294                            );
295                        } else {
296                            // the best we can do is to just claim the original amount
297                            claim.commit();
298                        }
299                    }
300                }
301            }
302        };
303
304        rv
305    }
306
307    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
308        self.project().inner.poll_flush(cx)
309    }
310
311    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
312        // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
313        // closed and will continue to accept bytes even after being closed, so we must continue to
314        // apply rate limiting even after being closed
315        self.project().inner.poll_close(cx)
316    }
317}
318
319/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
320/// everywhere.
321#[cfg(feature = "tokio")]
322mod tokio_impl {
323    use super::*;
324
325    use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
326    use tokio_util::compat::FuturesAsyncWriteCompatExt;
327
328    use std::io::Result as IoResult;
329
330    impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
331    where
332        W: AsyncWrite,
333        P: SleepProvider,
334    {
335        fn poll_write(
336            self: Pin<&mut Self>,
337            cx: &mut Context<'_>,
338            buf: &[u8],
339        ) -> Poll<IoResult<usize>> {
340            TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
341        }
342
343        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
344            TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
345        }
346
347        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
348            TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
349        }
350    }
351}
352
353/// The refill rate and burst for a [`RateLimitedWriter`].
354#[derive(Clone, Debug)]
355pub(crate) struct RateLimitedWriterConfig {
356    /// The refill rate in bytes/second.
357    pub(crate) rate: u64,
358    /// The "burst" in bytes.
359    pub(crate) burst: u64,
360    /// When polled, block until at most this many bytes are available.
361    ///
362    /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
363    /// larger.
364    ///
365    /// For example if a user attempts to write a large buffer, we usually don't want to block until
366    /// the entire buffer can be written. We'd prefer several partial writes to a single large
367    /// write. So instead of blocking until the entire buffer can be written, we only block until
368    /// at most this many bytes are available.
369    pub(crate) wake_when_bytes_available: NonZero<u64>,
370}
371
372#[cfg(test)]
373mod test {
374    #![allow(clippy::unwrap_used)]
375
376    use super::*;
377
378    use futures::{AsyncWriteExt, FutureExt};
379    use tor_rtcompat::SpawnExt;
380
381    #[test]
382    fn writer() {
383        tor_rtmock::MockRuntime::test_with_various(|rt| async move {
384            let start = rt.now();
385
386            // increases 10 tokens/second (one every 100 ms)
387            let config = TokenBucketConfig {
388                rate: 10,
389                bucket_max: 100,
390            };
391            let mut tb = TokenBucket::new(&config, start);
392            // drain the bucket
393            tb.drain(100).unwrap();
394
395            let wake_when_bytes_available = NonZero::new(15).unwrap();
396
397            let mut writer = Vec::new();
398            let mut writer = RateLimitedWriter::from_token_bucket(
399                &mut writer,
400                tb,
401                wake_when_bytes_available,
402                rt.clone(),
403            );
404
405            // drive time forward from 0 to 20_000 ms in 50 ms intervals
406            let rt_clone = rt.clone();
407            rt.spawn(async move {
408                for _ in 0..400 {
409                    rt_clone.progress_until_stalled().await;
410                    rt_clone.advance_by(Duration::from_millis(50)).await;
411                }
412            })
413            .unwrap();
414
415            // try writing 60 bytes, which sleeps until we can write at least 15 of them
416            assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
417            assert_eq!(1500, rt.now().duration_since(start).as_millis());
418
419            // wait 2 seconds
420            rt.sleep(Duration::from_millis(2000)).await;
421
422            // ensure that we can write immediately, and that we can write
423            // 2000 ms / (100 ms/token) = 20 bytes
424            assert_eq!(
425                Some(20),
426                writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
427            );
428        });
429    }
430
431    /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
432    #[test]
433    fn rate_burst_zero() {
434        let configs = [
435            // non-zero rate, zero max
436            TokenBucketConfig {
437                rate: 10,
438                bucket_max: 0,
439            },
440            // zero rate, non-zero max
441            TokenBucketConfig {
442                rate: 0,
443                bucket_max: 10,
444            },
445            // zero rate, zero max
446            TokenBucketConfig {
447                rate: 0,
448                bucket_max: 0,
449            },
450        ];
451        for config in configs {
452            tor_rtmock::MockRuntime::test_with_various(|rt| {
453                let config = config.clone();
454                async move {
455                    // an empty token bucket
456                    let mut tb = TokenBucket::new(&config, rt.now());
457                    tb.drain(tb.max()).unwrap();
458                    assert!(tb.is_empty());
459
460                    let wake_when_bytes_available = NonZero::new(2).unwrap();
461
462                    let mut writer = Vec::new();
463                    let mut writer = RateLimitedWriter::from_token_bucket(
464                        &mut writer,
465                        tb,
466                        wake_when_bytes_available,
467                        rt.clone(),
468                    );
469
470                    // drive time forward from 0 to 10_000 ms in 100 ms intervals
471                    let rt_clone = rt.clone();
472                    rt.spawn(async move {
473                        for _ in 0..100 {
474                            rt_clone.progress_until_stalled().await;
475                            rt_clone.advance_by(Duration::from_millis(100)).await;
476                        }
477                    })
478                    .unwrap();
479
480                    // ensure that a write returns `Pending`
481                    assert_eq!(
482                        None,
483                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
484                    );
485
486                    // wait 5 seconds
487                    rt.sleep(Duration::from_millis(5000)).await;
488
489                    // ensure that a write still returns `Pending`
490                    assert_eq!(
491                        None,
492                        writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
493                    );
494                }
495            });
496        }
497    }
498}