1use alloc::vec::Vec;
9use core::{
10 marker::PhantomData,
11 ops::{Deref, DerefMut},
12};
13
14use super::BinEncodable;
15use crate::error::{ProtoError, ProtoResult};
16
17mod private {
19 use alloc::vec::Vec;
20
21 use crate::{ProtoError, error::ProtoResult};
22
23 pub(super) struct MaximalBuf<'a> {
25 max_size: usize,
26 buffer: &'a mut Vec<u8>,
27 }
28
29 impl<'a> MaximalBuf<'a> {
30 pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
31 MaximalBuf {
32 max_size: max_size as usize,
33 buffer,
34 }
35 }
36
37 pub(super) fn set_max_size(&mut self, max: u16) {
39 self.max_size = max as usize;
40 }
41
42 pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
43 debug_assert!(offset <= self.buffer.len());
44 if offset + data.len() > self.max_size {
45 return Err(ProtoError::MaxBufferSizeExceeded(self.max_size));
46 }
47
48 if offset == self.buffer.len() {
49 self.buffer.extend(data);
50 return Ok(());
51 }
52
53 let end = offset + data.len();
54 if end > self.buffer.len() {
55 self.buffer.resize(end, 0);
56 }
57
58 self.buffer[offset..end].copy_from_slice(data);
59 Ok(())
60 }
61
62 pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
63 let end = offset + len;
64 if end > self.max_size {
65 return Err(ProtoError::MaxBufferSizeExceeded(self.max_size));
66 }
67
68 self.buffer.resize(end, 0);
69 Ok(())
70 }
71
72 pub(super) fn truncate(&mut self, len: usize) {
74 self.buffer.truncate(len)
75 }
76
77 pub(super) fn len(&self) -> usize {
79 self.buffer.len()
80 }
81
82 pub(super) fn buffer(&'a self) -> &'a [u8] {
84 self.buffer as &'a [u8]
85 }
86
87 pub(super) fn into_bytes(self) -> &'a Vec<u8> {
89 self.buffer
90 }
91 }
92}
93
94pub struct BinEncoder<'a> {
96 offset: usize,
97 buffer: private::MaximalBuf<'a>,
98 name_pointers: Vec<(usize, Vec<u8>)>,
100 canonical_form: bool,
102 name_encoding: NameEncoding,
104 pub(crate) compressed_name_count: usize,
106}
107
108impl<'a> BinEncoder<'a> {
109 pub fn new(buf: &'a mut Vec<u8>) -> Self {
111 Self::with_offset(buf, 0)
112 }
113
114 pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32) -> Self {
124 if buf.capacity() < 512 {
125 let reserve = 512 - buf.capacity();
126 buf.reserve(reserve);
127 }
128
129 BinEncoder {
130 offset: offset as usize,
131 buffer: private::MaximalBuf::new(u16::MAX, buf),
133 name_pointers: Vec::new(),
134 canonical_form: false,
135 name_encoding: NameEncoding::Compressed,
136 compressed_name_count: 0,
137 }
138 }
139
140 pub fn set_max_size(&mut self, max: u16) {
147 self.buffer.set_max_size(max);
148 }
149
150 pub fn into_bytes(self) -> &'a Vec<u8> {
152 self.buffer.into_bytes()
153 }
154
155 pub fn len(&self) -> usize {
157 self.buffer.len()
158 }
159
160 pub fn is_empty(&self) -> bool {
162 self.buffer.buffer().is_empty()
163 }
164
165 pub fn offset(&self) -> usize {
167 self.offset
168 }
169
170 pub fn set_offset(&mut self, offset: usize) {
172 self.offset = offset;
173 }
174
175 pub fn set_canonical_form(&mut self, canonical_form: bool) {
177 self.canonical_form = canonical_form;
178 }
179
180 pub fn is_canonical_form(&self) -> bool {
182 self.canonical_form
183 }
184
185 pub fn set_name_encoding(&mut self, name_encoding: NameEncoding) {
187 self.name_encoding = name_encoding;
188 }
189
190 pub fn name_encoding(&self) -> NameEncoding {
192 self.name_encoding
193 }
194
195 pub fn with_name_encoding<'e>(
197 &'e mut self,
198 name_encoding: NameEncoding,
199 ) -> ModalEncoder<'a, 'e> {
200 let previous_name_encoding = self.name_encoding();
201
202 self.set_name_encoding(name_encoding);
203
204 ModalEncoder {
205 previous_name_encoding,
206 inner: self,
207 }
208 }
209
210 pub fn with_rdata_behavior<'e>(
218 &'e mut self,
219 rdata_encoding: RDataEncoding,
220 ) -> ModalEncoder<'a, 'e> {
221 let previous_name_encoding = self.name_encoding();
222
223 match (rdata_encoding, self.is_canonical_form()) {
224 (RDataEncoding::StandardRecord, true) | (RDataEncoding::Canonical, true) => {
225 self.set_name_encoding(NameEncoding::UncompressedLowercase)
226 }
227 (RDataEncoding::StandardRecord, false) => {}
228 (RDataEncoding::Canonical, false)
229 | (RDataEncoding::Other, true)
230 | (RDataEncoding::Other, false) => self.set_name_encoding(NameEncoding::Uncompressed),
231 }
232
233 ModalEncoder {
234 previous_name_encoding,
235 inner: self,
236 }
237 }
238
239 pub fn trim(&mut self) {
241 let offset = self.offset;
242 self.buffer.truncate(offset);
243 self.name_pointers.retain(|&(start, _)| start < offset);
244 }
245
246 pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
260 assert!(start < self.offset);
261 assert!(end <= self.buffer.len());
262 &self.buffer.buffer()[start..end]
263 }
264
265 pub fn store_label_pointer(&mut self, start: usize, end: usize) {
270 assert!(start <= (u16::MAX as usize));
271 assert!(end <= (u16::MAX as usize));
272 assert!(start <= end);
273 if self.offset < 0x3FFF_usize && self.name_pointers.len() < COMPRESSION_CANDIDATE_LIMIT {
274 self.name_pointers
275 .push((start, self.slice_of(start, end).to_vec())); }
277 }
278
279 pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
281 let search = self.slice_of(start, end);
282
283 for (match_start, matcher) in &self.name_pointers {
284 if matcher.as_slice() == search {
285 assert!(match_start <= &(u16::MAX as usize));
286 return Some(*match_start as u16);
287 }
288 }
289
290 None
291 }
292
293 pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
295 self.buffer.write(self.offset, &[b])?;
296 self.offset += 1;
297 Ok(())
298 }
299
300 pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
313 let char_bytes = char_data.as_ref();
314 if char_bytes.len() > 255 {
315 return Err(ProtoError::CharacterDataTooLong {
316 max: 255,
317 len: char_bytes.len(),
318 });
319 }
320
321 self.emit_character_data_unrestricted(char_data)
322 }
323
324 pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
329 let data = data.as_ref();
331 self.emit(data.len() as u8)?;
332 self.write_slice(data)
333 }
334
335 pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
337 self.emit(data)
338 }
339
340 pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
342 self.write_slice(&data.to_be_bytes())
343 }
344
345 pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
347 self.write_slice(&data.to_be_bytes())
348 }
349
350 pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
352 self.write_slice(&data.to_be_bytes())
353 }
354
355 fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
356 self.buffer.write(self.offset, data)?;
357 self.offset += data.len();
358 Ok(())
359 }
360
361 pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
363 self.write_slice(data)
364 }
365
366 pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
368 &mut self,
369 mut iter: I,
370 ) -> ProtoResult<usize> {
371 self.emit_iter(&mut iter)
372 }
373
374 pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
377 where
378 'e: 'r,
379 I: Iterator<Item = &'r &'e E>,
380 E: 'r + 'e + BinEncodable,
381 {
382 let mut iter = iter.cloned();
383 self.emit_iter(&mut iter)
384 }
385
386 pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
388 &mut self,
389 iter: &mut I,
390 ) -> ProtoResult<usize> {
391 let mut count = 0;
392 for i in iter {
393 let rollback = self.set_rollback();
394 if let Err(e) = i.emit(self) {
395 return Err(match &e {
396 ProtoError::MaxBufferSizeExceeded(_) => {
397 rollback.rollback(self);
398 ProtoError::NotAllRecordsWritten { count }
399 }
400 _ => e,
401 });
402 }
403
404 count += 1;
405 }
406 Ok(count)
407 }
408
409 pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
411 let index = self.offset;
412
413 self.buffer.reserve(self.offset, T::LEN)?;
415
416 self.offset += T::LEN;
418
419 Ok(Place {
420 start_index: index,
421 phantom: PhantomData,
422 })
423 }
424
425 pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
427 (self.offset - place.start_index) - T::LEN
428 }
429
430 pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
432 let current_index = self.offset;
434
435 assert!(place.start_index < current_index);
438 self.offset = place.start_index;
439
440 let emit_result = data.emit(self);
442
443 assert!((self.offset - place.start_index) == T::LEN);
446
447 self.offset = current_index;
449
450 emit_result
451 }
452
453 fn set_rollback(&self) -> Rollback {
454 Rollback {
455 offset: self.offset(),
456 pointers: self.name_pointers.len(),
457 }
458 }
459}
460
461const COMPRESSION_CANDIDATE_LIMIT: usize = 64;
467
468#[derive(Clone, Copy)]
470pub enum NameEncoding {
471 Compressed,
473 Uncompressed,
475 UncompressedLowercase,
477}
478
479#[derive(Clone, Copy)]
482pub enum RDataEncoding {
483 StandardRecord,
491 Canonical,
498 Other,
505}
506
507pub struct ModalEncoder<'a, 'e> {
511 previous_name_encoding: NameEncoding,
512 inner: &'e mut BinEncoder<'a>,
513}
514
515impl<'a> Deref for ModalEncoder<'a, '_> {
516 type Target = BinEncoder<'a>;
517
518 fn deref(&self) -> &Self::Target {
519 self.inner
520 }
521}
522
523impl DerefMut for ModalEncoder<'_, '_> {
524 fn deref_mut(&mut self) -> &mut Self::Target {
525 self.inner
526 }
527}
528
529impl Drop for ModalEncoder<'_, '_> {
530 fn drop(&mut self) {
531 self.inner.set_name_encoding(self.previous_name_encoding);
532 }
533}
534
535pub trait EncodedSize: BinEncodable {
539 const LEN: usize;
541}
542
543impl EncodedSize for u16 {
544 const LEN: usize = 2;
545}
546
547#[derive(Debug)]
549#[must_use = "data must be written back to the place"]
550pub struct Place<T: EncodedSize> {
551 start_index: usize,
552 phantom: PhantomData<T>,
553}
554
555impl<T: EncodedSize> Place<T> {
556 pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
558 encoder.emit_at(self, data)
559 }
560}
561
562pub(crate) struct Rollback {
564 offset: usize,
565 pointers: usize,
566}
567
568impl Rollback {
569 pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
570 let Self { offset, pointers } = self;
571 encoder.set_offset(offset);
572 encoder.name_pointers.truncate(pointers);
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 #[cfg(any(feature = "std", feature = "no-std-rand"))]
579 use core::str::FromStr;
580
581 use super::*;
582 use crate::{op::Message, serialize::binary::BinDecoder};
583 #[cfg(any(feature = "std", feature = "no-std-rand"))]
584 use crate::{
585 op::Query,
586 rr::Name,
587 rr::{
588 RData, Record, RecordType,
589 rdata::{CNAME, SRV},
590 },
591 };
592
593 #[test]
594 fn test_label_compression_regression() {
595 let data = vec![
604 154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
605 115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
606 97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
607 0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
608 110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
609 0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
610 ];
611
612 let msg = Message::from_vec(&data).unwrap();
613 msg.to_bytes().unwrap();
614 }
615
616 #[test]
617 fn test_place() {
618 let mut buf = vec![];
619 {
620 let mut encoder = BinEncoder::new(&mut buf);
621 let place = encoder.place::<u16>().unwrap();
622 assert_eq!(encoder.len_since_place(&place), 0);
623
624 encoder.emit(42_u8).expect("failed 0");
625 assert_eq!(encoder.len_since_place(&place), 1);
626
627 encoder.emit(48_u8).expect("failed 1");
628 assert_eq!(encoder.len_since_place(&place), 2);
629
630 place
631 .replace(&mut encoder, 4_u16)
632 .expect("failed to replace");
633 drop(encoder);
634 }
635
636 assert_eq!(buf.len(), 4);
637
638 let mut decoder = BinDecoder::new(&buf);
639 let written = decoder.read_u16().expect("cound not read u16").unverified();
640
641 assert_eq!(written, 4);
642 }
643
644 #[test]
645 fn test_max_size() {
646 let mut buf = vec![];
647 let mut encoder = BinEncoder::new(&mut buf);
648
649 encoder.set_max_size(5);
650 encoder.emit(0).expect("failed to write");
651 encoder.emit(1).expect("failed to write");
652 encoder.emit(2).expect("failed to write");
653 encoder.emit(3).expect("failed to write");
654 encoder.emit(4).expect("failed to write");
655 let error = encoder.emit(5).unwrap_err();
656
657 match error {
658 ProtoError::MaxBufferSizeExceeded(_) => (),
659 _ => panic!(),
660 }
661 }
662
663 #[test]
664 fn test_max_size_0() {
665 let mut buf = vec![];
666 let mut encoder = BinEncoder::new(&mut buf);
667
668 encoder.set_max_size(0);
669 let error = encoder.emit(0).unwrap_err();
670
671 match error {
672 ProtoError::MaxBufferSizeExceeded(_) => (),
673 _ => panic!(),
674 }
675 }
676
677 #[test]
678 fn test_max_size_place() {
679 let mut buf = vec![];
680 let mut encoder = BinEncoder::new(&mut buf);
681
682 encoder.set_max_size(2);
683 let place = encoder.place::<u16>().expect("place failed");
684 place.replace(&mut encoder, 16).expect("placeback failed");
685
686 let error = encoder.place::<u16>().unwrap_err();
687
688 match error {
689 ProtoError::MaxBufferSizeExceeded(_) => (),
690 _ => panic!(),
691 }
692 }
693
694 #[cfg(any(feature = "std", feature = "no-std-rand"))]
695 #[test]
696 fn test_target_compression() {
697 let mut msg = Message::query();
698 msg.add_query(Query::query(
699 Name::from_str("www.google.com.").unwrap(),
700 RecordType::A,
701 ))
702 .add_answer(Record::from_rdata(
703 Name::from_str("www.google.com.").unwrap(),
704 0,
705 RData::SRV(SRV::new(
706 0,
707 0,
708 0,
709 Name::from_str("www.compressme.com.").unwrap(),
710 )),
711 ))
712 .add_additional(Record::from_rdata(
713 Name::from_str("www.google.com.").unwrap(),
714 0,
715 RData::SRV(SRV::new(
716 0,
717 0,
718 0,
719 Name::from_str("www.compressme.com.").unwrap(),
720 )),
721 ))
722 .add_answer(Record::from_rdata(
724 Name::from_str("www.compressme.com.").unwrap(),
725 0,
726 RData::CNAME(CNAME(Name::from_str("www.foo.com.").unwrap())),
727 ));
728
729 let bytes = msg.to_vec().unwrap();
730 assert_eq!(bytes.len(), 130);
732 assert!(Message::from_vec(&bytes).is_ok());
734 }
735}