Skip to main content

der/writer/
slice.rs

1//! Slice writer.
2
3use crate::{
4    Encode, EncodeValue, ErrorKind, Header, Length, Result, Tag, TagMode, TagNumber, Tagged,
5    Writer, asn1::*,
6};
7
8/// [`Writer`] which encodes DER into a mutable output byte slice.
9#[derive(Debug)]
10pub struct SliceWriter<'a> {
11    /// Buffer into which DER-encoded message is written
12    bytes: &'a mut [u8],
13
14    /// Has the encoding operation failed?
15    failed: bool,
16
17    /// Total number of bytes written to buffer so far
18    position: Length,
19}
20
21impl<'a> SliceWriter<'a> {
22    /// Create a new writer with the given byte slice as a backing buffer.
23    pub fn new(bytes: &'a mut [u8]) -> Self {
24        Self {
25            bytes,
26            failed: false,
27            position: Length::ZERO,
28        }
29    }
30
31    /// Encode a value which impls the [`Encode`] trait.
32    ///
33    /// # Errors
34    /// Returns an error if encoding failed.
35    pub fn encode<T: Encode>(&mut self, encodable: &T) -> Result<()> {
36        if self.is_failed() {
37            self.error(ErrorKind::Failed)?;
38        }
39
40        encodable.encode(self).map_err(|e| {
41            self.failed = true;
42            e.nested(self.position)
43        })
44    }
45
46    /// Return an error with the given [`ErrorKind`], annotating it with
47    /// context about where the error occurred.
48    ///
49    /// # Errors
50    /// This function is designed to generate errors.
51    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
52        self.failed = true;
53        Err(kind.at(self.position))
54    }
55
56    /// Did the decoding operation fail due to an error?
57    #[must_use]
58    pub fn is_failed(&self) -> bool {
59        self.failed
60    }
61
62    /// Finish encoding to the buffer, returning a slice containing the data
63    /// written to the buffer.
64    ///
65    /// # Errors
66    /// If we're overlength, or writing already failed.
67    pub fn finish(self) -> Result<&'a [u8]> {
68        let position = self.position;
69
70        if self.is_failed() {
71            return Err(ErrorKind::Failed.at(position));
72        }
73
74        self.bytes
75            .get(..usize::try_from(position)?)
76            .ok_or_else(|| ErrorKind::Overlength.at(position))
77    }
78
79    /// Encode a `CONTEXT-SPECIFIC` field with the provided tag number and mode.
80    ///
81    /// # Errors
82    /// If an encoding error occurred.
83    pub fn context_specific<T>(
84        &mut self,
85        tag_number: TagNumber,
86        tag_mode: TagMode,
87        value: &T,
88    ) -> Result<()>
89    where
90        T: EncodeValue + Tagged,
91    {
92        ContextSpecificRef {
93            tag_number,
94            tag_mode,
95            value,
96        }
97        .encode(self)
98    }
99
100    /// Encode an ASN.1 `SEQUENCE` of the given length.
101    ///
102    /// Spawns a nested slice writer which is expected to be exactly the
103    /// specified length upon completion.
104    ///
105    /// # Errors
106    /// If an encoding error occurred.
107    pub fn sequence<F>(&mut self, length: Length, f: F) -> Result<()>
108    where
109        F: FnOnce(&mut SliceWriter<'_>) -> Result<()>,
110    {
111        Header::new(Tag::Sequence, length).encode(self)?;
112
113        let mut nested_writer = SliceWriter::new(self.reserve(length)?);
114        f(&mut nested_writer)?;
115
116        if nested_writer.finish()?.len() == usize::try_from(length)? {
117            Ok(())
118        } else {
119            self.error(ErrorKind::Length { tag: Tag::Sequence })
120        }
121    }
122
123    /// Reserve a portion of the internal buffer, updating the internal cursor
124    /// position and returning a mutable slice.
125    fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
126        if self.is_failed() {
127            return Err(ErrorKind::Failed.at(self.position));
128        }
129
130        let len = len
131            .try_into()
132            .or_else(|_| self.error(ErrorKind::Overflow))?;
133
134        let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
135        let slice = self
136            .bytes
137            .get_mut(self.position.try_into()?..end.try_into()?)
138            .ok_or_else(|| ErrorKind::Overlength.at(end))?;
139
140        self.position = end;
141        Ok(slice)
142    }
143}
144
145impl Writer for SliceWriter<'_> {
146    fn write(&mut self, slice: &[u8]) -> Result<()> {
147        self.reserve(slice.len())?.copy_from_slice(slice);
148        Ok(())
149    }
150}
151
152#[cfg(test)]
153#[allow(clippy::unwrap_used)]
154mod tests {
155    use super::SliceWriter;
156    use crate::{Encode, ErrorKind, Length};
157
158    #[test]
159    fn overlength_message() {
160        let mut buffer = [];
161        let mut writer = SliceWriter::new(&mut buffer);
162        let err = false.encode(&mut writer).err().unwrap();
163        assert_eq!(err.kind(), ErrorKind::Overlength);
164        assert_eq!(err.position(), Some(Length::ONE));
165    }
166}