1use std::fmt::Display;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
6use std::str::FromStr;
7
8use crate::NormalItemArgument;
9use crate::parse2::{
10 ErrorProblem as EP, ItemArgumentParseable, KeywordRef, NetdocParseableFields, UnparsedItem,
11};
12
13use super::{PolicyError, PortRange, RuleKind};
14
15#[derive(Clone, Debug, Default, PartialEq, Eq)]
41pub struct AddrPolicy {
42 rules: Vec<AddrPolicyRule>,
48}
49
50impl AddrPolicy {
51 pub fn allows(&self, addr: &IpAddr, port: u16) -> Option<RuleKind> {
58 self.rules
59 .iter()
60 .find(|rule| rule.pattern.matches(addr, port))
61 .map(|AddrPolicyRule { kind, .. }| *kind)
62 }
63
64 pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option<RuleKind> {
66 self.allows(&addr.ip(), addr.port())
67 }
68
69 pub fn new() -> Self {
71 AddrPolicy::default()
72 }
73
74 pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) {
82 self.rules.push(AddrPolicyRule { kind, pattern });
83 }
84}
85
86impl NetdocParseableFields for AddrPolicy {
87 type Accumulator = AddrPolicy;
88
89 fn is_item_keyword(kw: KeywordRef<'_>) -> bool {
90 matches!(kw.as_str(), "accept" | "reject")
91 }
92
93 fn accumulate_item(acc: &mut Self::Accumulator, mut item: UnparsedItem<'_>) -> Result<(), EP> {
94 let rule = RuleKind::from_str(item.keyword().as_str())
97 .map_err(|_| EP::Internal("accept/reject not a RuleKind?"))?;
98 let args = item.args_mut();
99 let pattern =
100 AddrPortPattern::from_args(args).map_err(args.error_handler("accept/reject"))?;
101 acc.push(rule, pattern);
102 Ok(())
103 }
104
105 fn finish(acc: Self::Accumulator) -> Result<Self, EP> {
106 Ok(acc)
107 }
108}
109
110#[derive(Clone, Debug, PartialEq, Eq)]
114struct AddrPolicyRule {
115 kind: RuleKind,
117 pattern: AddrPortPattern,
119}
120
121#[derive(
152 Clone, Debug, Eq, PartialEq, serde_with::SerializeDisplay, serde_with::DeserializeFromStr,
153)]
154pub struct AddrPortPattern {
155 pattern: IpPattern,
157 ports: PortRange,
159}
160
161impl AddrPortPattern {
162 pub fn new_all() -> Self {
164 Self {
165 pattern: IpPattern::Star,
166 ports: PortRange::new_all(),
167 }
168 }
169
170 pub fn matches(&self, addr: &IpAddr, port: u16) -> bool {
172 self.pattern.matches(addr) && self.ports.contains(port)
173 }
174 pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool {
176 self.matches(&addr.ip(), addr.port())
177 }
178}
179
180impl Display for AddrPortPattern {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 if self.ports.is_all() {
183 write!(f, "{}:*", self.pattern)
184 } else {
185 write!(f, "{}:{}", self.pattern, self.ports)
186 }
187 }
188}
189
190impl FromStr for AddrPortPattern {
191 type Err = PolicyError;
192 fn from_str(s: &str) -> Result<Self, PolicyError> {
193 let last_colon = s.rfind(':').ok_or(PolicyError::InvalidPolicy)?;
194 let pattern: IpPattern = s[..last_colon].parse()?;
195 let ports_s = &s[last_colon + 1..];
196 let ports: PortRange = if ports_s == "*" {
197 PortRange::new_all()
198 } else {
199 ports_s.parse()?
200 };
201
202 Ok(AddrPortPattern { pattern, ports })
203 }
204}
205
206impl NormalItemArgument for AddrPortPattern {}
207
208#[derive(Clone, Debug, Eq, PartialEq)]
215enum IpPattern {
216 Star,
218 V4Star,
220 V6Star,
222 V4(Ipv4Addr, u8),
224 V6(Ipv6Addr, u8),
226}
227
228impl IpPattern {
229 fn from_addr_and_mask(addr: IpAddr, mask: u8) -> Result<Self, PolicyError> {
231 match (addr, mask) {
232 (IpAddr::V4(_), 0) => Ok(IpPattern::V4Star),
233 (IpAddr::V6(_), 0) => Ok(IpPattern::V6Star),
234 (IpAddr::V4(a), m) if m <= 32 => Ok(IpPattern::V4(a, m)),
235 (IpAddr::V6(a), m) if m <= 128 => Ok(IpPattern::V6(a, m)),
236 (_, _) => Err(PolicyError::InvalidMask),
237 }
238 }
239 fn matches(&self, addr: &IpAddr) -> bool {
241 match (self, addr) {
242 (IpPattern::Star, _) => true,
243 (IpPattern::V4Star, IpAddr::V4(_)) => true,
244 (IpPattern::V6Star, IpAddr::V6(_)) => true,
245 (IpPattern::V4(pat, mask), IpAddr::V4(addr)) => {
246 let p1 = u32::from_be_bytes(pat.octets());
247 let p2 = u32::from_be_bytes(addr.octets());
248 let shift = 32 - mask;
249 (p1 >> shift) == (p2 >> shift)
250 }
251 (IpPattern::V6(pat, mask), IpAddr::V6(addr)) => {
252 let p1 = u128::from_be_bytes(pat.octets());
253 let p2 = u128::from_be_bytes(addr.octets());
254 let shift = 128 - mask;
255 (p1 >> shift) == (p2 >> shift)
256 }
257 (_, _) => false,
258 }
259 }
260}
261
262impl Display for IpPattern {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 use IpPattern::*;
265 match self {
266 Star | V4Star | V6Star => write!(f, "*"),
267 V4(a, 32) => write!(f, "{}", a),
268 V4(a, m) => write!(f, "{}/{}", a, m),
269 V6(a, 128) => write!(f, "[{}]", a),
270 V6(a, m) => write!(f, "[{}]/{}", a, m),
271 }
272 }
273}
274
275fn parse_addr(mut s: &str) -> Result<IpAddr, PolicyError> {
278 let bracketed = s.starts_with('[') && s.ends_with(']');
279 if bracketed {
280 s = &s[1..s.len() - 1];
281 }
282 let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?;
283 if addr.is_ipv6() != bracketed {
284 return Err(PolicyError::InvalidAddress);
285 }
286 Ok(addr)
287}
288
289impl FromStr for IpPattern {
290 type Err = PolicyError;
291 fn from_str(s: &str) -> Result<Self, PolicyError> {
292 let (ip_s, mask_s) = match s.find('/') {
293 Some(slash_idx) => (&s[..slash_idx], Some(&s[slash_idx + 1..])),
294 None => (s, None),
295 };
296 match (ip_s, mask_s) {
297 ("*", Some(_)) => Err(PolicyError::MaskWithStar),
298 ("*", None) => Ok(IpPattern::Star),
299 (s, Some(m)) => {
300 let a: IpAddr = parse_addr(s)?;
301 let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?;
302 IpPattern::from_addr_and_mask(a, m)
303 }
304 (s, None) => {
305 let a: IpAddr = parse_addr(s)?;
306 let m = if a.is_ipv4() { 32 } else { 128 };
307 IpPattern::from_addr_and_mask(a, m)
308 }
309 }
310 }
311}
312
313#[cfg(test)]
314mod test {
315 #![allow(clippy::bool_assert_comparison)]
317 #![allow(clippy::clone_on_copy)]
318 #![allow(clippy::dbg_macro)]
319 #![allow(clippy::mixed_attributes_style)]
320 #![allow(clippy::print_stderr)]
321 #![allow(clippy::print_stdout)]
322 #![allow(clippy::single_char_pattern)]
323 #![allow(clippy::unwrap_used)]
324 #![allow(clippy::unchecked_time_subtraction)]
325 #![allow(clippy::useless_vec)]
326 #![allow(clippy::needless_pass_by_value)]
327 use super::*;
329
330 #[test]
331 fn test_roundtrip_rules() {
332 fn check(inp: &str, outp: &str) {
333 let policy = inp.parse::<AddrPortPattern>().unwrap();
334 assert_eq!(format!("{}", policy), outp);
335 }
336
337 check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000");
338 check("127.0.0.2/32:*", "127.0.0.2:*");
339 check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100");
340 check("127.0.0.0/0:443", "*:443");
341 check("*:443", "*:443");
342 check("[::1]:443", "[::1]:443");
343 check("[ffaa::]/16:80", "[ffaa::]/16:80");
344 check("[ffaa::77]/128:80", "[ffaa::77]:80");
345 }
346
347 #[test]
348 fn test_bad_rules() {
349 fn check(s: &str) {
350 assert!(s.parse::<AddrPortPattern>().is_err());
351 }
352
353 check("marzipan:80");
354 check("1.2.3.4:90-80");
355 check("1.2.3.4/100:8888");
356 check("[1.2.3.4]/16:80");
357 check("[::1]/130:8888");
358 }
359
360 #[test]
361 fn test_rule_matches() {
362 fn check(addr: &str, yes: &[&str], no: &[&str]) {
363 use std::net::SocketAddr;
364 let policy = addr.parse::<AddrPortPattern>().unwrap();
365 for s in yes {
366 let sa = s.parse::<SocketAddr>().unwrap();
367 assert!(policy.matches_sockaddr(&sa));
368 }
369 for s in no {
370 let sa = s.parse::<SocketAddr>().unwrap();
371 assert!(!policy.matches_sockaddr(&sa));
372 }
373 }
374
375 check(
376 "1.2.3.4/16:80",
377 &["1.2.3.4:80", "1.2.44.55:80"],
378 &["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"],
379 );
380 check(
381 "*:443-8000",
382 &["1.2.3.4:443", "[::1]:500"],
383 &["9.0.0.0:80", "[::1]:80"],
384 );
385 check(
386 "[face::]/8:80",
387 &["[fab0::7]:80"],
388 &["[dd00::]:80", "[face::7]:443"],
389 );
390
391 check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]);
392 check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]);
393 }
394
395 #[test]
396 fn test_policy_matches() -> Result<(), PolicyError> {
397 let mut policy = AddrPolicy::default();
398 policy.push(RuleKind::Accept, "*:443".parse()?);
399 policy.push(RuleKind::Accept, "[::1]:80".parse()?);
400 policy.push(RuleKind::Reject, "*:80".parse()?);
401
402 let policy = policy; assert_eq!(
404 policy.allows_sockaddr(&"[::6]:443".parse().unwrap()),
405 Some(RuleKind::Accept)
406 );
407 assert_eq!(
408 policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()),
409 Some(RuleKind::Accept)
410 );
411 assert_eq!(
412 policy.allows_sockaddr(&"[::1]:80".parse().unwrap()),
413 Some(RuleKind::Accept)
414 );
415 assert_eq!(
416 policy.allows_sockaddr(&"[::2]:80".parse().unwrap()),
417 Some(RuleKind::Reject)
418 );
419 assert_eq!(
420 policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
421 Some(RuleKind::Reject)
422 );
423 assert_eq!(
424 policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()),
425 None
426 );
427 Ok(())
428 }
429
430 #[test]
431 fn serde() {
432 #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
433 struct X {
434 p1: AddrPortPattern,
435 p2: AddrPortPattern,
436 }
437
438 let x = X {
439 p1: "127.0.0.1/8:9-10".parse().unwrap(),
440 p2: "*:80".parse().unwrap(),
441 };
442
443 let encoded = serde_json::to_string(&x).unwrap();
444 let expected = r#"{"p1":"127.0.0.1/8:9-10","p2":"*:80"}"#;
445 let x2: X = serde_json::from_str(&encoded).unwrap();
446 let x3: X = serde_json::from_str(expected).unwrap();
447 assert_eq!(&x2, &x3);
448 assert_eq!(&x2, &x);
449 }
450
451 #[test]
452 fn parse2() {
453 use crate::{
454 parse2::{self, ParseInput},
455 types::Ignored,
456 };
457 use derive_deftly::Deftly;
458
459 const RULES: &str = "\
460 intro\n\
461 reject *:25\n\
462 reject 127.0.0.0/8:*\n\
463 reject 192.168.0.0/16:*\n\
464 accept *:80\n\
465 accept *:443\n\
466 accept *:9000-65535\n\
467 reject *:*\n";
468
469 #[derive(Deftly)]
470 #[derive_deftly(NetdocParseable)]
471 struct Wrapper {
472 #[allow(dead_code)]
473 intro: Ignored,
474 #[deftly(netdoc(flatten))]
475 ipv4_policy: AddrPolicy,
476 }
477
478 let wrapper = parse2::parse_netdoc::<Wrapper>(&ParseInput::new(RULES, "")).unwrap();
479 let ap = wrapper.ipv4_policy;
480
481 assert_eq!(
482 ap.allows_sockaddr(&"1.1.1.1:80".parse().unwrap()),
483 Some(RuleKind::Accept)
484 );
485 assert_eq!(
486 ap.allows_sockaddr(&"1.1.1.1:443".parse().unwrap()),
487 Some(RuleKind::Accept)
488 );
489 assert_eq!(
490 ap.allows_sockaddr(&"1.1.1.1:9005".parse().unwrap()),
491 Some(RuleKind::Accept)
492 );
493
494 assert_eq!(
495 ap.allows_sockaddr(&"1.1.1.1:25".parse().unwrap()),
496 Some(RuleKind::Reject)
497 );
498 assert_eq!(
499 ap.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
500 Some(RuleKind::Reject)
501 );
502 assert_eq!(
503 ap.allows_sockaddr(&"1.1.1.1:70".parse().unwrap()),
504 Some(RuleKind::Reject)
505 );
506 }
507}