1mod addrpolicy;
20mod portpolicy;
21
22use std::str::FromStr;
23use std::{collections::BTreeSet, fmt::Display};
24use thiserror::Error;
25
26pub use addrpolicy::{AddrPolicy, AddrPortPattern};
27pub use portpolicy::PortPolicy;
28
29use crate::NormalItemArgument;
30use crate::parse2::{ArgumentError, ArgumentStream, ItemArgumentParseable};
31
32#[derive(Debug, Error, Clone, PartialEq, Eq)]
34#[non_exhaustive]
35pub enum PolicyError {
36 #[error("Invalid port")]
38 InvalidPort,
39 #[error("Invalid port range")]
41 InvalidRange,
42 #[error("Invalid address")]
44 InvalidAddress,
45 #[error("mask with star")]
47 MaskWithStar,
48 #[error("invalid mask")]
50 InvalidMask,
51 #[error("Invalid policy")]
53 InvalidPolicy,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71#[allow(clippy::exhaustive_structs)]
72pub struct PortRange {
73 lo: u16,
75 hi: u16,
77}
78
79impl PortRange {
80 fn new_unchecked(lo: u16, hi: u16) -> Self {
83 assert!(lo != 0);
84 assert!(lo <= hi);
85 PortRange { lo, hi }
86 }
87 pub fn new_all() -> Self {
89 PortRange::new_unchecked(1, 65535)
90 }
91 pub fn new(lo: u16, hi: u16) -> Option<Self> {
97 if lo != 0 && lo <= hi {
98 Some(PortRange { lo, hi })
99 } else {
100 None
101 }
102 }
103 pub fn contains(&self, port: u16) -> bool {
105 self.lo <= port && port <= self.hi
106 }
107 pub fn is_all(&self) -> bool {
109 self.lo == 1 && self.hi == 65535
110 }
111
112 fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
118 use std::cmp::Ordering::*;
119 if port < self.lo {
120 Greater
121 } else if port <= self.hi {
122 Equal
123 } else {
124 Less
125 }
126 }
127}
128
129impl Display for PortRange {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 if self.lo == self.hi {
135 write!(f, "{}", self.lo)
136 } else {
137 write!(f, "{}-{}", self.lo, self.hi)
138 }
139 }
140}
141
142impl FromStr for PortRange {
143 type Err = PolicyError;
144 fn from_str(s: &str) -> Result<Self, PolicyError> {
145 let idx = s.find('-');
146 let (lo, hi) = if let Some(pos) = idx {
148 (
150 s[..pos]
151 .parse::<u16>()
152 .map_err(|_| PolicyError::InvalidPort)?,
153 s[pos + 1..]
154 .parse::<u16>()
155 .map_err(|_| PolicyError::InvalidPort)?,
156 )
157 } else {
158 let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
160 (v, v)
161 };
162 PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
163 }
164}
165
166impl NormalItemArgument for PortRange {}
167
168#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
175struct PortRanges(Vec<PortRange>);
179
180impl PortRanges {
181 fn new() -> Self {
183 Self(Vec::new())
184 }
185
186 fn is_empty(&self) -> bool {
188 self.0.is_empty()
189 }
190
191 fn push_ordered(&mut self, item: PortRange) -> Result<(), PolicyError> {
197 if let Some(prev) = self.0.last() {
198 if prev.hi >= item.lo {
201 return Err(PolicyError::InvalidPolicy);
202 } else if prev.hi == item.lo - 1 {
203 let r = PortRange::new_unchecked(prev.lo, item.hi);
205 self.0.pop();
206 self.0.push(r);
207 return Ok(());
208 }
209 }
210
211 self.0.push(item);
212 Ok(())
213 }
214
215 fn contains(&self, port: u16) -> bool {
221 debug_assert!(self.0.is_sorted_by(|a, b| a.lo < b.lo));
222 self.0
223 .binary_search_by(|range| range.compare_to_port(port))
224 .is_ok()
225 }
226
227 fn invert(&mut self) {
231 let mut prev_hi = 0;
232 let mut new_allowed = Vec::new();
233 for entry in &self.0 {
234 if entry.lo > prev_hi + 1 {
237 new_allowed.push(PortRange::new_unchecked(prev_hi + 1, entry.lo - 1));
238 }
239 prev_hi = entry.hi;
240 }
241 if prev_hi < 65535 {
242 new_allowed.push(PortRange::new_unchecked(prev_hi + 1, 65535));
243 }
244 self.0 = new_allowed;
245 }
246
247 fn iter(&self) -> impl Iterator<Item = &PortRange> {
249 self.0.iter()
250 }
251}
252
253impl FromIterator<u16> for PortRanges {
254 fn from_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
255 let ports = iter.into_iter().collect::<BTreeSet<_>>();
257 let mut ports = ports.into_iter().peekable();
258
259 let mut out = Self::new();
260 let mut current_min = None;
261 while let Some(port) = ports.next() {
262 if current_min.is_none() {
263 current_min = Some(port);
264 }
265 if let Some(next_port) = ports.peek().copied() {
266 if next_port != port + 1 {
271 let _ = out.push_ordered(PortRange::new_unchecked(
272 current_min.expect("Don't have min port number"),
273 port,
274 ));
275 current_min = None;
276 }
277 } else {
278 let _ = out.push_ordered(PortRange::new_unchecked(
279 current_min.expect("Don't have min port number"),
280 port,
281 ));
282 }
283 }
284
285 out
286 }
287}
288
289impl FromStr for PortRanges {
295 type Err = PolicyError;
296
297 fn from_str(s: &str) -> Result<Self, Self::Err> {
298 let mut ranges = Self::new();
301 for range in s.split(',') {
302 ranges.push_ordered(range.parse()?)?;
303 }
304 Ok(ranges)
305 }
306}
307
308impl ItemArgumentParseable for PortRanges {
309 fn from_args<'s>(args: &mut ArgumentStream<'s>) -> Result<Self, ArgumentError> {
312 args.next()
313 .map(Self::from_str)
314 .unwrap_or(Ok(Self::new()))
315 .map_err(|_| ArgumentError::Invalid)
316 }
317}
318
319#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display, derive_more::FromStr)]
322#[display(rename_all = "lowercase")]
323#[from_str(rename_all = "lowercase")]
324#[allow(clippy::exhaustive_enums)]
325pub enum RuleKind {
326 Accept,
328 Reject,
330}
331
332impl NormalItemArgument for RuleKind {}
333
334#[cfg(test)]
335mod test {
336 #![allow(clippy::bool_assert_comparison)]
338 #![allow(clippy::clone_on_copy)]
339 #![allow(clippy::dbg_macro)]
340 #![allow(clippy::mixed_attributes_style)]
341 #![allow(clippy::print_stderr)]
342 #![allow(clippy::print_stdout)]
343 #![allow(clippy::single_char_pattern)]
344 #![allow(clippy::unwrap_used)]
345 #![allow(clippy::unchecked_time_subtraction)]
346 #![allow(clippy::useless_vec)]
347 #![allow(clippy::needless_pass_by_value)]
348 use super::*;
350 use crate::Result;
351 use crate::parse2::{self, ParseInput};
352
353 #[test]
354 fn parse_portrange() -> Result<()> {
355 assert_eq!(
356 "1-100".parse::<PortRange>()?,
357 PortRange::new(1, 100).unwrap()
358 );
359 assert_eq!(
360 "01-100".parse::<PortRange>()?,
361 PortRange::new(1, 100).unwrap()
362 );
363 assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
364 assert_eq!(
365 "10-30".parse::<PortRange>()?,
366 PortRange::new(10, 30).unwrap()
367 );
368 assert_eq!(
369 "9001".parse::<PortRange>()?,
370 PortRange::new(9001, 9001).unwrap()
371 );
372 assert_eq!(
373 "9001-9001".parse::<PortRange>()?,
374 PortRange::new(9001, 9001).unwrap()
375 );
376
377 assert!("hello".parse::<PortRange>().is_err());
378 assert!("0".parse::<PortRange>().is_err());
379 assert!("65536".parse::<PortRange>().is_err());
380 assert!("65537".parse::<PortRange>().is_err());
381 assert!("1-2-3".parse::<PortRange>().is_err());
382 assert!("10-5".parse::<PortRange>().is_err());
383 assert!("1-".parse::<PortRange>().is_err());
384 assert!("-2".parse::<PortRange>().is_err());
385 assert!("-".parse::<PortRange>().is_err());
386 assert!("*".parse::<PortRange>().is_err());
387 Ok(())
388 }
389
390 #[test]
391 fn pr_manip() {
392 assert!(PortRange::new_all().is_all());
393 assert!(!PortRange::new(2, 65535).unwrap().is_all());
394
395 assert!(PortRange::new_all().contains(1));
396 assert!(PortRange::new_all().contains(65535));
397 assert!(PortRange::new_all().contains(7777));
398
399 assert!(PortRange::new(20, 30).unwrap().contains(20));
400 assert!(PortRange::new(20, 30).unwrap().contains(25));
401 assert!(PortRange::new(20, 30).unwrap().contains(30));
402 assert!(!PortRange::new(20, 30).unwrap().contains(19));
403 assert!(!PortRange::new(20, 30).unwrap().contains(31));
404
405 use std::cmp::Ordering::*;
406 assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
407 assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
408 assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
409 assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
410 assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
411 }
412
413 #[test]
414 fn pr_fmt() {
415 fn chk(a: u16, b: u16, s: &str) {
416 let pr = PortRange::new(a, b).unwrap();
417 assert_eq!(format!("{}", pr), s);
418 }
419
420 chk(1, 65535, "1-65535");
421 chk(10, 20, "10-20");
422 chk(20, 20, "20");
423 }
424
425 #[test]
426 fn port_ranges() {
427 const INPUT: &str = "22,80,443,8000-9000,9002";
428 let ranges = PortRanges::from_str(INPUT).unwrap();
429 assert_eq!(
430 ranges.0,
431 [
432 PortRange::new(22, 22).unwrap(),
433 PortRange::new(80, 80).unwrap(),
434 PortRange::new(443, 443).unwrap(),
435 PortRange::new(8000, 9000).unwrap(),
436 PortRange::new(9002, 9002).unwrap(),
437 ]
438 );
439 assert!(ranges.contains(22));
440 assert!(ranges.contains(80));
441 assert!(ranges.contains(443));
442 assert!(ranges.contains(8000));
443 assert!(ranges.contains(8500));
444 assert!(ranges.contains(9000));
445 assert!(!ranges.contains(9001));
446 assert!(ranges.contains(9002));
447
448 let mut ranges_inverse = ranges.clone();
449 ranges_inverse.invert();
450 assert_eq!(
451 ranges_inverse.0,
452 [
453 PortRange::new(1, 21).unwrap(),
454 PortRange::new(23, 79).unwrap(),
455 PortRange::new(81, 442).unwrap(),
456 PortRange::new(444, 7999).unwrap(),
457 PortRange::new(9001, 9001).unwrap(),
458 PortRange::new(9003, 65535).unwrap(),
459 ]
460 );
461
462 #[derive(derive_deftly::Deftly)]
463 #[derive_deftly(NetdocParseable)]
464 struct Dummy {
465 #[deftly(netdoc(single_arg))]
466 dummy: PortRanges,
467 }
468 let ranges2 =
469 parse2::parse_netdoc::<Dummy>(&ParseInput::new(&format!("dummy {INPUT}\n"), ""))
470 .unwrap();
471 assert_eq!(ranges, ranges2.dummy);
472 }
473}