Skip to main content

maybenot/
machine.rs

1//! A machine determines when to inject and/or block outgoing traffic. Consists
2//! of one or more [`State`] structs.
3
4use crate::constants::{MAX_DECOMPRESSED_SIZE, STATE_MAX, VERSION};
5use crate::{Error, state};
6use base64::prelude::*;
7use bincode::Options;
8use flate2::Compression;
9use flate2::read::ZlibDecoder;
10use flate2::write::ZlibEncoder;
11use serde::{Deserialize, Serialize};
12use sha256::digest;
13use std::fmt;
14use std::io::prelude::*;
15use std::str::FromStr;
16
17use self::state::State;
18
19/// A probabilistic state machine (Rabin automaton) consisting of one or more
20/// [`State`] that determine when to inject and/or block outgoing traffic.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Machine {
23    /// The number of padding packets the machine is allowed to generate as
24    /// actions before other limits apply.
25    pub allowed_padding_packets: u64,
26    /// The maximum fraction of padding packets to allow as actions.
27    pub max_padding_frac: f64,
28    /// The number of microseconds of blocking a machine is allowed to generate
29    /// as actions before other limits apply.
30    pub allowed_blocked_microsec: u64,
31    /// The maximum fraction of blocking (microseconds) to allow as actions.
32    pub max_blocking_frac: f64,
33    /// The states that make up the machine.
34    pub states: Vec<State>,
35}
36
37impl Machine {
38    /// Create a new [`Machine`] with the given limits and states. Returns an
39    /// error if the machine or any of its states are invalid.
40    pub fn new(
41        allowed_padding_packets: u64,
42        max_padding_frac: f64,
43        allowed_blocked_microsec: u64,
44        max_blocking_frac: f64,
45        states: Vec<State>,
46    ) -> Result<Self, Error> {
47        let machine = Machine {
48            allowed_padding_packets,
49            max_padding_frac,
50            allowed_blocked_microsec,
51            max_blocking_frac,
52            states,
53        };
54        machine.validate()?;
55
56        Ok(machine)
57    }
58
59    /// Get a unique and deterministic string that represents the machine. The
60    /// string is 32 characters long, hex-encoded.
61    pub fn name(&self) -> String {
62        let s = digest(self.serialize());
63        s[0..32].to_string()
64    }
65
66    pub fn serialize(&self) -> String {
67        let bincoder = bincode::DefaultOptions::new().with_limit(MAX_DECOMPRESSED_SIZE as u64);
68        let encoded = bincoder.serialize(&self).unwrap();
69        let mut e = ZlibEncoder::new(Vec::new(), Compression::best());
70        e.write_all(encoded.as_slice()).unwrap();
71        let s = BASE64_STANDARD.encode(e.finish().unwrap());
72        // version as first 2 characters, then base64 compressed bincoded
73        format!("{VERSION:02}{s}")
74    }
75
76    /// Validates that the machine is in a valid state (machines that are
77    /// mutated may get into an invalid state).
78    pub fn validate(&self) -> Result<(), Error> {
79        // sane limits
80        if self.max_padding_frac < 0.0 || self.max_padding_frac > 1.0 {
81            return Err(Error::Machine(format!(
82                "max_padding_frac has to be [0.0, 1.0], got {}",
83                self.max_padding_frac
84            )));
85        }
86        if self.max_blocking_frac < 0.0 || self.max_blocking_frac > 1.0 {
87            return Err(Error::Machine(format!(
88                "max_blocking_frac has to be [0.0, 1.0], got {}",
89                self.max_blocking_frac
90            )));
91        }
92
93        // sane number of states
94        let num_states = self.states.len();
95
96        if num_states == 0 {
97            Err(Error::Machine(
98                "a machine must have at least one state".to_string(),
99            ))?;
100        }
101        if num_states > STATE_MAX {
102            Err(Error::Machine(format!(
103                "too many states, max is {STATE_MAX}, found {num_states}"
104            )))?;
105        }
106
107        // validate all states
108        for state in self.states.iter() {
109            state
110                .validate(num_states)
111                .map_err(|e| Error::Machine(e.to_string()))?;
112        }
113
114        Ok(())
115    }
116}
117
118/// From a serialized string, attempt to create a machine.
119impl FromStr for Machine {
120    type Err = Error;
121
122    fn from_str(s: &str) -> Result<Self, Self::Err> {
123        // version as first 2 bytes (len() checks bytes)
124        if s.len() < 3 {
125            Err(Error::Machine("string too short".to_string()))?;
126        }
127        if !s.is_ascii() {
128            Err(Error::Machine("string is not ascii".to_string()))?;
129        }
130        let version = &s[0..2];
131        if version != format!("{VERSION:02}") {
132            Err(Error::Machine(format!(
133                "version mismatch, expected {VERSION}, got {version}"
134            )))?;
135        }
136        let s = &s[2..];
137
138        // base64 decoding has a fixed ratio of ~4:3
139        let compressed = BASE64_STANDARD.decode(s.as_bytes());
140        if compressed.is_err() {
141            Err(Error::Machine("base64 decoding failed".to_string()))?;
142        }
143        let compressed = compressed.unwrap();
144        // decompress, but scared of exceeding memory limits / zlib bombs
145        let mut decoder = ZlibDecoder::new(compressed.as_slice());
146        let mut buf = vec![0; MAX_DECOMPRESSED_SIZE];
147        let bytes_read = decoder
148            .read(&mut buf)
149            .map_err(|e| Error::Machine(e.to_string()))?;
150
151        // With binencode, note that "The size of the encoded object will be the
152        // same or smaller than the size that the object takes up in memory in a
153        // running Rust program".
154        let bincoder = bincode::DefaultOptions::new().with_limit(MAX_DECOMPRESSED_SIZE as u64);
155        let r = bincoder.deserialize(&buf[..bytes_read]);
156
157        // ensure that the machine is valid
158        let m: Machine = r.map_err(|e| Error::Machine(e.to_string()))?;
159        m.validate()?;
160        Ok(m)
161    }
162}
163
164impl fmt::Display for Machine {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        write!(
167            f,
168            "Machine {}\n\
169            - allowed_padding_packets: {}\n\
170            - max_padding_frac: {}\n\
171            - allowed_blocked_microsec: {}\n\
172            - max_blocking_frac: {}\n\
173            States:\n\
174            {}",
175            self.name(),
176            self.allowed_padding_packets,
177            self.max_padding_frac,
178            self.allowed_blocked_microsec,
179            self.max_blocking_frac,
180            self.states
181                .iter()
182                .map(|s| format!("{s}"))
183                .collect::<Vec<String>>()
184                .join("\n")
185        )
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use crate::event::Event;
192    use crate::machine::*;
193    use crate::state::Trans;
194    use enum_map::enum_map;
195
196    #[test]
197    fn machine_name_generation() {
198        let s0 = State::new(enum_map! {
199                 Event::PaddingSent => vec![Trans(0, 1.0)],
200             _ => vec![],
201        });
202
203        // machine
204        let m = Machine::new(1000, 1.0, 0, 0.0, vec![s0]).unwrap();
205
206        // name generation should be deterministic
207        assert_eq!(m.name(), m.name());
208    }
209
210    #[test]
211    fn validate_machine_limits() {
212        let s0 = State::new(enum_map! {
213               Event::PaddingSent => vec![Trans(0, 1.0)],
214             _ => vec![],
215        });
216
217        let mut m = Machine::new(1000, 1.0, 0, 0.0, vec![s0]).unwrap();
218
219        // max padding frac
220        m.max_padding_frac = -0.1;
221        let r = m.validate();
222        println!("{:?}", r.as_ref().err());
223        assert!(r.is_err());
224
225        m.max_padding_frac = 1.1;
226        let r = m.validate();
227        println!("{:?}", r.as_ref().err());
228        assert!(r.is_err());
229
230        m.max_padding_frac = 0.5;
231        let r = m.validate();
232        assert!(r.is_ok());
233
234        // max blocking frac
235        m.max_blocking_frac = -0.1;
236        let r = m.validate();
237        println!("{:?}", r.as_ref().err());
238        assert!(r.is_err());
239
240        m.max_blocking_frac = 1.1;
241        let r = m.validate();
242        println!("{:?}", r.as_ref().err());
243        assert!(r.is_err());
244
245        m.max_blocking_frac = 0.5;
246        let r = m.validate();
247        assert!(r.is_ok());
248    }
249
250    #[test]
251    fn validate_machine_num_states() {
252        // invalid machine lacking state
253        let r = Machine::new(1000, 1.0, 0, 0.0, vec![]);
254
255        println!("{:?}", r.as_ref().err());
256        assert!(r.is_err());
257    }
258
259    #[test]
260    fn validate_machine_states() {
261        // out of bounds index
262        let s0 = State::new(enum_map! {
263                 Event::PaddingSent => vec![Trans(1, 1.0)],
264             _ => vec![],
265        });
266        // machine with broken state
267        let r = Machine::new(1000, 1.0, 0, 0.0, vec![s0]);
268        println!("{:?}", r.as_ref().err());
269        assert!(r.is_err());
270
271        // valid states should be allowed
272        let s0 = State::new(enum_map! {
273                 Event::PaddingSent => vec![Trans(0, 0.8)],
274             _ => vec![],
275        });
276        let r = Machine::new(1000, 1.0, 0, 0.0, vec![s0]);
277        assert!(r.is_ok());
278    }
279}