Skip to main content

maybenot/
state.rs

1//! A state as part of a [`Machine`](crate::Machine). Contains an optional
2//! [`Action`] and [`Counter`] to be executed upon transition to this state, and
3//! a vector of state transitions for each possible [`Event`].
4
5use crate::constants::{EVENT_NUM, STATE_END, STATE_SIGNAL};
6use crate::{Error, action, counter, event};
7use enum_map::Enum;
8use enum_map::EnumMap;
9use rand::RngCore;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::HashSet;
13use std::fmt;
14
15use self::action::Action;
16use self::counter::Counter;
17use self::event::Event;
18
19use enum_map::enum_map;
20
21/// A state index and probability for a transition.
22#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
23pub struct Trans(pub usize, pub f32);
24
25impl fmt::Display for Trans {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        if self.1 == 1.0 {
28            write!(f, "{}", self.0)
29        } else {
30            write!(f, "{} ({})", self.0, self.1)
31        }
32    }
33}
34
35/// A state as part of a [`Machine`](crate::Machine).
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct State {
38    /// Take an action upon transitioning to this state.
39    pub action: Option<Action>,
40    /// On transition to this state, update the machine's two counters (A,B).
41    pub counter: (Option<Counter>, Option<Counter>),
42    /// For each possible [`Event`], a vector of state transitions.
43    transitions: [Option<Vec<Trans>>; EVENT_NUM],
44}
45
46impl State {
47    /// Create a new [`State`] that transitions on the given [`Event`]s.
48    ///
49    /// Example:
50    /// ```
51    /// use maybenot::state::*;
52    /// use maybenot::event::*;
53    /// use enum_map::enum_map;
54    /// let state = State::new(enum_map! {
55    ///     Event::PaddingSent => vec![Trans(1, 1.0)],
56    ///     Event::CounterZero => vec![Trans(2, 1.0)],
57    ///     _ => vec![],
58    /// });
59    /// ```
60    /// This creates a state that transitions to state 1 on
61    /// [`Event::PaddingSent`] and to state 2 on [`Event::CounterZero`], both
62    /// with 100% probability. All other events will not cause a transition.
63    /// Note that state indexes are 0-based and determined by the order in which
64    /// states are added to the [`Machine`](crate::Machine).
65    pub fn new(t: EnumMap<Event, Vec<Trans>>) -> Self {
66        const ARRAY_NO_TRANS: Option<Vec<Trans>> = None;
67        let mut transitions = [ARRAY_NO_TRANS; EVENT_NUM];
68        for (event, vector) in t {
69            if !vector.is_empty() {
70                transitions[event.to_usize()] = Some(vector);
71            }
72        }
73
74        State {
75            transitions,
76            action: None,
77            counter: (None, None),
78        }
79    }
80
81    /// Validate that this state has acceptable transitions and that the
82    /// distributions, if set, are valid. Note that num_states is the number of
83    /// states in the machine, not the number of states in this state's
84    /// transitions. Called by [`Machine::new`](crate::machine::Machine::new).
85    pub fn validate(&self, num_states: usize) -> Result<(), Error> {
86        // validate transition probabilities
87        for (event, transitions) in self.transitions.iter().enumerate() {
88            let Some(transitions) = transitions else {
89                continue;
90            };
91            if self.transitions.is_empty() {
92                Err(Error::Machine(format!(
93                    "found empty transition vector for {}",
94                    &event
95                )))?;
96            }
97
98            let mut sum: f32 = 0.0;
99            let mut seen: HashSet<usize> = HashSet::new();
100
101            for t in transitions.iter() {
102                if t.0 >= num_states && t.0 != STATE_END && t.0 != STATE_SIGNAL {
103                    Err(Error::Machine(format!(
104                        "found out-of-bounds state index {}",
105                        t.0
106                    )))?;
107                }
108                if seen.contains(&t.0) {
109                    Err(Error::Machine(format!(
110                        "found duplicate state index {}",
111                        t.0
112                    )))?;
113                }
114                seen.insert(t.0);
115
116                if t.1 <= 0.0 || t.1 > 1.0 {
117                    Err(Error::Machine(format!(
118                        "found probability {}, has to be (0.0, 1.0]",
119                        t.1
120                    )))?;
121                }
122                sum += t.1;
123            }
124
125            if sum <= 0.0 || sum > 1.0 {
126                Err(Error::Machine(format!(
127                    "found invalid total probability vector {} for {}, must be (0.0, 1.0]",
128                    &sum, &event
129                )))?;
130            }
131        }
132
133        // validate distribution parameters
134        // check that required distributions are present
135        if let Some(action) = &self.action {
136            action.validate()?;
137        }
138        if let Some(counter) = &self.counter.0 {
139            counter.validate()?;
140        }
141        if let Some(counter) = &self.counter.1 {
142            counter.validate()?;
143        }
144
145        Ok(())
146    }
147
148    /// Sample a state to transition to given an [`Event`].
149    pub fn sample_state<R: RngCore>(&self, event: Event, rng: &mut R) -> Option<usize> {
150        use rand::Rng;
151        if let Some(vector) = &self.transitions[event.to_usize()] {
152            let mut sum = 0.0;
153            let r = rng.random_range(0.0..1.0);
154            for t in vector.iter() {
155                sum += t.1;
156                if r < sum {
157                    return Some(t.0);
158                }
159            }
160        }
161        None
162    }
163
164    /// Get the transitions for this state as an [`EnumMap`] of [`Event`] to
165    /// vectors of [`Trans`].
166    pub fn get_transitions(&self) -> EnumMap<Event, Vec<Trans>> {
167        let mut map = enum_map! {_ => vec![]};
168        for (event, vector) in self.transitions.iter().enumerate() {
169            if let Some(vector) = vector {
170                if vector.is_empty() {
171                    continue;
172                }
173                map[Event::from_usize(event)].clone_from(vector);
174            }
175        }
176
177        map
178    }
179}
180
181impl fmt::Display for State {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        if let Some(action) = self.action {
184            writeln!(f, "action: {action}")?;
185        } else {
186            writeln!(f, "action: None")?;
187        }
188        match self.counter {
189            (Some(counter_a), Some(counter_b)) => {
190                writeln!(f, "counter A: {counter_a}")?;
191                writeln!(f, "counter B: {counter_b}")?;
192            }
193            (Some(counter), None) => {
194                writeln!(f, "counter A: {counter}")?;
195            }
196            (None, Some(counter)) => {
197                writeln!(f, "counter B: {counter}")?;
198            }
199            _ => {
200                writeln!(f, "counter: None")?;
201            }
202        }
203
204        writeln!(f, "transitions: ")?;
205        for event in Event::iter() {
206            if let Some(vector) = &self.transitions[event.to_usize()] {
207                if vector.is_empty() {
208                    continue;
209                }
210                write!(f, "\t{event}:")?;
211                for trans in vector {
212                    write!(f, " {trans}")?;
213                    if trans != vector.last().unwrap() {
214                        write!(f, ",")?;
215                    }
216                }
217                writeln!(f)?;
218            }
219        }
220
221        Ok(())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use crate::counter::{Counter, Operation};
228    use crate::dist::{Dist, DistType};
229    use crate::event::Event;
230    use crate::state::*;
231    use enum_map::enum_map;
232
233    #[test]
234    fn serialization() {
235        // Ensure that sampling works after deserialization
236        let s0 = State::new(enum_map! {
237                 Event::PaddingSent => vec![Trans(6, 1.0)],
238             _ => vec![],
239        });
240
241        let s0 = bincode::serialize(&s0).unwrap();
242        let s0: State = bincode::deserialize(&s0).unwrap();
243
244        assert_eq!(
245            s0.sample_state(Event::PaddingSent, &mut rand::rng()),
246            Some(6)
247        );
248    }
249
250    #[test]
251    fn validate_state_transitions() {
252        // assume a machine with two states
253        let num_states = 2;
254
255        // out of bounds index
256        let s = State::new(enum_map! {
257                 Event::PaddingSent => vec![Trans(num_states, 1.0)],
258             _ => vec![],
259        });
260        let r = s.validate(num_states);
261        println!("{:?}", r.as_ref().err());
262        assert!(r.is_err());
263
264        // try setting one probability too high
265        let s = State::new(enum_map! {
266                 Event::PaddingSent => vec![Trans(0, 1.1)],
267             _ => vec![],
268        });
269        let r = s.validate(num_states);
270        println!("{:?}", r.as_ref().err());
271        assert!(r.is_err());
272
273        // try setting total probability too high
274        let s = State::new(enum_map! {
275                 Event::PaddingSent => vec![Trans(0, 0.5), Trans(1, 0.6)],
276             _ => vec![],
277        });
278        let r = s.validate(num_states);
279        println!("{:?}", r.as_ref().err());
280        assert!(r.is_err());
281
282        // try specifying duplicate transitions
283        let s = State::new(enum_map! {
284                 Event::PaddingSent => vec![Trans(0, 0.4), Trans(0, 0.6)],
285             _ => vec![],
286        });
287        let r = s.validate(num_states);
288        println!("{:?}", r.as_ref().err());
289        assert!(r.is_err());
290
291        // valid transitions should be allowed
292        let s = State::new(enum_map! {
293                 Event::PaddingSent => vec![Trans(0, 0.4), Trans(STATE_END, 0.3)],
294             _ => vec![],
295        });
296        let r = s.validate(num_states);
297        assert!(r.is_ok());
298    }
299
300    #[test]
301    fn validate_state_action() {
302        // assume a machine with one state
303        let num_states = 1;
304
305        // valid actions should be allowed
306        let mut s = State::new(enum_map! {
307                 Event::PaddingSent => vec![Trans(0, 1.0)],
308             _ => vec![],
309        });
310        s.action = Some(Action::SendPadding {
311            bypass: false,
312            replace: false,
313            timeout: Dist {
314                dist: DistType::Uniform {
315                    low: 10.0,
316                    high: 10.0,
317                },
318                start: 0.0,
319                max: 0.0,
320            },
321            limit: None,
322        });
323
324        let r = s.validate(num_states);
325        println!("{:?}", r.as_ref().err());
326        assert!(r.is_ok());
327
328        // invalid action in state
329        s.action = Some(Action::SendPadding {
330            bypass: false,
331            replace: false,
332            timeout: Dist {
333                dist: DistType::Uniform {
334                    low: 2.0, // NOTE low > high
335                    high: 1.0,
336                },
337                start: 0.0,
338                max: 0.0,
339            },
340            limit: None,
341        });
342
343        let r = s.validate(num_states);
344        println!("{:?}", r.as_ref().err());
345        assert!(r.is_err());
346    }
347
348    #[test]
349    fn validate_state_counter() {
350        // assume a machine with one state
351        let num_states = 1;
352
353        // valid counter updates should be allowed
354        let mut s = State::new(enum_map! {
355                 Event::PaddingSent => vec![Trans(0, 1.0)],
356             _ => vec![],
357        });
358        s.counter = (Some(Counter::new(Operation::Increment)), None);
359
360        let r = s.validate(num_states);
361        println!("{:?}", r.as_ref().err());
362        assert!(r.is_ok());
363
364        // invalid counter update in state
365        s.counter = (
366            None,
367            Some(Counter::new_dist(
368                Operation::Increment,
369                Dist {
370                    dist: DistType::Uniform {
371                        low: 2.0, // NOTE low > high
372                        high: 1.0,
373                    },
374                    start: 0.0,
375                    max: 0.0,
376                },
377            )),
378        );
379
380        let r = s.validate(num_states);
381        println!("{:?}", r.as_ref().err());
382        assert!(r.is_err());
383    }
384}