1use crate::time::{Duration, Instant};
12use crate::{Framework, Machine, TriggerAction, TriggerEvent};
13use rand_core::RngCore;
14use std::ops::Sub;
15use std::time::Instant as StdInstant;
16
17pub struct RateLimitedFramework<M, R, T = StdInstant>
47where
48 M: AsRef<[Machine]>,
49 R: RngCore,
50 T: Instant,
51 T::Duration: Sub<Output = T::Duration>,
52{
53 framework: Framework<M, R, T>,
54 prev: f64,
56 current: f64,
58 tick: T,
60}
61
62impl<M, R, T> RateLimitedFramework<M, R, T>
63where
64 M: AsRef<[Machine]>,
65 R: RngCore,
66 T: Instant,
67 T::Duration: Sub<Output = T::Duration>,
68{
69 pub fn new(framework: Framework<M, R, T>) -> Self {
81 let tick = framework.current_time;
82
83 Self {
84 framework,
85 prev: 0.0,
86 current: 0.0,
87 tick,
88 }
89 }
90
91 pub fn trigger_events(
112 &mut self,
113 events: &[TriggerEvent],
114 max_actions_per_second: f64,
115 current_time: T,
116 ) -> impl Iterator<Item = &TriggerAction<T>> {
117 let window_1s = Duration::from_micros(1_000_000);
118
119 #[allow(unused_must_use)]
124 self.framework.trigger_events(events, current_time);
125
126 let delta = current_time.saturating_duration_since(self.tick);
127 if delta < window_1s {
129 let rate = (self.prev * (window_1s - delta).div_duration_f64(window_1s)) + self.current;
133 if rate >= max_actions_per_second {
134 self.framework.actions.fill(None);
136 }
137 } else {
138 if delta.div_duration_f64(window_1s) < 2.0 {
139 self.prev = self.current;
141 } else {
142 self.prev = 0.0;
144 }
145 self.tick = current_time;
147 self.current = 0.0;
148 }
149
150 self.current += self
151 .framework
152 .actions
153 .iter()
154 .filter(|a| a.is_some())
155 .count() as f64;
156 self.framework
157 .actions
158 .iter()
159 .filter_map(|action| action.as_ref())
160 }
161
162 pub fn framework(&self) -> &Framework<M, R, T> {
166 &self.framework
167 }
168
169 pub fn framework_mut(&mut self) -> &mut Framework<M, R, T> {
174 &mut self.framework
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::action::Action;
182 use crate::dist::{Dist, DistType};
183 use crate::event::Event;
184 use crate::state::{State, Trans};
185 use crate::{Framework, Machine, MachineId, TriggerEvent};
186 use enum_map::enum_map;
187 use std::time::Instant as StdInstant;
188
189 fn create_test_framework() -> Framework<Vec<Machine>, rand::rngs::ThreadRng, StdInstant> {
190 let mut state = State::new(enum_map! {
191 Event::PaddingSent => vec![Trans(0, 1.0)],
192 _ => vec![],
193 });
194 state.action = Some(Action::SendPadding {
195 bypass: false,
196 replace: false,
197 timeout: Dist {
198 dist: DistType::Uniform {
199 low: 0.0,
200 high: 0.0,
201 },
202 start: 0.0,
203 max: 0.0,
204 },
205 limit: None,
206 });
207
208 let m = Machine::new(1_000_000, 0.0, 0, 0.0, vec![state]).unwrap();
209 Framework::new(vec![m], 0.0, 0.0, StdInstant::now(), rand::rng()).unwrap()
210 }
211
212 #[test]
213 fn test_new() {
214 let framework = create_test_framework();
215 let rate_limited = RateLimitedFramework::new(framework);
216
217 assert_eq!(rate_limited.prev, 0.0);
218 assert_eq!(rate_limited.current, 0.0);
219 }
220
221 #[test]
222 fn test_framework_accessors() {
223 let framework = create_test_framework();
224 let mut rate_limited = RateLimitedFramework::new(framework);
225
226 let _framework_ref = rate_limited.framework();
227 let _framework_mut_ref = rate_limited.framework_mut();
228 }
229
230 #[test]
231 fn test_rate_limiting_under_limit() {
232 let framework = create_test_framework();
233 let mut rate_limited = RateLimitedFramework::new(framework);
234
235 let events = [TriggerEvent::PaddingSent {
236 machine: MachineId::from_raw(0),
237 }];
238 let max_rate = 10.0;
239 let current_time = StdInstant::now();
240
241 let _actions: Vec<_> = rate_limited
242 .trigger_events(&events, max_rate, current_time)
243 .collect();
244
245 assert_eq!(rate_limited.current, 1.0);
246 }
247
248 #[test]
249 fn test_rate_limiting_over_limit() {
250 let framework = create_test_framework();
251 let mut rate_limited = RateLimitedFramework::new(framework);
252
253 let events = [TriggerEvent::PaddingSent {
254 machine: MachineId::from_raw(0),
255 }];
256 let max_rate = 0.5;
257 let current_time = StdInstant::now();
258
259 rate_limited.current = 1.0;
260
261 let _actions: Vec<_> = rate_limited
262 .trigger_events(&events, max_rate, current_time)
263 .collect();
264
265 assert_eq!(rate_limited.current, 1.0);
266 }
267
268 #[test]
269 fn test_sliding_window_within_current_window() {
270 let framework = create_test_framework();
271 let mut rate_limited = RateLimitedFramework::new(framework);
272
273 rate_limited.prev = 2.0;
274 rate_limited.current = 1.0;
275
276 let events = [TriggerEvent::PaddingSent {
277 machine: MachineId::from_raw(0),
278 }];
279 let max_rate = 2.0;
280 let current_time = rate_limited.tick + std::time::Duration::from_millis(500);
281
282 let _actions: Vec<_> = rate_limited
283 .trigger_events(&events, max_rate, current_time)
284 .collect();
285 }
286
287 #[test]
288 fn test_sliding_window_next_window() {
289 let framework = create_test_framework();
290 let mut rate_limited = RateLimitedFramework::new(framework);
291
292 rate_limited.current = 5.0;
293 let original_tick = rate_limited.tick;
294
295 let events = [TriggerEvent::PaddingSent {
296 machine: MachineId::from_raw(0),
297 }];
298 let max_rate = 10.0;
299 let current_time = rate_limited.tick + std::time::Duration::from_millis(1500);
300
301 let _actions: Vec<_> = rate_limited
302 .trigger_events(&events, max_rate, current_time)
303 .collect();
304
305 assert_eq!(rate_limited.prev, 5.0);
306 assert_eq!(rate_limited.current, 1.0);
307 assert!(rate_limited.tick > original_tick);
308 }
309
310 #[test]
311 fn test_sliding_window_long_duration_reset() {
312 let framework = create_test_framework();
313 let mut rate_limited = RateLimitedFramework::new(framework);
314
315 rate_limited.current = 5.0;
316
317 let events = [TriggerEvent::PaddingSent {
318 machine: MachineId::from_raw(0),
319 }];
320 let max_rate = 10.0;
321 let current_time = rate_limited.tick + std::time::Duration::from_secs(3);
322
323 let _actions: Vec<_> = rate_limited
324 .trigger_events(&events, max_rate, current_time)
325 .collect();
326
327 assert_eq!(rate_limited.prev, 0.0);
328 assert_eq!(rate_limited.current, 1.0);
329 }
330
331 #[test]
332 fn test_multiple_events_increment_current() {
333 let framework = create_test_framework();
334 let mut rate_limited = RateLimitedFramework::new(framework);
335
336 let events = [TriggerEvent::PaddingSent {
337 machine: MachineId::from_raw(0),
338 }];
339 let max_rate = 10.0;
340 let current_time = StdInstant::now();
341
342 rate_limited
343 .trigger_events(&events, max_rate, current_time)
344 .count();
345 rate_limited
346 .trigger_events(&events, max_rate, current_time)
347 .count();
348 rate_limited
349 .trigger_events(&events, max_rate, current_time)
350 .count();
351
352 assert_eq!(rate_limited.current, 3.0);
353 }
354
355 #[test]
356 fn test_actions_returned_when_under_rate_limit() {
357 let framework = create_test_framework();
358 let mut rate_limited = RateLimitedFramework::new(framework);
359
360 let events = [TriggerEvent::PaddingSent {
361 machine: crate::MachineId::from_raw(0),
362 }];
363 let max_rate = 10.0;
364 let current_time = StdInstant::now();
365
366 let actions: Vec<_> = rate_limited
367 .trigger_events(&events, max_rate, current_time)
368 .collect();
369
370 assert!(!actions.is_empty());
371 assert_eq!(rate_limited.current, 1.0);
372 }
373
374 #[test]
375 fn test_actions_blocked_when_over_rate_limit() {
376 let framework = create_test_framework();
377 let mut rate_limited = RateLimitedFramework::new(framework);
378
379 rate_limited.current = 2.0;
380
381 let events = [TriggerEvent::PaddingSent {
382 machine: crate::MachineId::from_raw(0),
383 }];
384 let max_rate = 1.0;
385 let current_time = StdInstant::now();
386
387 let actions: Vec<_> = rate_limited
388 .trigger_events(&events, max_rate, current_time)
389 .collect();
390
391 assert!(actions.is_empty());
392 assert_eq!(rate_limited.current, 2.0);
393 }
394
395 #[test]
396 fn test_rate_limiting_with_sliding_window_calculation() {
397 let framework = create_test_framework();
398 let mut rate_limited = RateLimitedFramework::new(framework);
399
400 rate_limited.prev = 3.0;
401 rate_limited.current = 1.0;
402
403 let events = [TriggerEvent::PaddingSent {
404 machine: MachineId::from_raw(0),
405 }];
406 let max_rate = 2.5;
407 let current_time = rate_limited.tick + std::time::Duration::from_millis(250);
408
409 let actions: Vec<_> = rate_limited
410 .trigger_events(&events, max_rate, current_time)
411 .collect();
412
413 assert!(actions.is_empty());
414 assert_eq!(rate_limited.current, 1.0);
415 }
416
417 #[test]
418 fn test_repeated_triggers_with_rate_limit_5() {
419 let framework = create_test_framework();
420 let mut rate_limited = RateLimitedFramework::new(framework);
421
422 let events = [TriggerEvent::PaddingSent {
423 machine: MachineId::from_raw(0),
424 }];
425 let max_rate = 5.0;
426 let current_time = StdInstant::now();
427
428 for i in 1..=5 {
429 let actions: Vec<_> = rate_limited
430 .trigger_events(&events, max_rate, current_time)
431 .collect();
432 assert!(!actions.is_empty(), "Expected actions on iteration {}", i);
433 assert_eq!(rate_limited.current, i as f64);
434 }
435
436 let actions: Vec<_> = rate_limited
437 .trigger_events(&events, max_rate, current_time)
438 .collect();
439 assert!(actions.is_empty(), "Expected no actions when over limit");
440 assert_eq!(rate_limited.current, 5.0);
441 }
442}