1use crate::{BytesRef, Decode, EncodingRules, Error, ErrorKind, Length, Reader};
4
5#[derive(Clone, Debug)]
7pub struct SliceReader<'a> {
8 bytes: &'a BytesRef,
10
11 encoding_rules: EncodingRules,
13
14 failed: bool,
16
17 position: Length,
19}
20
21impl<'a> SliceReader<'a> {
22 pub fn new(bytes: &'a [u8]) -> Result<Self, Error> {
27 Self::new_with_encoding_rules(bytes, EncodingRules::default())
28 }
29
30 pub fn new_with_encoding_rules(
35 bytes: &'a [u8],
36 encoding_rules: EncodingRules,
37 ) -> Result<Self, Error> {
38 Ok(Self {
39 bytes: BytesRef::new(bytes)?,
40 encoding_rules,
41 failed: false,
42 position: Length::ZERO,
43 })
44 }
45
46 pub fn error(&mut self, kind: ErrorKind) -> Error {
49 self.failed = true;
50 kind.at(self.position)
51 }
52
53 #[must_use]
55 pub fn is_failed(&self) -> bool {
56 self.failed
57 }
58
59 pub(crate) fn remaining(&self) -> Result<&'a [u8], Error> {
62 if self.is_failed() {
63 Err(ErrorKind::Failed.at(self.position))
64 } else {
65 self.bytes
66 .as_slice()
67 .get(self.position.try_into()?..)
68 .ok_or_else(|| Error::incomplete(self.input_len()))
69 }
70 }
71 pub(crate) fn new_nested_reader(&mut self, len: Length) -> Result<Self, Error> {
73 let prefix_len = (self.position + len)?;
74 let mut nested_reader = self.clone();
75 nested_reader.bytes = self.bytes.prefix(prefix_len)?;
76 Ok(nested_reader)
77 }
78}
79
80impl<'a> Reader<'a> for SliceReader<'a> {
81 const CAN_READ_SLICE: bool = true;
82
83 fn encoding_rules(&self) -> EncodingRules {
84 self.encoding_rules
85 }
86
87 fn input_len(&self) -> Length {
88 self.bytes.len()
89 }
90
91 fn position(&self) -> Length {
92 self.position
93 }
94
95 fn read_nested<T, F, E>(&mut self, len: Length, f: F) -> Result<T, E>
97 where
98 F: FnOnce(&mut Self) -> Result<T, E>,
99 E: From<Error>,
100 {
101 let mut nested_reader = self.new_nested_reader(len)?;
102 let ret = f(&mut nested_reader);
103 self.position = nested_reader.position;
104 self.failed = nested_reader.failed;
105
106 match ret {
107 Ok(value) => {
108 nested_reader.finish().inspect_err(|_e| {
109 self.failed = true;
110 })?;
111 Ok(value)
112 }
113 Err(err) => Err(err),
114 }
115 }
116
117 fn read_slice(&mut self, len: Length) -> Result<&'a [u8], Error> {
118 if self.is_failed() {
119 return Err(self.error(ErrorKind::Failed));
120 }
121
122 match self.remaining()?.get(..len.try_into()?) {
123 Some(result) => {
124 self.position = (self.position + len)?;
125 Ok(result)
126 }
127 None => Err(self.error(ErrorKind::Incomplete {
128 expected_len: (self.position + len)?,
129 actual_len: self.input_len(),
130 })),
131 }
132 }
133
134 fn decode<T: Decode<'a>>(&mut self) -> Result<T, T::Error> {
135 if self.is_failed() {
136 return Err(self.error(ErrorKind::Failed).into());
137 }
138
139 T::decode(self).inspect_err(|_| {
140 self.failed = true;
141 })
142 }
143
144 fn error(&mut self, kind: ErrorKind) -> Error {
145 self.failed = true;
146 kind.at(self.position)
147 }
148
149 fn finish(self) -> Result<(), Error> {
150 if self.is_failed() {
151 Err(ErrorKind::Failed.at(self.position))
152 } else if !self.is_finished() {
153 Err(ErrorKind::TrailingData {
154 decoded: self.position,
155 remaining: self.remaining_len(),
156 }
157 .at(self.position))
158 } else {
159 Ok(())
160 }
161 }
162
163 fn remaining_len(&self) -> Length {
164 debug_assert!(self.position <= self.input_len());
165 self.input_len().saturating_sub(self.position)
166 }
167}
168
169#[cfg(test)]
170#[allow(clippy::unwrap_used, clippy::panic)]
171mod tests {
172 use super::SliceReader;
173 use crate::{Decode, ErrorKind, Length, Reader};
174 use hex_literal::hex;
175
176 const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
178
179 #[test]
180 fn empty_message() {
181 let mut reader = SliceReader::new(&[]).unwrap();
182 let err = bool::decode(&mut reader).err().unwrap();
183 assert_eq!(Some(Length::ZERO), err.position());
184
185 match err.kind() {
186 ErrorKind::Incomplete {
187 expected_len,
188 actual_len,
189 } => {
190 assert_eq!(actual_len, 0u8.into());
191 assert_eq!(expected_len, 1u8.into());
192 }
193 other => panic!("unexpected error kind: {:?}", other),
194 }
195 }
196
197 #[test]
198 fn invalid_field_length() {
199 const MSG_LEN: usize = 2;
200
201 let mut reader = SliceReader::new(&EXAMPLE_MSG[..MSG_LEN]).unwrap();
202 let err = i8::decode(&mut reader).err().unwrap();
203 assert_eq!(Some(Length::from(2u8)), err.position());
204
205 match err.kind() {
206 ErrorKind::Incomplete {
207 expected_len,
208 actual_len,
209 } => {
210 assert_eq!(actual_len, MSG_LEN.try_into().unwrap());
211 assert_eq!(expected_len, (MSG_LEN + 1).try_into().unwrap());
212 }
213 other => panic!("unexpected error kind: {:?}", other),
214 }
215 }
216
217 #[test]
218 fn trailing_data() {
219 let mut reader = SliceReader::new(EXAMPLE_MSG).unwrap();
220 let x = i8::decode(&mut reader).unwrap();
221 assert_eq!(42i8, x);
222
223 let err = reader.finish().err().unwrap();
224 assert_eq!(Some(Length::from(3u8)), err.position());
225
226 assert_eq!(
227 ErrorKind::TrailingData {
228 decoded: 3u8.into(),
229 remaining: 1u8.into(),
230 },
231 err.kind()
232 );
233 }
234}