tor_proto/util/token_bucket/writer.rs
1//! An [`AsyncWrite`] rate limiter.
2
3use std::future::Future;
4use std::num::NonZero;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use web_time_compat::{Duration, Instant};
8
9use futures::AsyncWrite;
10use futures::io::Error;
11use sync_wrapper::SyncFuture;
12use tor_rtcompat::SleepProvider;
13
14use super::bucket::{NeverEnoughTokensError, TokenBucket, TokenBucketConfig};
15
16/// A rate-limited async [writer](AsyncWrite).
17///
18/// This can be used as a wrapper around an existing [`AsyncWrite`] writer.
19#[derive(educe::Educe)]
20#[educe(Debug)]
21#[pin_project::pin_project]
22pub(crate) struct RateLimitedWriter<W: AsyncWrite, P: SleepProvider> {
23 /// The token bucket.
24 bucket: TokenBucket<Instant>,
25 /// The sleep provider, for getting the current time and creating new sleep futures.
26 ///
27 /// While we use [`Instant`] for the time, we should always get the time from this
28 /// [`SleepProvider`].
29 /// For example, use [`SleepProvider::now()`],
30 /// not [`Instant::now()`](std::time::Instant::now) or
31 /// [`InstantExt::get`](web_time_compat::InstantExt::get).
32 #[educe(Debug(ignore))]
33 sleep_provider: P,
34 /// See [`RateLimitedWriterConfig::wake_when_bytes_available`].
35 wake_when_bytes_available: NonZero<u64>,
36 /// The inner writer.
37 #[educe(Debug(ignore))]
38 #[pin]
39 inner: W,
40 /// We need to store the sleep future if [`AsyncWrite::poll_write()`] blocks.
41 #[educe(Debug(ignore))]
42 #[pin]
43 sleep_fut: Option<SyncFuture<P::SleepFuture>>,
44}
45
46impl<W, P> RateLimitedWriter<W, P>
47where
48 W: AsyncWrite,
49 P: SleepProvider,
50{
51 /// Create a new [`RateLimitedWriter`].
52 // We take the rate and bucket max directly rather than a `TokenBucket` to ensure that the token
53 // bucket only ever uses times from `sleep_provider`.
54 pub(crate) fn new(writer: W, config: &RateLimitedWriterConfig, sleep_provider: P) -> Self {
55 let bucket_config = TokenBucketConfig {
56 rate: config.rate,
57 bucket_max: config.burst,
58 };
59 Self::from_token_bucket(
60 writer,
61 TokenBucket::new(&bucket_config, sleep_provider.now()),
62 config.wake_when_bytes_available,
63 sleep_provider,
64 )
65 }
66
67 /// Create a new [`RateLimitedWriter`] from a [`TokenBucket`].
68 ///
69 /// The token bucket must have only been used with times created by `sleep_provider`.
70 #[cfg_attr(test, visibility::make(pub(super)))]
71 fn from_token_bucket(
72 writer: W,
73 bucket: TokenBucket<Instant>,
74 wake_when_bytes_available: NonZero<u64>,
75 sleep_provider: P,
76 ) -> Self {
77 Self {
78 bucket,
79 sleep_provider,
80 wake_when_bytes_available,
81 inner: writer,
82 sleep_fut: None,
83 }
84 }
85
86 /// Access the inner [`AsyncWrite`] writer.
87 pub(crate) fn inner(&self) -> &W {
88 &self.inner
89 }
90
91 /// Adjust the refill rate and burst.
92 ///
93 /// A rate and/or burst of 0 is allowed.
94 pub(crate) fn adjust(
95 self: &mut Pin<&mut Self>,
96 now: Instant,
97 config: &RateLimitedWriterConfig,
98 ) {
99 let self_ = self.as_mut().project();
100
101 // destructuring allows us to make sure we aren't forgetting to handle any fields
102 let RateLimitedWriterConfig {
103 rate,
104 burst,
105 wake_when_bytes_available,
106 } = *config;
107
108 let bucket_config = TokenBucketConfig {
109 rate,
110 bucket_max: burst,
111 };
112
113 self_.bucket.adjust(now, &bucket_config);
114 *self_.wake_when_bytes_available = wake_when_bytes_available;
115 }
116
117 /// The sleep provider.
118 ///
119 /// We don't want this to be generally accessible, only to other token bucket-related modules
120 /// like [`DynamicRateLimitedWriter`](super::dynamic_writer::DynamicRateLimitedWriter).
121 pub(super) fn sleep_provider(&self) -> &P {
122 &self.sleep_provider
123 }
124
125 /// Configure this writer to sleep for `duration`.
126 ///
127 /// A `duration` of `None` is interpreted as "forever".
128 ///
129 /// It's considered a bug if asked to sleep for `Duration::ZERO` time.
130 fn register_sleep(
131 sleep_fut: &mut Pin<&mut Option<SyncFuture<P::SleepFuture>>>,
132 sleep_provider: &mut P,
133 cx: &mut Context<'_>,
134 duration: Option<Duration>,
135 ) -> Poll<()> {
136 match duration {
137 None => {
138 sleep_fut.as_mut().set(None);
139 Poll::Pending
140 }
141 Some(duration) => {
142 debug_assert_ne!(duration, Duration::ZERO, "asked to sleep for 0 time");
143 sleep_fut
144 .as_mut()
145 .set(Some(SyncFuture::new(sleep_provider.sleep(duration))));
146 sleep_fut
147 .as_mut()
148 .as_pin_mut()
149 .expect("but we just set it to `Some`?!")
150 .poll(cx)
151 }
152 }
153 }
154}
155
156impl<W, P> AsyncWrite for RateLimitedWriter<W, P>
157where
158 W: AsyncWrite,
159 P: SleepProvider,
160{
161 fn poll_write(
162 mut self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 mut buf: &[u8],
165 ) -> Poll<Result<usize, Error>> {
166 let mut self_ = self.as_mut().project();
167
168 // this should be optimized to a no-op on at least x86-64
169 fn to_u64(x: usize) -> u64 {
170 x.try_into().expect("failed usize to u64 conversion")
171 }
172
173 // for an empty buffer, just defer to the inner writer's impl
174 if buf.is_empty() {
175 return self_.inner.poll_write(cx, buf);
176 }
177
178 let now = self_.sleep_provider.now();
179
180 // refill the bucket and attempt to claim all of the bytes
181 self_.bucket.refill(now);
182 let claim = self_.bucket.claim(to_u64(buf.len()));
183
184 let mut claim = match claim {
185 // claim was successful
186 Ok(x) => x,
187 // not enough tokens, so let's use a smaller buffer
188 Err(e) => {
189 let available = e.available_tokens();
190
191 // need to drop the old claim so that we can access the token bucket again
192 drop(claim);
193
194 // if no tokens in bucket, we must sleep
195 if available == 0 {
196 // number of tokens we'll wait for
197 let wake_at_tokens = to_u64(buf.len());
198
199 // If the user wants to write X tokens, we don't necessarily want to sleep until
200 // we have room for X tokens. We also don't want to wake every time that a
201 // single byte can be written. We allow the user to configure this threshold
202 // with `RateLimitedWriterConfig::wake_when_bytes_available`.
203 let wake_at_tokens =
204 std::cmp::min(wake_at_tokens, self_.wake_when_bytes_available.get());
205
206 // max number of tokens the bucket can hold
207 let bucket_max = self_.bucket.max();
208
209 // how long to sleep for; `None` indicates to sleep forever
210 let sleep_for = if bucket_max == 0 {
211 // bucket can't hold any tokens, so sleep forever
212 None
213 } else {
214 // if the bucket has a max of X tokens, we should never try to wait for >X
215 // tokens
216 let wake_at_tokens = std::cmp::min(wake_at_tokens, bucket_max);
217
218 // if we asked for 0 tokens, we'd get a time of ~now, which is not what we
219 // want
220 debug_assert!(wake_at_tokens > 0);
221
222 let wake_at = self_.bucket.tokens_available_at(wake_at_tokens);
223 let sleep_for = wake_at.map(|x| x.saturating_duration_since(now));
224
225 match sleep_for {
226 Ok(x) => Some(x),
227 Err(NeverEnoughTokensError::ExceedsMaxTokens) => {
228 panic!(
229 "exceeds max tokens, but we took the max into account above"
230 );
231 }
232 // we aren't refilling, so sleep forever
233 Err(NeverEnoughTokensError::ZeroRate) => None,
234 // too far in the future to be represented, so sleep forever
235 Err(NeverEnoughTokensError::InstantNotRepresentable) => None,
236 }
237 };
238
239 // configure the sleep future and poll it to register
240 let poll = Self::register_sleep(
241 &mut self_.sleep_fut,
242 self_.sleep_provider,
243 cx,
244 sleep_for,
245 );
246 return match poll {
247 // wait for the sleep to finish
248 Poll::Pending => Poll::Pending,
249 // The sleep is already ready?! A recursive call here isn't great, but
250 // there's not much else we can do here. Hopefully this second `poll_write`
251 // will succeed since we should now have enough tokens.
252 Poll::Ready(()) => self.poll_write(cx, buf),
253 };
254 }
255
256 /// Convert a `u64` to `usize`, saturating if size of `usize` is smaller than `u64`.
257 // This is a separate function to ensure we don't accidentally try to convert a
258 // signed integer into a `usize`, in which case `unwrap_or(MAX)` wouldn't make
259 // sense.
260 fn to_usize_saturating(x: u64) -> usize {
261 x.try_into().unwrap_or(usize::MAX)
262 }
263
264 // There are tokens, so try to write as many as are available.
265 let available_usize = to_usize_saturating(available);
266 buf = &buf[0..available_usize];
267 self_.bucket.claim(to_u64(buf.len())).unwrap_or_else(|_| {
268 panic!(
269 "bucket has {available} tokens available, but can't claim {}?",
270 buf.len(),
271 )
272 })
273 }
274 };
275
276 let rv = self_.inner.poll_write(cx, buf);
277
278 match rv {
279 // no bytes were written, so discard the claim
280 Poll::Pending | Poll::Ready(Err(_)) => claim.discard(),
281 // `x` bytes were written, so only commit those tokens
282 Poll::Ready(Ok(x)) => {
283 if x <= buf.len() {
284 claim
285 .reduce(to_u64(x))
286 .expect("can't commit fewer tokens?!");
287 claim.commit();
288 } else {
289 cfg_if::cfg_if! {
290 if #[cfg(debug_assertions)] {
291 panic!(
292 "Writer is claiming it wrote more bytes {x} than we gave it {}",
293 buf.len(),
294 );
295 } else {
296 // the best we can do is to just claim the original amount
297 claim.commit();
298 }
299 }
300 }
301 }
302 };
303
304 rv
305 }
306
307 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
308 self.project().inner.poll_flush(cx)
309 }
310
311 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
312 // some implementers of `AsyncWrite` (like `Vec`) don't do anything other than flush when
313 // closed and will continue to accept bytes even after being closed, so we must continue to
314 // apply rate limiting even after being closed
315 self.project().inner.poll_close(cx)
316 }
317}
318
319/// A module to make it easier to implement tokio traits without putting `cfg()` conditionals
320/// everywhere.
321#[cfg(feature = "tokio")]
322mod tokio_impl {
323 use super::*;
324
325 use tokio_crate::io::AsyncWrite as TokioAsyncWrite;
326 use tokio_util::compat::FuturesAsyncWriteCompatExt;
327
328 use std::io::Result as IoResult;
329
330 impl<W, P> TokioAsyncWrite for RateLimitedWriter<W, P>
331 where
332 W: AsyncWrite,
333 P: SleepProvider,
334 {
335 fn poll_write(
336 self: Pin<&mut Self>,
337 cx: &mut Context<'_>,
338 buf: &[u8],
339 ) -> Poll<IoResult<usize>> {
340 TokioAsyncWrite::poll_write(Pin::new(&mut self.compat_write()), cx, buf)
341 }
342
343 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
344 TokioAsyncWrite::poll_flush(Pin::new(&mut self.compat_write()), cx)
345 }
346
347 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
348 TokioAsyncWrite::poll_shutdown(Pin::new(&mut self.compat_write()), cx)
349 }
350 }
351}
352
353/// The refill rate and burst for a [`RateLimitedWriter`].
354#[derive(Clone, Debug)]
355pub(crate) struct RateLimitedWriterConfig {
356 /// The refill rate in bytes/second.
357 pub(crate) rate: u64,
358 /// The "burst" in bytes.
359 pub(crate) burst: u64,
360 /// When polled, block until at most this many bytes are available.
361 ///
362 /// Or in other words, wake when we can write this many bytes, even if the provided buffer is
363 /// larger.
364 ///
365 /// For example if a user attempts to write a large buffer, we usually don't want to block until
366 /// the entire buffer can be written. We'd prefer several partial writes to a single large
367 /// write. So instead of blocking until the entire buffer can be written, we only block until
368 /// at most this many bytes are available.
369 pub(crate) wake_when_bytes_available: NonZero<u64>,
370}
371
372#[cfg(test)]
373mod test {
374 #![allow(clippy::unwrap_used)]
375
376 use super::*;
377
378 use futures::{AsyncWriteExt, FutureExt};
379 use tor_rtcompat::SpawnExt;
380
381 #[test]
382 fn writer() {
383 tor_rtmock::MockRuntime::test_with_various(|rt| async move {
384 let start = rt.now();
385
386 // increases 10 tokens/second (one every 100 ms)
387 let config = TokenBucketConfig {
388 rate: 10,
389 bucket_max: 100,
390 };
391 let mut tb = TokenBucket::new(&config, start);
392 // drain the bucket
393 tb.drain(100).unwrap();
394
395 let wake_when_bytes_available = NonZero::new(15).unwrap();
396
397 let mut writer = Vec::new();
398 let mut writer = RateLimitedWriter::from_token_bucket(
399 &mut writer,
400 tb,
401 wake_when_bytes_available,
402 rt.clone(),
403 );
404
405 // drive time forward from 0 to 20_000 ms in 50 ms intervals
406 let rt_clone = rt.clone();
407 rt.spawn(async move {
408 for _ in 0..400 {
409 rt_clone.progress_until_stalled().await;
410 rt_clone.advance_by(Duration::from_millis(50)).await;
411 }
412 })
413 .unwrap();
414
415 // try writing 60 bytes, which sleeps until we can write at least 15 of them
416 assert_eq!(15, writer.write(&[0; 60]).await.unwrap());
417 assert_eq!(1500, rt.now().duration_since(start).as_millis());
418
419 // wait 2 seconds
420 rt.sleep(Duration::from_millis(2000)).await;
421
422 // ensure that we can write immediately, and that we can write
423 // 2000 ms / (100 ms/token) = 20 bytes
424 assert_eq!(
425 Some(20),
426 writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
427 );
428 });
429 }
430
431 /// Test that writing to a token bucket which has a rate and/or max of 0 works as expected.
432 #[test]
433 fn rate_burst_zero() {
434 let configs = [
435 // non-zero rate, zero max
436 TokenBucketConfig {
437 rate: 10,
438 bucket_max: 0,
439 },
440 // zero rate, non-zero max
441 TokenBucketConfig {
442 rate: 0,
443 bucket_max: 10,
444 },
445 // zero rate, zero max
446 TokenBucketConfig {
447 rate: 0,
448 bucket_max: 0,
449 },
450 ];
451 for config in configs {
452 tor_rtmock::MockRuntime::test_with_various(|rt| {
453 let config = config.clone();
454 async move {
455 // an empty token bucket
456 let mut tb = TokenBucket::new(&config, rt.now());
457 tb.drain(tb.max()).unwrap();
458 assert!(tb.is_empty());
459
460 let wake_when_bytes_available = NonZero::new(2).unwrap();
461
462 let mut writer = Vec::new();
463 let mut writer = RateLimitedWriter::from_token_bucket(
464 &mut writer,
465 tb,
466 wake_when_bytes_available,
467 rt.clone(),
468 );
469
470 // drive time forward from 0 to 10_000 ms in 100 ms intervals
471 let rt_clone = rt.clone();
472 rt.spawn(async move {
473 for _ in 0..100 {
474 rt_clone.progress_until_stalled().await;
475 rt_clone.advance_by(Duration::from_millis(100)).await;
476 }
477 })
478 .unwrap();
479
480 // ensure that a write returns `Pending`
481 assert_eq!(
482 None,
483 writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
484 );
485
486 // wait 5 seconds
487 rt.sleep(Duration::from_millis(5000)).await;
488
489 // ensure that a write still returns `Pending`
490 assert_eq!(
491 None,
492 writer.write(&[0; 60]).now_or_never().map(Result::unwrap),
493 );
494 }
495 });
496 }
497 }
498}