Skip to main content

tor_netdoc/types/
policy.rs

1//! Exit policies: match patterns of addresses and/or ports.
2//!
3//! Every Tor relays has a set of address:port combinations that it
4//! actually allows connections to.  The set, abstractly, is the
5//! relay's "exit policy".
6//!
7//! Address policies can be transmitted in two forms.  One is a "full
8//! policy", that includes a list of rules that are applied in order
9//! to represent addresses and ports.  We represent this with the
10//! AddrPolicy type.
11//!
12//! In microdescriptors, and for IPv6 policies, policies are just
13//! given a list of ports for which _most_ addresses are permitted.
14//! We represent this kind of policy with the PortPolicy type.
15//!
16//! TODO: This module probably belongs in a crate of its own, with
17//! possibly only the parsing code in this crate.
18
19mod addrpolicy;
20mod portpolicy;
21
22use std::str::FromStr;
23use std::{collections::BTreeSet, fmt::Display};
24use thiserror::Error;
25
26pub use addrpolicy::{AddrPolicy, AddrPortPattern};
27pub use portpolicy::PortPolicy;
28
29use crate::NormalItemArgument;
30use crate::parse2::{ArgumentError, ArgumentStream, ItemArgumentParseable};
31
32/// Error from an unparsable or invalid policy.
33#[derive(Debug, Error, Clone, PartialEq, Eq)]
34#[non_exhaustive]
35pub enum PolicyError {
36    /// A port was not a number in the range 1..65535
37    #[error("Invalid port")]
38    InvalidPort,
39    /// A port range had its starting-point higher than its ending point.
40    #[error("Invalid port range")]
41    InvalidRange,
42    /// An address could not be interpreted.
43    #[error("Invalid address")]
44    InvalidAddress,
45    /// Tried to use a bitmask with the address "*".
46    #[error("mask with star")]
47    MaskWithStar,
48    /// A bit mask was out of range.
49    #[error("invalid mask")]
50    InvalidMask,
51    /// A policy could not be parsed for some other reason.
52    #[error("Invalid policy")]
53    InvalidPolicy,
54}
55
56/// A PortRange is a set of consecutively numbered TCP or UDP ports.
57///
58/// # Example
59/// ```
60/// use tor_netdoc::types::policy::PortRange;
61///
62/// let r: PortRange = "22-8000".parse().unwrap();
63/// assert!(r.contains(128));
64/// assert!(r.contains(22));
65/// assert!(r.contains(8000));
66///
67/// assert!(! r.contains(21));
68/// assert!(! r.contains(8001));
69/// ```
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71#[allow(clippy::exhaustive_structs)]
72pub struct PortRange {
73    /// The first port in this range.
74    lo: u16,
75    /// The last port in this range.
76    hi: u16,
77}
78
79impl PortRange {
80    /// Create a new port range spanning from lo to hi, asserting that
81    /// the correct invariants hold.
82    fn new_unchecked(lo: u16, hi: u16) -> Self {
83        assert!(lo != 0);
84        assert!(lo <= hi);
85        PortRange { lo, hi }
86    }
87    /// Create a port range containing all ports.
88    pub fn new_all() -> Self {
89        PortRange::new_unchecked(1, 65535)
90    }
91    /// Create a new PortRange.
92    ///
93    /// The Portrange contains all ports between `lo` and `hi` inclusive.
94    ///
95    /// Returns None if lo is greater than hi, or if either is zero.
96    pub fn new(lo: u16, hi: u16) -> Option<Self> {
97        if lo != 0 && lo <= hi {
98            Some(PortRange { lo, hi })
99        } else {
100            None
101        }
102    }
103    /// Return true if a port is in this range.
104    pub fn contains(&self, port: u16) -> bool {
105        self.lo <= port && port <= self.hi
106    }
107    /// Return true if this range contains all ports.
108    pub fn is_all(&self) -> bool {
109        self.lo == 1 && self.hi == 65535
110    }
111
112    /// Helper for binary search: compare this range to a port.
113    ///
114    /// This range is "equal" to all ports that it contains.  It is
115    /// "greater" than all ports that precede its starting point, and
116    /// "less" than all ports that follow its ending point.
117    fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
118        use std::cmp::Ordering::*;
119        if port < self.lo {
120            Greater
121        } else if port <= self.hi {
122            Equal
123        } else {
124            Less
125        }
126    }
127}
128
129/// A PortRange is displayed as a number if it contains a single port,
130/// and as a start point and end point separated by a dash if it contains
131/// more than one port.
132impl Display for PortRange {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        if self.lo == self.hi {
135            write!(f, "{}", self.lo)
136        } else {
137            write!(f, "{}-{}", self.lo, self.hi)
138        }
139    }
140}
141
142impl FromStr for PortRange {
143    type Err = PolicyError;
144    fn from_str(s: &str) -> Result<Self, PolicyError> {
145        let idx = s.find('-');
146        // Find "lo" and "hi".
147        let (lo, hi) = if let Some(pos) = idx {
148            // This is a range; parse each part.
149            (
150                s[..pos]
151                    .parse::<u16>()
152                    .map_err(|_| PolicyError::InvalidPort)?,
153                s[pos + 1..]
154                    .parse::<u16>()
155                    .map_err(|_| PolicyError::InvalidPort)?,
156            )
157        } else {
158            // There was no hyphen, so try to parse this range as a singleton.
159            let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
160            (v, v)
161        };
162        PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
163    }
164}
165
166impl NormalItemArgument for PortRange {}
167
168/// A collection of port ranges in a sorted order.
169///
170/// Please use this when storing multiple port ranges because it optimizies
171/// them storage wise.
172// TODO: We should rewrite most of this, the implementation has lots of
173// potential for off-by-one errors and such.
174#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
175// Invariant:
176//
177// The `PortRange`s are valid, nonoverlapping, non-abutting, and sorted.
178struct PortRanges(Vec<PortRange>);
179
180impl PortRanges {
181    /// Creates a new [`PortRanges`] collection with no elements in it.
182    fn new() -> Self {
183        Self(Vec::new())
184    }
185
186    /// Checks whether there are no ranges in this instance.
187    fn is_empty(&self) -> bool {
188        self.0.is_empty()
189    }
190
191    /// Adds a new range into this [`PortRanges`].
192    ///
193    /// The ranges must be valid, nonoverlapping, and pushed in a monotonically increasing order,
194    /// meaning that inserting `400-500,450-600` or `400-500,500-600` are
195    /// invalid, whereas `400-500,501-600` and `400-500,501-600` are.
196    fn push_ordered(&mut self, item: PortRange) -> Result<(), PolicyError> {
197        if let Some(prev) = self.0.last() {
198            // TODO SPEC: We don't enforce this in Tor, but we probably
199            // should.  See torspec#60.
200            if prev.hi >= item.lo {
201                return Err(PolicyError::InvalidPolicy);
202            } else if prev.hi == item.lo - 1 {
203                // We compress a-b,(b+1)-c into a-c.
204                let r = PortRange::new_unchecked(prev.lo, item.hi);
205                self.0.pop();
206                self.0.push(r);
207                return Ok(());
208            }
209        }
210
211        self.0.push(item);
212        Ok(())
213    }
214
215    /// Checks whether `port` is contained in a range.
216    ///
217    /// Whether this means if `port` is allowed or rejected depends on the
218    /// surroundings (such as which field this `PortRage` is in,
219    /// or an associated [`RuleKind`]).
220    fn contains(&self, port: u16) -> bool {
221        debug_assert!(self.0.is_sorted_by(|a, b| a.lo < b.lo));
222        self.0
223            .binary_search_by(|range| range.compare_to_port(port))
224            .is_ok()
225    }
226
227    /// Inverts a [`PortRanges`].
228    ///
229    /// For example, a [`PortRanges`] of `80-443` would become `1-79,444-65535`.
230    fn invert(&mut self) {
231        let mut prev_hi = 0;
232        let mut new_allowed = Vec::new();
233        for entry in &self.0 {
234            // ports prev_hi+1 through entry.lo-1 were rejected.  We should
235            // make them allowed.
236            if entry.lo > prev_hi + 1 {
237                new_allowed.push(PortRange::new_unchecked(prev_hi + 1, entry.lo - 1));
238            }
239            prev_hi = entry.hi;
240        }
241        if prev_hi < 65535 {
242            new_allowed.push(PortRange::new_unchecked(prev_hi + 1, 65535));
243        }
244        self.0 = new_allowed;
245    }
246
247    /// Returns an iterator for [`PortRanges`].
248    fn iter(&self) -> impl Iterator<Item = &PortRange> {
249        self.0.iter()
250    }
251}
252
253impl FromIterator<u16> for PortRanges {
254    fn from_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
255        // Collect all ports into a BTreeSet to have them sorted and deduped.
256        let ports = iter.into_iter().collect::<BTreeSet<_>>();
257        let mut ports = ports.into_iter().peekable();
258
259        let mut out = Self::new();
260        let mut current_min = None;
261        while let Some(port) = ports.next() {
262            if current_min.is_none() {
263                current_min = Some(port);
264            }
265            if let Some(next_port) = ports.peek().copied() {
266                // We do not have to worry about port == 65535, because then
267                // ports.peek() will be None, as each item in the BTreeSet is
268                // ordered and unique, implying that there won't be a successor
269                // to a port == 65535.
270                if next_port != port + 1 {
271                    let _ = out.push_ordered(PortRange::new_unchecked(
272                        current_min.expect("Don't have min port number"),
273                        port,
274                    ));
275                    current_min = None;
276                }
277            } else {
278                let _ = out.push_ordered(PortRange::new_unchecked(
279                    current_min.expect("Don't have min port number"),
280                    port,
281                ));
282            }
283        }
284
285        out
286    }
287}
288
289// There is deliberately no Display implementation for PortRanges because this
290// highly depends on the semantic wrapper around it.  For example, an empty
291// PortRanges may either be represented as `reject 1-65535` or `accept 1-65535`
292// depending on the context.
293
294impl FromStr for PortRanges {
295    type Err = PolicyError;
296
297    fn from_str(s: &str) -> Result<Self, Self::Err> {
298        // Pitfall: Do not use a clever iterator here because we need the result
299        // of .push() in order to avoid things such as `30-19`.
300        let mut ranges = Self::new();
301        for range in s.split(',') {
302            ranges.push_ordered(range.parse()?)?;
303        }
304        Ok(ranges)
305    }
306}
307
308impl ItemArgumentParseable for PortRanges {
309    /// [`PortRanges`] argument parser which is odd because port ranges are
310    /// syntactically a single argument although semantically multiple ones.
311    fn from_args<'s>(args: &mut ArgumentStream<'s>) -> Result<Self, ArgumentError> {
312        args.next()
313            .map(Self::from_str)
314            .unwrap_or(Ok(Self::new()))
315            .map_err(|_| ArgumentError::Invalid)
316    }
317}
318
319/// A kind of policy rule: either accepts or rejects addresses
320/// matching a pattern.
321#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, derive_more::FromStr)]
322#[display(rename_all = "lowercase")]
323#[from_str(rename_all = "lowercase")]
324#[allow(clippy::exhaustive_enums)]
325pub enum RuleKind {
326    /// A rule that accepts matching address:port combinations.
327    Accept,
328    /// A rule that rejects matching address:port combinations.
329    Reject,
330}
331
332impl NormalItemArgument for RuleKind {}
333
334#[cfg(test)]
335mod test {
336    // @@ begin test lint list maintained by maint/add_warning @@
337    #![allow(clippy::bool_assert_comparison)]
338    #![allow(clippy::clone_on_copy)]
339    #![allow(clippy::dbg_macro)]
340    #![allow(clippy::mixed_attributes_style)]
341    #![allow(clippy::print_stderr)]
342    #![allow(clippy::print_stdout)]
343    #![allow(clippy::single_char_pattern)]
344    #![allow(clippy::unwrap_used)]
345    #![allow(clippy::unchecked_time_subtraction)]
346    #![allow(clippy::useless_vec)]
347    #![allow(clippy::needless_pass_by_value)]
348    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
349    use super::*;
350    use crate::Result;
351    use crate::parse2::{self, ParseInput};
352
353    #[test]
354    fn parse_portrange() -> Result<()> {
355        assert_eq!(
356            "1-100".parse::<PortRange>()?,
357            PortRange::new(1, 100).unwrap()
358        );
359        assert_eq!(
360            "01-100".parse::<PortRange>()?,
361            PortRange::new(1, 100).unwrap()
362        );
363        assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
364        assert_eq!(
365            "10-30".parse::<PortRange>()?,
366            PortRange::new(10, 30).unwrap()
367        );
368        assert_eq!(
369            "9001".parse::<PortRange>()?,
370            PortRange::new(9001, 9001).unwrap()
371        );
372        assert_eq!(
373            "9001-9001".parse::<PortRange>()?,
374            PortRange::new(9001, 9001).unwrap()
375        );
376
377        assert!("hello".parse::<PortRange>().is_err());
378        assert!("0".parse::<PortRange>().is_err());
379        assert!("65536".parse::<PortRange>().is_err());
380        assert!("65537".parse::<PortRange>().is_err());
381        assert!("1-2-3".parse::<PortRange>().is_err());
382        assert!("10-5".parse::<PortRange>().is_err());
383        assert!("1-".parse::<PortRange>().is_err());
384        assert!("-2".parse::<PortRange>().is_err());
385        assert!("-".parse::<PortRange>().is_err());
386        assert!("*".parse::<PortRange>().is_err());
387        Ok(())
388    }
389
390    #[test]
391    fn pr_manip() {
392        assert!(PortRange::new_all().is_all());
393        assert!(!PortRange::new(2, 65535).unwrap().is_all());
394
395        assert!(PortRange::new_all().contains(1));
396        assert!(PortRange::new_all().contains(65535));
397        assert!(PortRange::new_all().contains(7777));
398
399        assert!(PortRange::new(20, 30).unwrap().contains(20));
400        assert!(PortRange::new(20, 30).unwrap().contains(25));
401        assert!(PortRange::new(20, 30).unwrap().contains(30));
402        assert!(!PortRange::new(20, 30).unwrap().contains(19));
403        assert!(!PortRange::new(20, 30).unwrap().contains(31));
404
405        use std::cmp::Ordering::*;
406        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
407        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
408        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
409        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
410        assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
411    }
412
413    #[test]
414    fn pr_fmt() {
415        fn chk(a: u16, b: u16, s: &str) {
416            let pr = PortRange::new(a, b).unwrap();
417            assert_eq!(format!("{}", pr), s);
418        }
419
420        chk(1, 65535, "1-65535");
421        chk(10, 20, "10-20");
422        chk(20, 20, "20");
423    }
424
425    #[test]
426    fn port_ranges() {
427        const INPUT: &str = "22,80,443,8000-9000,9002";
428        let ranges = PortRanges::from_str(INPUT).unwrap();
429        assert_eq!(
430            ranges.0,
431            [
432                PortRange::new(22, 22).unwrap(),
433                PortRange::new(80, 80).unwrap(),
434                PortRange::new(443, 443).unwrap(),
435                PortRange::new(8000, 9000).unwrap(),
436                PortRange::new(9002, 9002).unwrap(),
437            ]
438        );
439        assert!(ranges.contains(22));
440        assert!(ranges.contains(80));
441        assert!(ranges.contains(443));
442        assert!(ranges.contains(8000));
443        assert!(ranges.contains(8500));
444        assert!(ranges.contains(9000));
445        assert!(!ranges.contains(9001));
446        assert!(ranges.contains(9002));
447
448        let mut ranges_inverse = ranges.clone();
449        ranges_inverse.invert();
450        assert_eq!(
451            ranges_inverse.0,
452            [
453                PortRange::new(1, 21).unwrap(),
454                PortRange::new(23, 79).unwrap(),
455                PortRange::new(81, 442).unwrap(),
456                PortRange::new(444, 7999).unwrap(),
457                PortRange::new(9001, 9001).unwrap(),
458                PortRange::new(9003, 65535).unwrap(),
459            ]
460        );
461
462        #[derive(derive_deftly::Deftly)]
463        #[derive_deftly(NetdocParseable)]
464        struct Dummy {
465            #[deftly(netdoc(single_arg))]
466            dummy: PortRanges,
467        }
468        let ranges2 =
469            parse2::parse_netdoc::<Dummy>(&ParseInput::new(&format!("dummy {INPUT}\n"), ""))
470                .unwrap();
471        assert_eq!(ranges, ranges2.dummy);
472    }
473}