maybenot/counter.rs
1//! Counters as part of a [`Machine`](crate::Machine).
2
3use rand_core::RngCore;
4use serde::{Deserialize, Serialize};
5
6use crate::{Error, dist};
7use std::fmt;
8
9use self::dist::Dist;
10
11/// The operation applied to one of a [`Machine`](crate::Machine)'s counters
12/// upon transition to a [`State`](crate::state::State).
13#[derive(Debug, Eq, Hash, PartialEq, Clone, Copy, Serialize, Deserialize)]
14pub enum Operation {
15 /// Increment the counter.
16 Increment,
17 /// Decrement the counter.
18 Decrement,
19 /// Replace the current value of the counter.
20 Set,
21}
22
23/// A specification of how one of a [`Machine`](crate::Machine)'s counters
24/// should be updated when transitioning to a [`State`](crate::state::State).
25/// Consists of an [`Operation`] to be applied to the counter with one of three
26/// values: by default, the value 1, unless a distribution is provided or the
27/// copy flag is set to true. If the copy flag is set to true, the counter will
28/// be updated with the value of the other counter *prior to transitioning to
29/// the state*. If a distribution is provided, the counter will be updated with
30/// a value sampled from the distribution.
31#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
32pub struct Counter {
33 /// The operation to apply to the counter upon a state transition. If the
34 /// distribution is not set and copy is false, the counter will be updated
35 /// by 1.
36 pub operation: Operation,
37 /// If set, sample the value to update the counter with from a
38 /// distribution.
39 pub dist: Option<Dist>,
40 /// If set, the counter will be updated by the other counter's value *prior
41 /// to transitioning to the state*. Supersedes the `dist` field.
42 pub copy: bool,
43}
44
45impl fmt::Display for Counter {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 write!(f, "{self:#?}")
48 }
49}
50
51impl Counter {
52 /// Create a new counter with an operation that modifies the counter with
53 /// value 1.
54 pub fn new(operation: Operation) -> Self {
55 Counter {
56 operation,
57 dist: None,
58 copy: false,
59 }
60 }
61
62 /// Create a new counter with an operation and a distribution to sample the
63 /// value from.
64 pub fn new_dist(operation: Operation, dist: Dist) -> Self {
65 Counter {
66 operation,
67 dist: Some(dist),
68 copy: false,
69 }
70 }
71
72 /// Create a new counter with an operation that copies the value of the
73 /// other counter *prior to transitioning to the state*.
74 pub fn new_copy(operation: Operation) -> Self {
75 Counter {
76 operation,
77 dist: None,
78 copy: true,
79 }
80 }
81
82 /// Sample a value to update the counter with.
83 pub fn sample_value<R: RngCore>(&self, rng: &mut R) -> u64 {
84 // Maximum safe f64 value that can be converted to u64 without overflow
85 const MAX_SAFE_F64_TO_U64: f64 = u64::MAX as f64;
86
87 match self.dist {
88 None => 1,
89 Some(dist) => {
90 let sampled = dist.sample(rng);
91 if !sampled.is_finite() || sampled < 0.0 {
92 0
93 } else if sampled > MAX_SAFE_F64_TO_U64 {
94 u64::MAX
95 } else {
96 sampled as u64
97 }
98 }
99 }
100 }
101
102 // Validate the value dist.
103 pub fn validate(&self) -> Result<(), Error> {
104 if let Some(dist) = self.dist {
105 dist.validate()?;
106 }
107 Ok(())
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use crate::{counter::*, dist::DistType};
114
115 #[test]
116 fn validate_counter_update() {
117 // valid counter update
118 let mut cu = Counter::new_dist(
119 Operation::Increment,
120 Dist {
121 dist: DistType::Uniform {
122 low: 10.0,
123 high: 10.0,
124 },
125 start: 0.0,
126 max: 0.0,
127 },
128 );
129
130 let r = cu.validate();
131 assert!(r.is_ok());
132
133 // counter update with invalid dist
134 cu.dist = Some(Dist {
135 dist: DistType::Uniform {
136 low: 15.0, // NOTE low > high
137 high: 5.0,
138 },
139 start: 0.0,
140 max: 0.0,
141 });
142
143 let r = cu.validate();
144 assert!(r.is_err());
145
146 // counter with default value
147 cu.dist = None;
148
149 let r = cu.validate();
150 assert!(r.is_ok());
151
152 assert_eq!(cu.sample_value(&mut rand::rng()), 1);
153
154 // counter with copy value
155 cu.copy = true;
156
157 let r = cu.validate();
158 assert!(r.is_ok());
159 }
160
161 #[test]
162 fn sample_value_overflow_protection() {
163 use crate::dist::{Dist, DistType};
164
165 // Test with distribution that can produce very large values
166 let cu = Counter::new_dist(
167 Operation::Increment,
168 Dist {
169 dist: DistType::Uniform {
170 low: f64::MAX,
171 high: f64::MAX,
172 },
173 start: 0.0,
174 max: 0.0,
175 },
176 );
177
178 let sampled = cu.sample_value(&mut rand::rng());
179 assert_eq!(sampled, u64::MAX);
180
181 // Test with distribution that can produce negative values
182 let cu_negative = Counter::new_dist(
183 Operation::Increment,
184 Dist {
185 dist: DistType::Uniform {
186 low: -1000.0,
187 high: -500.0,
188 },
189 start: 0.0,
190 max: 0.0,
191 },
192 );
193
194 let sampled_negative = cu_negative.sample_value(&mut rand::rng());
195 assert_eq!(sampled_negative, 0);
196
197 // Test with distribution that can produce NaN (create invalid Normal dist case)
198 let cu_nan = Counter::new_dist(
199 Operation::Increment,
200 Dist {
201 dist: DistType::Normal {
202 mean: f64::NAN,
203 stdev: 1.0,
204 },
205 start: 0.0,
206 max: 0.0,
207 },
208 );
209
210 let sampled_nan = cu_nan.sample_value(&mut rand::rng());
211 assert_eq!(sampled_nan, 0);
212 }
213}