Skip to main content

maybenot/
counter.rs

1//! Counters as part of a [`Machine`](crate::Machine).
2
3use rand_core::RngCore;
4use serde::{Deserialize, Serialize};
5
6use crate::{Error, dist};
7use std::fmt;
8
9use self::dist::Dist;
10
11/// The operation applied to one of a [`Machine`](crate::Machine)'s counters
12/// upon transition to a [`State`](crate::state::State).
13#[derive(Debug, Eq, Hash, PartialEq, Clone, Copy, Serialize, Deserialize)]
14pub enum Operation {
15    /// Increment the counter.
16    Increment,
17    /// Decrement the counter.
18    Decrement,
19    /// Replace the current value of the counter.
20    Set,
21}
22
23/// A specification of how one of a [`Machine`](crate::Machine)'s counters
24/// should be updated when transitioning to a [`State`](crate::state::State).
25/// Consists of an [`Operation`] to be applied to the counter with one of three
26/// values: by default, the value 1, unless a distribution is provided or the
27/// copy flag is set to true. If the copy flag is set to true, the counter will
28/// be updated with the value of the other counter *prior to transitioning to
29/// the state*. If a distribution is provided, the counter will be updated with
30/// a value sampled from the distribution.
31#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
32pub struct Counter {
33    /// The operation to apply to the counter upon a state transition. If the
34    /// distribution is not set and copy is false, the counter will be updated
35    /// by 1.
36    pub operation: Operation,
37    /// If set, sample the value to update the counter with from a
38    /// distribution.
39    pub dist: Option<Dist>,
40    /// If set, the counter will be updated by the other counter's value *prior
41    /// to transitioning to the state*. Supersedes the `dist` field.
42    pub copy: bool,
43}
44
45impl fmt::Display for Counter {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        write!(f, "{self:#?}")
48    }
49}
50
51impl Counter {
52    /// Create a new counter with an operation that modifies the counter with
53    /// value 1.
54    pub fn new(operation: Operation) -> Self {
55        Counter {
56            operation,
57            dist: None,
58            copy: false,
59        }
60    }
61
62    /// Create a new counter with an operation and a distribution to sample the
63    /// value from.
64    pub fn new_dist(operation: Operation, dist: Dist) -> Self {
65        Counter {
66            operation,
67            dist: Some(dist),
68            copy: false,
69        }
70    }
71
72    /// Create a new counter with an operation that copies the value of the
73    /// other counter *prior to transitioning to the state*.
74    pub fn new_copy(operation: Operation) -> Self {
75        Counter {
76            operation,
77            dist: None,
78            copy: true,
79        }
80    }
81
82    /// Sample a value to update the counter with.
83    pub fn sample_value<R: RngCore>(&self, rng: &mut R) -> u64 {
84        // Maximum safe f64 value that can be converted to u64 without overflow
85        const MAX_SAFE_F64_TO_U64: f64 = u64::MAX as f64;
86
87        match self.dist {
88            None => 1,
89            Some(dist) => {
90                let sampled = dist.sample(rng);
91                if !sampled.is_finite() || sampled < 0.0 {
92                    0
93                } else if sampled > MAX_SAFE_F64_TO_U64 {
94                    u64::MAX
95                } else {
96                    sampled as u64
97                }
98            }
99        }
100    }
101
102    // Validate the value dist.
103    pub fn validate(&self) -> Result<(), Error> {
104        if let Some(dist) = self.dist {
105            dist.validate()?;
106        }
107        Ok(())
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use crate::{counter::*, dist::DistType};
114
115    #[test]
116    fn validate_counter_update() {
117        // valid counter update
118        let mut cu = Counter::new_dist(
119            Operation::Increment,
120            Dist {
121                dist: DistType::Uniform {
122                    low: 10.0,
123                    high: 10.0,
124                },
125                start: 0.0,
126                max: 0.0,
127            },
128        );
129
130        let r = cu.validate();
131        assert!(r.is_ok());
132
133        // counter update with invalid dist
134        cu.dist = Some(Dist {
135            dist: DistType::Uniform {
136                low: 15.0, // NOTE low > high
137                high: 5.0,
138            },
139            start: 0.0,
140            max: 0.0,
141        });
142
143        let r = cu.validate();
144        assert!(r.is_err());
145
146        // counter with default value
147        cu.dist = None;
148
149        let r = cu.validate();
150        assert!(r.is_ok());
151
152        assert_eq!(cu.sample_value(&mut rand::rng()), 1);
153
154        // counter with copy value
155        cu.copy = true;
156
157        let r = cu.validate();
158        assert!(r.is_ok());
159    }
160
161    #[test]
162    fn sample_value_overflow_protection() {
163        use crate::dist::{Dist, DistType};
164
165        // Test with distribution that can produce very large values
166        let cu = Counter::new_dist(
167            Operation::Increment,
168            Dist {
169                dist: DistType::Uniform {
170                    low: f64::MAX,
171                    high: f64::MAX,
172                },
173                start: 0.0,
174                max: 0.0,
175            },
176        );
177
178        let sampled = cu.sample_value(&mut rand::rng());
179        assert_eq!(sampled, u64::MAX);
180
181        // Test with distribution that can produce negative values
182        let cu_negative = Counter::new_dist(
183            Operation::Increment,
184            Dist {
185                dist: DistType::Uniform {
186                    low: -1000.0,
187                    high: -500.0,
188                },
189                start: 0.0,
190                max: 0.0,
191            },
192        );
193
194        let sampled_negative = cu_negative.sample_value(&mut rand::rng());
195        assert_eq!(sampled_negative, 0);
196
197        // Test with distribution that can produce NaN (create invalid Normal dist case)
198        let cu_nan = Counter::new_dist(
199            Operation::Increment,
200            Dist {
201                dist: DistType::Normal {
202                    mean: f64::NAN,
203                    stdev: 1.0,
204                },
205                start: 0.0,
206                max: 0.0,
207            },
208        );
209
210        let sampled_nan = cu_nan.sample_value(&mut rand::rng());
211        assert_eq!(sampled_nan, 0);
212    }
213}