Skip to main content

maybenot/
rate_limited_framework.rs

1//! Rate-limited wrapper for the Maybenot framework.
2//!
3//! This module provides a [`RateLimitedFramework`] that wraps the core [`Framework`]
4//! to add rate limiting capabilities. The rate limiter uses a sliding window algorithm
5//! to limit the number of actions returned per second, helping prevent abuse and
6//! excessive resource consumption.
7//!
8//! The sliding window algorithm is based on the approach described in Cloudflare's
9//! blog post: <https://blog.cloudflare.com/counting-things-a-lot-of-different-things/>
10
11use crate::time::{Duration, Instant};
12use crate::{Framework, Machine, TriggerAction, TriggerEvent};
13use rand_core::RngCore;
14use std::ops::Sub;
15use std::time::Instant as StdInstant;
16
17/// A rate-limited wrapper around the Maybenot framework.
18///
19/// This struct wraps a [`Framework`] and applies rate limiting to the actions
20/// returned by [`trigger_events`](Self::trigger_events). It uses a sliding window
21/// algorithm to track the rate of events and blocks actions when the rate exceeds
22/// the specified limit.
23///
24/// The rate limiter tracks events across a 1-second sliding window, using the
25/// previous window's count and the current window's count to calculate the
26/// effective rate. This approach prevents burst traffic from overwhelming the
27/// system while allowing sustained traffic up to the limit.
28///
29/// # Example
30/// ```
31/// use maybenot::{Framework, RateLimitedFramework, Machine, TriggerEvent};
32/// use std::time::Instant;
33///
34/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
35/// let machines = vec![];
36/// let framework = Framework::new(machines, 0.0, 0.0, Instant::now(), rand::rng())?;
37/// let mut rate_limited = RateLimitedFramework::new(framework);
38///
39/// let events = [TriggerEvent::NormalSent];
40/// let actions: Vec<_> = rate_limited
41///     .trigger_events(&events, 10.0, Instant::now())
42///     .collect();
43/// # Ok(())
44/// # }
45/// ```
46pub struct RateLimitedFramework<M, R, T = StdInstant>
47where
48    M: AsRef<[Machine]>,
49    R: RngCore,
50    T: Instant,
51    T::Duration: Sub<Output = T::Duration>,
52{
53    framework: Framework<M, R, T>,
54    /// Count of events in the previous 1-second window
55    prev: f64,
56    /// Count of events in the current 1-second window
57    current: f64,
58    /// Start time of the current 1-second window
59    tick: T,
60}
61
62impl<M, R, T> RateLimitedFramework<M, R, T>
63where
64    M: AsRef<[Machine]>,
65    R: RngCore,
66    T: Instant,
67    T::Duration: Sub<Output = T::Duration>,
68{
69    /// Creates a new rate-limited framework wrapper.
70    ///
71    /// Initializes the rate limiter with zero counts for both the previous and
72    /// current windows, and sets the current window start time to the framework's
73    /// current time.
74    ///
75    /// # Arguments
76    /// * `framework` - The underlying framework to wrap with rate limiting
77    ///
78    /// # Returns
79    /// A new `RateLimitedFramework` instance
80    pub fn new(framework: Framework<M, R, T>) -> Self {
81        let tick = framework.current_time;
82
83        Self {
84            framework,
85            prev: 0.0,
86            current: 0.0,
87            tick,
88        }
89    }
90
91    /// Triggers events in the framework with rate limiting applied.
92    ///
93    /// This method forwards events to the underlying framework and applies rate
94    /// limiting to the returned actions. It uses a sliding window algorithm to
95    /// track the rate of events over time and blocks actions when the rate exceeds
96    /// the specified limit.
97    ///
98    /// The sliding window calculation uses the formula from Cloudflare's approach:
99    /// `rate = (prev * (1s - elapsed) / 1s) + current`
100    ///
101    /// Events are always processed (allowing machines to transition states), but
102    /// actions may be dropped if the rate limit is exceeded.
103    ///
104    /// # Arguments
105    /// * `events` - The events to trigger in the framework
106    /// * `max_actions_per_second` - Maximum allowed actions per second
107    /// * `current_time` - The current time for rate window calculations
108    ///
109    /// # Returns
110    /// An iterator over the allowed actions (may be empty if rate limited)
111    pub fn trigger_events(
112        &mut self,
113        events: &[TriggerEvent],
114        max_actions_per_second: f64,
115        current_time: T,
116    ) -> impl Iterator<Item = &TriggerAction<T>> {
117        let window_1s = Duration::from_micros(1_000_000);
118
119        // We always trigger events since that can cause machines to transition,
120        // we just rate limit the returned actions. If the user of the framework
121        // cannot keep up, they're supposed to first start batching their
122        // events, then tail drop old events worst-case.
123        #[allow(unused_must_use)]
124        self.framework.trigger_events(events, current_time);
125
126        let delta = current_time.saturating_duration_since(self.tick);
127        // are we in the current potentially busy window?
128        if delta < window_1s {
129            // simple sliding window like cloudflare uses/used,
130            // https://blog.cloudflare.com/counting-things-a-lot-of-different-things/
131            // , assuming previous hits were uniformly distributed
132            let rate = (self.prev * (window_1s - delta).div_duration_f64(window_1s)) + self.current;
133            if rate >= max_actions_per_second {
134                // over rate, fill actions with None (dropping all actions)
135                self.framework.actions.fill(None);
136            }
137        } else {
138            if delta.div_duration_f64(window_1s) < 2.0 {
139                // we are still in the next window, save previous count
140                self.prev = self.current;
141            } else {
142                // long duration since last trigger, reset previous count
143                self.prev = 0.0;
144            }
145            // start new window
146            self.tick = current_time;
147            self.current = 0.0;
148        }
149
150        self.current += self
151            .framework
152            .actions
153            .iter()
154            .filter(|a| a.is_some())
155            .count() as f64;
156        self.framework
157            .actions
158            .iter()
159            .filter_map(|action| action.as_ref())
160    }
161
162    /// Returns a reference to the underlying framework.
163    ///
164    /// This provides read-only access to the wrapped framework instance.
165    pub fn framework(&self) -> &Framework<M, R, T> {
166        &self.framework
167    }
168
169    /// Returns a mutable reference to the underlying framework.
170    ///
171    /// This provides mutable access to the wrapped framework instance for
172    /// advanced use cases that need to modify the framework state directly.
173    pub fn framework_mut(&mut self) -> &mut Framework<M, R, T> {
174        &mut self.framework
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::action::Action;
182    use crate::dist::{Dist, DistType};
183    use crate::event::Event;
184    use crate::state::{State, Trans};
185    use crate::{Framework, Machine, MachineId, TriggerEvent};
186    use enum_map::enum_map;
187    use std::time::Instant as StdInstant;
188
189    fn create_test_framework() -> Framework<Vec<Machine>, rand::rngs::ThreadRng, StdInstant> {
190        let mut state = State::new(enum_map! {
191            Event::PaddingSent => vec![Trans(0, 1.0)],
192        _ => vec![],
193        });
194        state.action = Some(Action::SendPadding {
195            bypass: false,
196            replace: false,
197            timeout: Dist {
198                dist: DistType::Uniform {
199                    low: 0.0,
200                    high: 0.0,
201                },
202                start: 0.0,
203                max: 0.0,
204            },
205            limit: None,
206        });
207
208        let m = Machine::new(1_000_000, 0.0, 0, 0.0, vec![state]).unwrap();
209        Framework::new(vec![m], 0.0, 0.0, StdInstant::now(), rand::rng()).unwrap()
210    }
211
212    #[test]
213    fn test_new() {
214        let framework = create_test_framework();
215        let rate_limited = RateLimitedFramework::new(framework);
216
217        assert_eq!(rate_limited.prev, 0.0);
218        assert_eq!(rate_limited.current, 0.0);
219    }
220
221    #[test]
222    fn test_framework_accessors() {
223        let framework = create_test_framework();
224        let mut rate_limited = RateLimitedFramework::new(framework);
225
226        let _framework_ref = rate_limited.framework();
227        let _framework_mut_ref = rate_limited.framework_mut();
228    }
229
230    #[test]
231    fn test_rate_limiting_under_limit() {
232        let framework = create_test_framework();
233        let mut rate_limited = RateLimitedFramework::new(framework);
234
235        let events = [TriggerEvent::PaddingSent {
236            machine: MachineId::from_raw(0),
237        }];
238        let max_rate = 10.0;
239        let current_time = StdInstant::now();
240
241        let _actions: Vec<_> = rate_limited
242            .trigger_events(&events, max_rate, current_time)
243            .collect();
244
245        assert_eq!(rate_limited.current, 1.0);
246    }
247
248    #[test]
249    fn test_rate_limiting_over_limit() {
250        let framework = create_test_framework();
251        let mut rate_limited = RateLimitedFramework::new(framework);
252
253        let events = [TriggerEvent::PaddingSent {
254            machine: MachineId::from_raw(0),
255        }];
256        let max_rate = 0.5;
257        let current_time = StdInstant::now();
258
259        rate_limited.current = 1.0;
260
261        let _actions: Vec<_> = rate_limited
262            .trigger_events(&events, max_rate, current_time)
263            .collect();
264
265        assert_eq!(rate_limited.current, 1.0);
266    }
267
268    #[test]
269    fn test_sliding_window_within_current_window() {
270        let framework = create_test_framework();
271        let mut rate_limited = RateLimitedFramework::new(framework);
272
273        rate_limited.prev = 2.0;
274        rate_limited.current = 1.0;
275
276        let events = [TriggerEvent::PaddingSent {
277            machine: MachineId::from_raw(0),
278        }];
279        let max_rate = 2.0;
280        let current_time = rate_limited.tick + std::time::Duration::from_millis(500);
281
282        let _actions: Vec<_> = rate_limited
283            .trigger_events(&events, max_rate, current_time)
284            .collect();
285    }
286
287    #[test]
288    fn test_sliding_window_next_window() {
289        let framework = create_test_framework();
290        let mut rate_limited = RateLimitedFramework::new(framework);
291
292        rate_limited.current = 5.0;
293        let original_tick = rate_limited.tick;
294
295        let events = [TriggerEvent::PaddingSent {
296            machine: MachineId::from_raw(0),
297        }];
298        let max_rate = 10.0;
299        let current_time = rate_limited.tick + std::time::Duration::from_millis(1500);
300
301        let _actions: Vec<_> = rate_limited
302            .trigger_events(&events, max_rate, current_time)
303            .collect();
304
305        assert_eq!(rate_limited.prev, 5.0);
306        assert_eq!(rate_limited.current, 1.0);
307        assert!(rate_limited.tick > original_tick);
308    }
309
310    #[test]
311    fn test_sliding_window_long_duration_reset() {
312        let framework = create_test_framework();
313        let mut rate_limited = RateLimitedFramework::new(framework);
314
315        rate_limited.current = 5.0;
316
317        let events = [TriggerEvent::PaddingSent {
318            machine: MachineId::from_raw(0),
319        }];
320        let max_rate = 10.0;
321        let current_time = rate_limited.tick + std::time::Duration::from_secs(3);
322
323        let _actions: Vec<_> = rate_limited
324            .trigger_events(&events, max_rate, current_time)
325            .collect();
326
327        assert_eq!(rate_limited.prev, 0.0);
328        assert_eq!(rate_limited.current, 1.0);
329    }
330
331    #[test]
332    fn test_multiple_events_increment_current() {
333        let framework = create_test_framework();
334        let mut rate_limited = RateLimitedFramework::new(framework);
335
336        let events = [TriggerEvent::PaddingSent {
337            machine: MachineId::from_raw(0),
338        }];
339        let max_rate = 10.0;
340        let current_time = StdInstant::now();
341
342        rate_limited
343            .trigger_events(&events, max_rate, current_time)
344            .count();
345        rate_limited
346            .trigger_events(&events, max_rate, current_time)
347            .count();
348        rate_limited
349            .trigger_events(&events, max_rate, current_time)
350            .count();
351
352        assert_eq!(rate_limited.current, 3.0);
353    }
354
355    #[test]
356    fn test_actions_returned_when_under_rate_limit() {
357        let framework = create_test_framework();
358        let mut rate_limited = RateLimitedFramework::new(framework);
359
360        let events = [TriggerEvent::PaddingSent {
361            machine: crate::MachineId::from_raw(0),
362        }];
363        let max_rate = 10.0;
364        let current_time = StdInstant::now();
365
366        let actions: Vec<_> = rate_limited
367            .trigger_events(&events, max_rate, current_time)
368            .collect();
369
370        assert!(!actions.is_empty());
371        assert_eq!(rate_limited.current, 1.0);
372    }
373
374    #[test]
375    fn test_actions_blocked_when_over_rate_limit() {
376        let framework = create_test_framework();
377        let mut rate_limited = RateLimitedFramework::new(framework);
378
379        rate_limited.current = 2.0;
380
381        let events = [TriggerEvent::PaddingSent {
382            machine: crate::MachineId::from_raw(0),
383        }];
384        let max_rate = 1.0;
385        let current_time = StdInstant::now();
386
387        let actions: Vec<_> = rate_limited
388            .trigger_events(&events, max_rate, current_time)
389            .collect();
390
391        assert!(actions.is_empty());
392        assert_eq!(rate_limited.current, 2.0);
393    }
394
395    #[test]
396    fn test_rate_limiting_with_sliding_window_calculation() {
397        let framework = create_test_framework();
398        let mut rate_limited = RateLimitedFramework::new(framework);
399
400        rate_limited.prev = 3.0;
401        rate_limited.current = 1.0;
402
403        let events = [TriggerEvent::PaddingSent {
404            machine: MachineId::from_raw(0),
405        }];
406        let max_rate = 2.5;
407        let current_time = rate_limited.tick + std::time::Duration::from_millis(250);
408
409        let actions: Vec<_> = rate_limited
410            .trigger_events(&events, max_rate, current_time)
411            .collect();
412
413        assert!(actions.is_empty());
414        assert_eq!(rate_limited.current, 1.0);
415    }
416
417    #[test]
418    fn test_repeated_triggers_with_rate_limit_5() {
419        let framework = create_test_framework();
420        let mut rate_limited = RateLimitedFramework::new(framework);
421
422        let events = [TriggerEvent::PaddingSent {
423            machine: MachineId::from_raw(0),
424        }];
425        let max_rate = 5.0;
426        let current_time = StdInstant::now();
427
428        for i in 1..=5 {
429            let actions: Vec<_> = rate_limited
430                .trigger_events(&events, max_rate, current_time)
431                .collect();
432            assert!(!actions.is_empty(), "Expected actions on iteration {}", i);
433            assert_eq!(rate_limited.current, i as f64);
434        }
435
436        let actions: Vec<_> = rate_limited
437            .trigger_events(&events, max_rate, current_time)
438            .collect();
439        assert!(actions.is_empty(), "Expected no actions when over limit");
440        assert_eq!(rate_limited.current, 5.0);
441    }
442}