Skip to main content

tor_netdoc/types/policy/
addrpolicy.rs

1//! Implements address policies, based on a series of accept/reject
2//! rules.
3
4use std::fmt::Display;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::str::FromStr;
7
8use crate::NormalItemArgument;
9use crate::parse2::{
10    ErrorProblem as EP, ItemArgumentParseable, KeywordRef, NetdocParseableFields, UnparsedItem,
11};
12
13use super::{PolicyError, PortRange, RuleKind};
14
15/// A sequence of rules that are applied to an address:port until one
16/// matches.
17///
18/// Each rule is of the form "accept PATTERN" or "reject PATTERN",
19/// where every pattern describes a set of addresses and ports.
20/// Address sets are given as a prefix of 0-128 bits that the address
21/// must have; port sets are given as a low-bound and high-bound that
22/// the target port might lie between.
23///
24/// Relays use this type for defining their own policies, and for
25/// publishing their IPv4 policies.  Clients instead use
26/// [super::portpolicy::PortPolicy] objects to view a summary of the
27/// relays' declared policies.
28///
29/// An example IPv4 policy might be:
30///
31/// ```ignore
32///  reject *:25
33///  reject 127.0.0.0/8:*
34///  reject 192.168.0.0/16:*
35///  accept *:80
36///  accept *:443
37///  accept *:9000-65535
38///  reject *:*
39/// ```
40#[derive(Clone, Debug, Default, PartialEq, Eq)]
41pub struct AddrPolicy {
42    /// A list of rules to apply to find out whether an address is
43    /// contained by this policy.
44    ///
45    /// The rules apply in order; the first one to match determines
46    /// whether the address is accepted or rejected.
47    rules: Vec<AddrPolicyRule>,
48}
49
50impl AddrPolicy {
51    /// Apply this policy to an address:port combination
52    ///
53    /// We do this by applying each rule in sequence, until one
54    /// matches.
55    ///
56    /// Returns None if no rule matches.
57    pub fn allows(&self, addr: &IpAddr, port: u16) -> Option<RuleKind> {
58        self.rules
59            .iter()
60            .find(|rule| rule.pattern.matches(addr, port))
61            .map(|AddrPolicyRule { kind, .. }| *kind)
62    }
63
64    /// As allows, but accept a SocketAddr.
65    pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option<RuleKind> {
66        self.allows(&addr.ip(), addr.port())
67    }
68
69    /// Create a new AddrPolicy that matches nothing.
70    pub fn new() -> Self {
71        AddrPolicy::default()
72    }
73
74    /// Add a new rule to this policy.
75    ///
76    /// The newly added rule is applied _after_ all previous rules.
77    /// It matches all addresses and ports covered by AddrPortPattern.
78    ///
79    /// If accept is true, the rule is to accept addresses that match;
80    /// if accept is false, the rule rejects such addresses.
81    pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) {
82        self.rules.push(AddrPolicyRule { kind, pattern });
83    }
84}
85
86impl NetdocParseableFields for AddrPolicy {
87    type Accumulator = AddrPolicy;
88
89    fn is_item_keyword(kw: KeywordRef<'_>) -> bool {
90        matches!(kw.as_str(), "accept" | "reject")
91    }
92
93    fn accumulate_item(acc: &mut Self::Accumulator, mut item: UnparsedItem<'_>) -> Result<(), EP> {
94        // We must use `FromStr`, not argument parsing, because
95        // RuleKind is the keyword and not an argument.
96        let rule = RuleKind::from_str(item.keyword().as_str())
97            .map_err(|_| EP::Internal("accept/reject not a RuleKind?"))?;
98        let args = item.args_mut();
99        let pattern =
100            AddrPortPattern::from_args(args).map_err(args.error_handler("accept/reject"))?;
101        acc.push(rule, pattern);
102        Ok(())
103    }
104
105    fn finish(acc: Self::Accumulator) -> Result<Self, EP> {
106        Ok(acc)
107    }
108}
109
110/// A single rule in an address policy.
111///
112/// Contains a pattern and what to do with things that match it.
113#[derive(Clone, Debug, PartialEq, Eq)]
114struct AddrPolicyRule {
115    /// What do we do with items that match the pattern?
116    kind: RuleKind,
117    /// What pattern are we trying to match?
118    pattern: AddrPortPattern,
119}
120
121/*
122impl Display for AddrPolicyRule {
123    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124        let cmd = match self.kind {
125            RuleKind::Accept => "accept",
126            RuleKind::Reject => "reject",
127        };
128        write!(f, "{} {}", cmd, self.pattern)
129    }
130}
131*/
132
133/// A pattern that may or may not match an address and port.
134///
135/// Each AddrPortPattern has an IP pattern, which matches a set of
136/// addresses by prefix, and a port pattern, which matches a range of
137/// ports.
138///
139/// # Example
140///
141/// ```
142/// use tor_netdoc::types::policy::AddrPortPattern;
143/// use std::net::{IpAddr,Ipv4Addr};
144/// let localhost = IpAddr::V4(Ipv4Addr::new(127,3,4,5));
145/// let not_localhost = IpAddr::V4(Ipv4Addr::new(192,0,2,16));
146/// let pat: AddrPortPattern = "127.0.0.0/8:*".parse().unwrap();
147///
148/// assert!(pat.matches(&localhost, 22));
149/// assert!(! pat.matches(&not_localhost, 22));
150/// ```
151#[derive(
152    Clone, Debug, Eq, PartialEq, serde_with::SerializeDisplay, serde_with::DeserializeFromStr,
153)]
154pub struct AddrPortPattern {
155    /// A pattern to match somewhere between zero and all IP addresses.
156    pattern: IpPattern,
157    /// A pattern to match a range of ports.
158    ports: PortRange,
159}
160
161impl AddrPortPattern {
162    /// Return an AddrPortPattern matching all targets.
163    pub fn new_all() -> Self {
164        Self {
165            pattern: IpPattern::Star,
166            ports: PortRange::new_all(),
167        }
168    }
169
170    /// Return true iff this pattern matches a given address and port.
171    pub fn matches(&self, addr: &IpAddr, port: u16) -> bool {
172        self.pattern.matches(addr) && self.ports.contains(port)
173    }
174    /// As matches, but accept a SocketAddr.
175    pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool {
176        self.matches(&addr.ip(), addr.port())
177    }
178}
179
180impl Display for AddrPortPattern {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        if self.ports.is_all() {
183            write!(f, "{}:*", self.pattern)
184        } else {
185            write!(f, "{}:{}", self.pattern, self.ports)
186        }
187    }
188}
189
190impl FromStr for AddrPortPattern {
191    type Err = PolicyError;
192    fn from_str(s: &str) -> Result<Self, PolicyError> {
193        let last_colon = s.rfind(':').ok_or(PolicyError::InvalidPolicy)?;
194        let pattern: IpPattern = s[..last_colon].parse()?;
195        let ports_s = &s[last_colon + 1..];
196        let ports: PortRange = if ports_s == "*" {
197            PortRange::new_all()
198        } else {
199            ports_s.parse()?
200        };
201
202        Ok(AddrPortPattern { pattern, ports })
203    }
204}
205
206impl NormalItemArgument for AddrPortPattern {}
207
208/// A pattern that matches one or more IP addresses.
209//
210// TODO(nickm): At present there is no way for Display or FromStr to distinguish
211// V4Star, V6Star, and Star.  If we decide it's important to have a syntax for
212// "all IPv4 addresses" that isn't "0.0.0.0/0", we'll need to revisit that.
213// At present, C tor allows '*', '*4', and '*6'.
214#[derive(Clone, Debug, Eq, PartialEq)]
215enum IpPattern {
216    /// Match all addresses.
217    Star,
218    /// Match all IPv4 addresses.
219    V4Star,
220    /// Match all IPv6 addresses.
221    V6Star,
222    /// Match all IPv4 addresses beginning with a given prefix.
223    V4(Ipv4Addr, u8),
224    /// Match all IPv6 addresses beginning with a given prefix.
225    V6(Ipv6Addr, u8),
226}
227
228impl IpPattern {
229    /// Construct an IpPattern that matches the first `mask` bits of `addr`.
230    fn from_addr_and_mask(addr: IpAddr, mask: u8) -> Result<Self, PolicyError> {
231        match (addr, mask) {
232            (IpAddr::V4(_), 0) => Ok(IpPattern::V4Star),
233            (IpAddr::V6(_), 0) => Ok(IpPattern::V6Star),
234            (IpAddr::V4(a), m) if m <= 32 => Ok(IpPattern::V4(a, m)),
235            (IpAddr::V6(a), m) if m <= 128 => Ok(IpPattern::V6(a, m)),
236            (_, _) => Err(PolicyError::InvalidMask),
237        }
238    }
239    /// Return true iff `addr` is matched by this pattern.
240    fn matches(&self, addr: &IpAddr) -> bool {
241        match (self, addr) {
242            (IpPattern::Star, _) => true,
243            (IpPattern::V4Star, IpAddr::V4(_)) => true,
244            (IpPattern::V6Star, IpAddr::V6(_)) => true,
245            (IpPattern::V4(pat, mask), IpAddr::V4(addr)) => {
246                let p1 = u32::from_be_bytes(pat.octets());
247                let p2 = u32::from_be_bytes(addr.octets());
248                let shift = 32 - mask;
249                (p1 >> shift) == (p2 >> shift)
250            }
251            (IpPattern::V6(pat, mask), IpAddr::V6(addr)) => {
252                let p1 = u128::from_be_bytes(pat.octets());
253                let p2 = u128::from_be_bytes(addr.octets());
254                let shift = 128 - mask;
255                (p1 >> shift) == (p2 >> shift)
256            }
257            (_, _) => false,
258        }
259    }
260}
261
262impl Display for IpPattern {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        use IpPattern::*;
265        match self {
266            Star | V4Star | V6Star => write!(f, "*"),
267            V4(a, 32) => write!(f, "{}", a),
268            V4(a, m) => write!(f, "{}/{}", a, m),
269            V6(a, 128) => write!(f, "[{}]", a),
270            V6(a, m) => write!(f, "[{}]/{}", a, m),
271        }
272    }
273}
274
275/// Helper: try to parse a plain ipv4 address, or an IPv6 address
276/// wrapped in brackets.
277fn parse_addr(mut s: &str) -> Result<IpAddr, PolicyError> {
278    let bracketed = s.starts_with('[') && s.ends_with(']');
279    if bracketed {
280        s = &s[1..s.len() - 1];
281    }
282    let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?;
283    if addr.is_ipv6() != bracketed {
284        return Err(PolicyError::InvalidAddress);
285    }
286    Ok(addr)
287}
288
289impl FromStr for IpPattern {
290    type Err = PolicyError;
291    fn from_str(s: &str) -> Result<Self, PolicyError> {
292        let (ip_s, mask_s) = match s.find('/') {
293            Some(slash_idx) => (&s[..slash_idx], Some(&s[slash_idx + 1..])),
294            None => (s, None),
295        };
296        match (ip_s, mask_s) {
297            ("*", Some(_)) => Err(PolicyError::MaskWithStar),
298            ("*", None) => Ok(IpPattern::Star),
299            (s, Some(m)) => {
300                let a: IpAddr = parse_addr(s)?;
301                let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?;
302                IpPattern::from_addr_and_mask(a, m)
303            }
304            (s, None) => {
305                let a: IpAddr = parse_addr(s)?;
306                let m = if a.is_ipv4() { 32 } else { 128 };
307                IpPattern::from_addr_and_mask(a, m)
308            }
309        }
310    }
311}
312
313#[cfg(test)]
314mod test {
315    // @@ begin test lint list maintained by maint/add_warning @@
316    #![allow(clippy::bool_assert_comparison)]
317    #![allow(clippy::clone_on_copy)]
318    #![allow(clippy::dbg_macro)]
319    #![allow(clippy::mixed_attributes_style)]
320    #![allow(clippy::print_stderr)]
321    #![allow(clippy::print_stdout)]
322    #![allow(clippy::single_char_pattern)]
323    #![allow(clippy::unwrap_used)]
324    #![allow(clippy::unchecked_time_subtraction)]
325    #![allow(clippy::useless_vec)]
326    #![allow(clippy::needless_pass_by_value)]
327    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
328    use super::*;
329
330    #[test]
331    fn test_roundtrip_rules() {
332        fn check(inp: &str, outp: &str) {
333            let policy = inp.parse::<AddrPortPattern>().unwrap();
334            assert_eq!(format!("{}", policy), outp);
335        }
336
337        check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000");
338        check("127.0.0.2/32:*", "127.0.0.2:*");
339        check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100");
340        check("127.0.0.0/0:443", "*:443");
341        check("*:443", "*:443");
342        check("[::1]:443", "[::1]:443");
343        check("[ffaa::]/16:80", "[ffaa::]/16:80");
344        check("[ffaa::77]/128:80", "[ffaa::77]:80");
345    }
346
347    #[test]
348    fn test_bad_rules() {
349        fn check(s: &str) {
350            assert!(s.parse::<AddrPortPattern>().is_err());
351        }
352
353        check("marzipan:80");
354        check("1.2.3.4:90-80");
355        check("1.2.3.4/100:8888");
356        check("[1.2.3.4]/16:80");
357        check("[::1]/130:8888");
358    }
359
360    #[test]
361    fn test_rule_matches() {
362        fn check(addr: &str, yes: &[&str], no: &[&str]) {
363            use std::net::SocketAddr;
364            let policy = addr.parse::<AddrPortPattern>().unwrap();
365            for s in yes {
366                let sa = s.parse::<SocketAddr>().unwrap();
367                assert!(policy.matches_sockaddr(&sa));
368            }
369            for s in no {
370                let sa = s.parse::<SocketAddr>().unwrap();
371                assert!(!policy.matches_sockaddr(&sa));
372            }
373        }
374
375        check(
376            "1.2.3.4/16:80",
377            &["1.2.3.4:80", "1.2.44.55:80"],
378            &["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"],
379        );
380        check(
381            "*:443-8000",
382            &["1.2.3.4:443", "[::1]:500"],
383            &["9.0.0.0:80", "[::1]:80"],
384        );
385        check(
386            "[face::]/8:80",
387            &["[fab0::7]:80"],
388            &["[dd00::]:80", "[face::7]:443"],
389        );
390
391        check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]);
392        check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]);
393    }
394
395    #[test]
396    fn test_policy_matches() -> Result<(), PolicyError> {
397        let mut policy = AddrPolicy::default();
398        policy.push(RuleKind::Accept, "*:443".parse()?);
399        policy.push(RuleKind::Accept, "[::1]:80".parse()?);
400        policy.push(RuleKind::Reject, "*:80".parse()?);
401
402        let policy = policy; // drop mut
403        assert_eq!(
404            policy.allows_sockaddr(&"[::6]:443".parse().unwrap()),
405            Some(RuleKind::Accept)
406        );
407        assert_eq!(
408            policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()),
409            Some(RuleKind::Accept)
410        );
411        assert_eq!(
412            policy.allows_sockaddr(&"[::1]:80".parse().unwrap()),
413            Some(RuleKind::Accept)
414        );
415        assert_eq!(
416            policy.allows_sockaddr(&"[::2]:80".parse().unwrap()),
417            Some(RuleKind::Reject)
418        );
419        assert_eq!(
420            policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
421            Some(RuleKind::Reject)
422        );
423        assert_eq!(
424            policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()),
425            None
426        );
427        Ok(())
428    }
429
430    #[test]
431    fn serde() {
432        #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
433        struct X {
434            p1: AddrPortPattern,
435            p2: AddrPortPattern,
436        }
437
438        let x = X {
439            p1: "127.0.0.1/8:9-10".parse().unwrap(),
440            p2: "*:80".parse().unwrap(),
441        };
442
443        let encoded = serde_json::to_string(&x).unwrap();
444        let expected = r#"{"p1":"127.0.0.1/8:9-10","p2":"*:80"}"#;
445        let x2: X = serde_json::from_str(&encoded).unwrap();
446        let x3: X = serde_json::from_str(expected).unwrap();
447        assert_eq!(&x2, &x3);
448        assert_eq!(&x2, &x);
449    }
450
451    #[test]
452    fn parse2() {
453        use crate::{
454            parse2::{self, ParseInput},
455            types::Ignored,
456        };
457        use derive_deftly::Deftly;
458
459        const RULES: &str = "\
460        intro\n\
461        reject *:25\n\
462        reject 127.0.0.0/8:*\n\
463        reject 192.168.0.0/16:*\n\
464        accept *:80\n\
465        accept *:443\n\
466        accept *:9000-65535\n\
467        reject *:*\n";
468
469        #[derive(Deftly)]
470        #[derive_deftly(NetdocParseable)]
471        struct Wrapper {
472            #[allow(dead_code)]
473            intro: Ignored,
474            #[deftly(netdoc(flatten))]
475            ipv4_policy: AddrPolicy,
476        }
477
478        let wrapper = parse2::parse_netdoc::<Wrapper>(&ParseInput::new(RULES, "")).unwrap();
479        let ap = wrapper.ipv4_policy;
480
481        assert_eq!(
482            ap.allows_sockaddr(&"1.1.1.1:80".parse().unwrap()),
483            Some(RuleKind::Accept)
484        );
485        assert_eq!(
486            ap.allows_sockaddr(&"1.1.1.1:443".parse().unwrap()),
487            Some(RuleKind::Accept)
488        );
489        assert_eq!(
490            ap.allows_sockaddr(&"1.1.1.1:9005".parse().unwrap()),
491            Some(RuleKind::Accept)
492        );
493
494        assert_eq!(
495            ap.allows_sockaddr(&"1.1.1.1:25".parse().unwrap()),
496            Some(RuleKind::Reject)
497        );
498        assert_eq!(
499            ap.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
500            Some(RuleKind::Reject)
501        );
502        assert_eq!(
503            ap.allows_sockaddr(&"1.1.1.1:70".parse().unwrap()),
504            Some(RuleKind::Reject)
505        );
506    }
507}