Skip to main content

der/reader/
slice.rs

1//! Slice reader.
2
3use crate::{BytesRef, Decode, EncodingRules, Error, ErrorKind, Length, Reader};
4
5/// [`Reader`] which consumes an input byte slice.
6#[derive(Clone, Debug)]
7pub struct SliceReader<'a> {
8    /// Byte slice being decoded.
9    bytes: &'a BytesRef,
10
11    /// Encoding rules to apply when decoding the input.
12    encoding_rules: EncodingRules,
13
14    /// Did the decoding operation fail?
15    failed: bool,
16
17    /// Position within the decoded slice.
18    position: Length,
19}
20
21impl<'a> SliceReader<'a> {
22    /// Create a new slice reader for the given byte slice.
23    ///
24    /// # Errors
25    /// If `bytes` is too long.
26    pub fn new(bytes: &'a [u8]) -> Result<Self, Error> {
27        Self::new_with_encoding_rules(bytes, EncodingRules::default())
28    }
29
30    /// Create a new slice reader with the given encoding rules.
31    ///
32    /// # Errors
33    /// If `bytes` is too long.
34    pub fn new_with_encoding_rules(
35        bytes: &'a [u8],
36        encoding_rules: EncodingRules,
37    ) -> Result<Self, Error> {
38        Ok(Self {
39            bytes: BytesRef::new(bytes)?,
40            encoding_rules,
41            failed: false,
42            position: Length::ZERO,
43        })
44    }
45
46    /// Return an error with the given [`ErrorKind`], annotating it with
47    /// context about where the error occurred.
48    pub fn error(&mut self, kind: ErrorKind) -> Error {
49        self.failed = true;
50        kind.at(self.position)
51    }
52
53    /// Did the decoding operation fail due to an error?
54    #[must_use]
55    pub fn is_failed(&self) -> bool {
56        self.failed
57    }
58
59    /// Obtain the remaining bytes in this slice reader from the current cursor
60    /// position.
61    pub(crate) fn remaining(&self) -> Result<&'a [u8], Error> {
62        if self.is_failed() {
63            Err(ErrorKind::Failed.at(self.position))
64        } else {
65            self.bytes
66                .as_slice()
67                .get(self.position.try_into()?..)
68                .ok_or_else(|| Error::incomplete(self.input_len()))
69        }
70    }
71    /// Creates new [`SliceReader`] without advancing current reader.
72    pub(crate) fn new_nested_reader(&mut self, len: Length) -> Result<Self, Error> {
73        let prefix_len = (self.position + len)?;
74        let mut nested_reader = self.clone();
75        nested_reader.bytes = self.bytes.prefix(prefix_len)?;
76        Ok(nested_reader)
77    }
78}
79
80impl<'a> Reader<'a> for SliceReader<'a> {
81    const CAN_READ_SLICE: bool = true;
82
83    fn encoding_rules(&self) -> EncodingRules {
84        self.encoding_rules
85    }
86
87    fn input_len(&self) -> Length {
88        self.bytes.len()
89    }
90
91    fn position(&self) -> Length {
92        self.position
93    }
94
95    /// Read nested data of the given length.
96    fn read_nested<T, F, E>(&mut self, len: Length, f: F) -> Result<T, E>
97    where
98        F: FnOnce(&mut Self) -> Result<T, E>,
99        E: From<Error>,
100    {
101        let mut nested_reader = self.new_nested_reader(len)?;
102        let ret = f(&mut nested_reader);
103        self.position = nested_reader.position;
104        self.failed = nested_reader.failed;
105
106        match ret {
107            Ok(value) => {
108                nested_reader.finish().inspect_err(|_e| {
109                    self.failed = true;
110                })?;
111                Ok(value)
112            }
113            Err(err) => Err(err),
114        }
115    }
116
117    fn read_slice(&mut self, len: Length) -> Result<&'a [u8], Error> {
118        if self.is_failed() {
119            return Err(self.error(ErrorKind::Failed));
120        }
121
122        match self.remaining()?.get(..len.try_into()?) {
123            Some(result) => {
124                self.position = (self.position + len)?;
125                Ok(result)
126            }
127            None => Err(self.error(ErrorKind::Incomplete {
128                expected_len: (self.position + len)?,
129                actual_len: self.input_len(),
130            })),
131        }
132    }
133
134    fn decode<T: Decode<'a>>(&mut self) -> Result<T, T::Error> {
135        if self.is_failed() {
136            return Err(self.error(ErrorKind::Failed).into());
137        }
138
139        T::decode(self).inspect_err(|_| {
140            self.failed = true;
141        })
142    }
143
144    fn error(&mut self, kind: ErrorKind) -> Error {
145        self.failed = true;
146        kind.at(self.position)
147    }
148
149    fn finish(self) -> Result<(), Error> {
150        if self.is_failed() {
151            Err(ErrorKind::Failed.at(self.position))
152        } else if !self.is_finished() {
153            Err(ErrorKind::TrailingData {
154                decoded: self.position,
155                remaining: self.remaining_len(),
156            }
157            .at(self.position))
158        } else {
159            Ok(())
160        }
161    }
162
163    fn remaining_len(&self) -> Length {
164        debug_assert!(self.position <= self.input_len());
165        self.input_len().saturating_sub(self.position)
166    }
167}
168
169#[cfg(test)]
170#[allow(clippy::unwrap_used, clippy::panic)]
171mod tests {
172    use super::SliceReader;
173    use crate::{Decode, ErrorKind, Length, Reader};
174    use hex_literal::hex;
175
176    // INTEGER: 42
177    const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
178
179    #[test]
180    fn empty_message() {
181        let mut reader = SliceReader::new(&[]).unwrap();
182        let err = bool::decode(&mut reader).err().unwrap();
183        assert_eq!(Some(Length::ZERO), err.position());
184
185        match err.kind() {
186            ErrorKind::Incomplete {
187                expected_len,
188                actual_len,
189            } => {
190                assert_eq!(actual_len, 0u8.into());
191                assert_eq!(expected_len, 1u8.into());
192            }
193            other => panic!("unexpected error kind: {:?}", other),
194        }
195    }
196
197    #[test]
198    fn invalid_field_length() {
199        const MSG_LEN: usize = 2;
200
201        let mut reader = SliceReader::new(&EXAMPLE_MSG[..MSG_LEN]).unwrap();
202        let err = i8::decode(&mut reader).err().unwrap();
203        assert_eq!(Some(Length::from(2u8)), err.position());
204
205        match err.kind() {
206            ErrorKind::Incomplete {
207                expected_len,
208                actual_len,
209            } => {
210                assert_eq!(actual_len, MSG_LEN.try_into().unwrap());
211                assert_eq!(expected_len, (MSG_LEN + 1).try_into().unwrap());
212            }
213            other => panic!("unexpected error kind: {:?}", other),
214        }
215    }
216
217    #[test]
218    fn trailing_data() {
219        let mut reader = SliceReader::new(EXAMPLE_MSG).unwrap();
220        let x = i8::decode(&mut reader).unwrap();
221        assert_eq!(42i8, x);
222
223        let err = reader.finish().err().unwrap();
224        assert_eq!(Some(Length::from(3u8)), err.position());
225
226        assert_eq!(
227            ErrorKind::TrailingData {
228                decoded: 3u8.into(),
229                remaining: 1u8.into(),
230            },
231            err.kind()
232        );
233    }
234}