1use crate::constants::{EVENT_NUM, STATE_END, STATE_SIGNAL};
6use crate::{Error, action, counter, event};
7use enum_map::Enum;
8use enum_map::EnumMap;
9use rand::RngCore;
10use serde::Deserialize;
11use serde::Serialize;
12use std::collections::HashSet;
13use std::fmt;
14
15use self::action::Action;
16use self::counter::Counter;
17use self::event::Event;
18
19use enum_map::enum_map;
20
21#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
23pub struct Trans(pub usize, pub f32);
24
25impl fmt::Display for Trans {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 if self.1 == 1.0 {
28 write!(f, "{}", self.0)
29 } else {
30 write!(f, "{} ({})", self.0, self.1)
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct State {
38 pub action: Option<Action>,
40 pub counter: (Option<Counter>, Option<Counter>),
42 transitions: [Option<Vec<Trans>>; EVENT_NUM],
44}
45
46impl State {
47 pub fn new(t: EnumMap<Event, Vec<Trans>>) -> Self {
66 const ARRAY_NO_TRANS: Option<Vec<Trans>> = None;
67 let mut transitions = [ARRAY_NO_TRANS; EVENT_NUM];
68 for (event, vector) in t {
69 if !vector.is_empty() {
70 transitions[event.to_usize()] = Some(vector);
71 }
72 }
73
74 State {
75 transitions,
76 action: None,
77 counter: (None, None),
78 }
79 }
80
81 pub fn validate(&self, num_states: usize) -> Result<(), Error> {
86 for (event, transitions) in self.transitions.iter().enumerate() {
88 let Some(transitions) = transitions else {
89 continue;
90 };
91 if self.transitions.is_empty() {
92 Err(Error::Machine(format!(
93 "found empty transition vector for {}",
94 &event
95 )))?;
96 }
97
98 let mut sum: f32 = 0.0;
99 let mut seen: HashSet<usize> = HashSet::new();
100
101 for t in transitions.iter() {
102 if t.0 >= num_states && t.0 != STATE_END && t.0 != STATE_SIGNAL {
103 Err(Error::Machine(format!(
104 "found out-of-bounds state index {}",
105 t.0
106 )))?;
107 }
108 if seen.contains(&t.0) {
109 Err(Error::Machine(format!(
110 "found duplicate state index {}",
111 t.0
112 )))?;
113 }
114 seen.insert(t.0);
115
116 if t.1 <= 0.0 || t.1 > 1.0 {
117 Err(Error::Machine(format!(
118 "found probability {}, has to be (0.0, 1.0]",
119 t.1
120 )))?;
121 }
122 sum += t.1;
123 }
124
125 if sum <= 0.0 || sum > 1.0 {
126 Err(Error::Machine(format!(
127 "found invalid total probability vector {} for {}, must be (0.0, 1.0]",
128 &sum, &event
129 )))?;
130 }
131 }
132
133 if let Some(action) = &self.action {
136 action.validate()?;
137 }
138 if let Some(counter) = &self.counter.0 {
139 counter.validate()?;
140 }
141 if let Some(counter) = &self.counter.1 {
142 counter.validate()?;
143 }
144
145 Ok(())
146 }
147
148 pub fn sample_state<R: RngCore>(&self, event: Event, rng: &mut R) -> Option<usize> {
150 use rand::Rng;
151 if let Some(vector) = &self.transitions[event.to_usize()] {
152 let mut sum = 0.0;
153 let r = rng.random_range(0.0..1.0);
154 for t in vector.iter() {
155 sum += t.1;
156 if r < sum {
157 return Some(t.0);
158 }
159 }
160 }
161 None
162 }
163
164 pub fn get_transitions(&self) -> EnumMap<Event, Vec<Trans>> {
167 let mut map = enum_map! {_ => vec![]};
168 for (event, vector) in self.transitions.iter().enumerate() {
169 if let Some(vector) = vector {
170 if vector.is_empty() {
171 continue;
172 }
173 map[Event::from_usize(event)].clone_from(vector);
174 }
175 }
176
177 map
178 }
179}
180
181impl fmt::Display for State {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 if let Some(action) = self.action {
184 writeln!(f, "action: {action}")?;
185 } else {
186 writeln!(f, "action: None")?;
187 }
188 match self.counter {
189 (Some(counter_a), Some(counter_b)) => {
190 writeln!(f, "counter A: {counter_a}")?;
191 writeln!(f, "counter B: {counter_b}")?;
192 }
193 (Some(counter), None) => {
194 writeln!(f, "counter A: {counter}")?;
195 }
196 (None, Some(counter)) => {
197 writeln!(f, "counter B: {counter}")?;
198 }
199 _ => {
200 writeln!(f, "counter: None")?;
201 }
202 }
203
204 writeln!(f, "transitions: ")?;
205 for event in Event::iter() {
206 if let Some(vector) = &self.transitions[event.to_usize()] {
207 if vector.is_empty() {
208 continue;
209 }
210 write!(f, "\t{event}:")?;
211 for trans in vector {
212 write!(f, " {trans}")?;
213 if trans != vector.last().unwrap() {
214 write!(f, ",")?;
215 }
216 }
217 writeln!(f)?;
218 }
219 }
220
221 Ok(())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use crate::counter::{Counter, Operation};
228 use crate::dist::{Dist, DistType};
229 use crate::event::Event;
230 use crate::state::*;
231 use enum_map::enum_map;
232
233 #[test]
234 fn serialization() {
235 let s0 = State::new(enum_map! {
237 Event::PaddingSent => vec![Trans(6, 1.0)],
238 _ => vec![],
239 });
240
241 let s0 = bincode::serialize(&s0).unwrap();
242 let s0: State = bincode::deserialize(&s0).unwrap();
243
244 assert_eq!(
245 s0.sample_state(Event::PaddingSent, &mut rand::rng()),
246 Some(6)
247 );
248 }
249
250 #[test]
251 fn validate_state_transitions() {
252 let num_states = 2;
254
255 let s = State::new(enum_map! {
257 Event::PaddingSent => vec![Trans(num_states, 1.0)],
258 _ => vec![],
259 });
260 let r = s.validate(num_states);
261 println!("{:?}", r.as_ref().err());
262 assert!(r.is_err());
263
264 let s = State::new(enum_map! {
266 Event::PaddingSent => vec![Trans(0, 1.1)],
267 _ => vec![],
268 });
269 let r = s.validate(num_states);
270 println!("{:?}", r.as_ref().err());
271 assert!(r.is_err());
272
273 let s = State::new(enum_map! {
275 Event::PaddingSent => vec![Trans(0, 0.5), Trans(1, 0.6)],
276 _ => vec![],
277 });
278 let r = s.validate(num_states);
279 println!("{:?}", r.as_ref().err());
280 assert!(r.is_err());
281
282 let s = State::new(enum_map! {
284 Event::PaddingSent => vec![Trans(0, 0.4), Trans(0, 0.6)],
285 _ => vec![],
286 });
287 let r = s.validate(num_states);
288 println!("{:?}", r.as_ref().err());
289 assert!(r.is_err());
290
291 let s = State::new(enum_map! {
293 Event::PaddingSent => vec![Trans(0, 0.4), Trans(STATE_END, 0.3)],
294 _ => vec![],
295 });
296 let r = s.validate(num_states);
297 assert!(r.is_ok());
298 }
299
300 #[test]
301 fn validate_state_action() {
302 let num_states = 1;
304
305 let mut s = State::new(enum_map! {
307 Event::PaddingSent => vec![Trans(0, 1.0)],
308 _ => vec![],
309 });
310 s.action = Some(Action::SendPadding {
311 bypass: false,
312 replace: false,
313 timeout: Dist {
314 dist: DistType::Uniform {
315 low: 10.0,
316 high: 10.0,
317 },
318 start: 0.0,
319 max: 0.0,
320 },
321 limit: None,
322 });
323
324 let r = s.validate(num_states);
325 println!("{:?}", r.as_ref().err());
326 assert!(r.is_ok());
327
328 s.action = Some(Action::SendPadding {
330 bypass: false,
331 replace: false,
332 timeout: Dist {
333 dist: DistType::Uniform {
334 low: 2.0, high: 1.0,
336 },
337 start: 0.0,
338 max: 0.0,
339 },
340 limit: None,
341 });
342
343 let r = s.validate(num_states);
344 println!("{:?}", r.as_ref().err());
345 assert!(r.is_err());
346 }
347
348 #[test]
349 fn validate_state_counter() {
350 let num_states = 1;
352
353 let mut s = State::new(enum_map! {
355 Event::PaddingSent => vec![Trans(0, 1.0)],
356 _ => vec![],
357 });
358 s.counter = (Some(Counter::new(Operation::Increment)), None);
359
360 let r = s.validate(num_states);
361 println!("{:?}", r.as_ref().err());
362 assert!(r.is_ok());
363
364 s.counter = (
366 None,
367 Some(Counter::new_dist(
368 Operation::Increment,
369 Dist {
370 dist: DistType::Uniform {
371 low: 2.0, high: 1.0,
373 },
374 start: 0.0,
375 max: 0.0,
376 },
377 )),
378 );
379
380 let r = s.validate(num_states);
381 println!("{:?}", r.as_ref().err());
382 assert!(r.is_err());
383 }
384}