1use digest::Digest;
5use tor_bytes::Reader;
6use tor_cell::chancell::{
7 AnyChanCell, ChanCell, ChanCmd, ChanMsg, codec,
8 msg::{self, AnyChanMsg},
9};
10use tor_error::internal;
11use tor_llcrypto as ll;
12
13use asynchronous_codec as futures_codec;
14use bytes::BytesMut;
15
16use crate::{channel::msg::LinkVersion, util::err::Error as ChanError};
17
18use super::{ChannelType, msg::MessageFilter};
19
20pub(crate) type AuthLogDigest = [u8; 32];
22#[derive(Debug, PartialEq)]
24pub(crate) struct ClogDigest(AuthLogDigest);
25#[derive(Debug, PartialEq)]
27pub(crate) struct SlogDigest(AuthLogDigest);
28
29impl ClogDigest {
30 pub(crate) fn new(digest: AuthLogDigest) -> Self {
32 Self(digest)
33 }
34}
35
36impl SlogDigest {
37 pub(crate) fn new(digest: AuthLogDigest) -> Self {
39 Self(digest)
40 }
41}
42
43impl AsRef<[u8]> for ClogDigest {
44 fn as_ref(&self) -> &[u8] {
45 &self.0
46 }
47}
48impl AsRef<[u8]> for SlogDigest {
49 fn as_ref(&self) -> &[u8] {
50 &self.0
51 }
52}
53
54pub(crate) enum ChannelCellHandler {
62 New(NewChannelHandler),
65 Handshake(HandshakeChannelHandler),
68 Open(OpenChannelHandler),
70}
71
72impl From<super::ChannelType> for ChannelCellHandler {
75 fn from(ty: ChannelType) -> Self {
76 Self::New(ty.into())
77 }
78}
79
80impl ChannelCellHandler {
81 pub(crate) fn channel_type(&self) -> ChannelType {
83 match self {
84 Self::New(h) => h.channel_type,
85 Self::Handshake(h) => h.channel_type(),
86 Self::Open(h) => h.channel_type(),
87 }
88 }
89
90 pub(crate) fn set_link_version(&mut self, link_version: u16) -> Result<(), ChanError> {
96 let Self::New(new_handler) = self else {
97 return Err(ChanError::Bug(internal!(
98 "Setting link protocol without a new handler",
99 )));
100 };
101 *self = Self::Handshake(new_handler.next_handler(link_version.try_into()?));
102 Ok(())
103 }
104
105 pub(crate) fn set_open(&mut self) -> Result<(), ChanError> {
109 let Self::Handshake(handler) = self else {
110 return Err(ChanError::Bug(internal!(
111 "Setting open without a handshake handler"
112 )));
113 };
114 *self = Self::Open(handler.next_handler());
115 Ok(())
116 }
117
118 pub(crate) fn set_authenticated(&mut self) -> Result<(), ChanError> {
123 let Self::Handshake(handler) = self else {
124 return Err(ChanError::Bug(internal!(
125 "Setting authenticated without a handshake handler"
126 )));
127 };
128 handler.set_authenticated();
129 Ok(())
130 }
131
132 pub(crate) fn take_send_log_digest(&mut self) -> Result<AuthLogDigest, ChanError> {
141 if let Self::Handshake(handler) = self {
142 handler
143 .take_send_log_digest()
144 .ok_or(ChanError::Bug(internal!(
145 "No send log digest on channel, or already taken"
146 )))
147 } else {
148 Err(ChanError::Bug(internal!(
149 "Getting send log digest without a handshake handler"
150 )))
151 }
152 }
153
154 pub(crate) fn take_recv_log_digest(&mut self) -> Result<AuthLogDigest, ChanError> {
163 if let Self::Handshake(handler) = self {
164 handler
165 .take_recv_log_digest()
166 .ok_or(ChanError::Bug(internal!(
167 "No recv log digest on channel, or already taken"
168 )))
169 } else {
170 Err(ChanError::Bug(internal!(
171 "Getting recv log digest without a handshake handler"
172 )))
173 }
174 }
175}
176
177impl futures_codec::Decoder for ChannelCellHandler {
196 type Item = AnyChanCell;
197 type Error = ChanError;
198
199 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
200 match self {
201 Self::New(c) => c
202 .decode(src)
203 .map(|opt| opt.map(|msg| ChanCell::new(None, msg.into()))),
204 Self::Handshake(c) => c.decode(src),
205 Self::Open(c) => c.decode(src),
206 }
207 }
208}
209
210impl futures_codec::Encoder for ChannelCellHandler {
211 type Item<'a> = AnyChanCell;
212 type Error = ChanError;
213
214 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
215 match self {
216 Self::New(c) => {
217 let AnyChanMsg::Versions(versions) = item.into_circid_and_msg().1 else {
220 return Err(Self::Error::HandshakeProto(
221 "Non VERSIONS cell for new handler".into(),
222 ));
223 };
224 c.encode(versions, dst)
225 }
226 Self::Handshake(c) => c.encode(item, dst),
227 Self::Open(c) => c.encode(item, dst),
228 }
229 }
230}
231
232pub(crate) struct NewChannelHandler {
237 channel_type: ChannelType,
239 send_log: Option<ll::d::Sha256>,
243 recv_log: Option<ll::d::Sha256>,
247}
248
249impl NewChannelHandler {
250 fn next_handler(&mut self, link_version: LinkVersion) -> HandshakeChannelHandler {
252 HandshakeChannelHandler::new(self, link_version)
253 }
254}
255
256impl From<ChannelType> for NewChannelHandler {
257 fn from(channel_type: ChannelType) -> Self {
258 match channel_type {
259 ChannelType::ClientInitiator => Self {
260 channel_type,
261 send_log: None,
262 recv_log: None,
263 },
264 ChannelType::RelayInitiator | ChannelType::RelayResponder { .. } => Self {
267 channel_type,
268 send_log: Some(ll::d::Sha256::new()),
269 recv_log: Some(ll::d::Sha256::new()),
270 },
271 }
272 }
273}
274
275impl futures_codec::Decoder for NewChannelHandler {
276 type Item = msg::Versions;
277 type Error = ChanError;
278
279 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
280 const HEADER_SIZE: usize = 5;
288
289 if src.len() < HEADER_SIZE {
293 return Ok(None);
294 }
295
296 let circ_id = u16::from_be_bytes([src[0], src[1]]);
299 if circ_id != 0 {
300 return Err(Self::Error::HandshakeProto(
301 "Invalid CircID in variable cell".into(),
302 ));
303 }
304
305 let cmd = ChanCmd::from(src[2]);
308 if cmd != ChanCmd::VERSIONS {
309 return Err(Self::Error::HandshakeProto(format!(
310 "Invalid command {cmd} variable cell, expected a VERSIONS."
311 )));
312 }
313
314 let body_len = u16::from_be_bytes([src[3], src[4]]) as usize;
317
318 if body_len % 2 == 1 {
323 return Err(Self::Error::HandshakeProto(
324 "VERSIONS cell body length is odd. Rejecting.".into(),
325 ));
326 }
327
328 let wanted_bytes = HEADER_SIZE + body_len;
330 if src.len() < wanted_bytes {
331 return Ok(None);
336 }
337 let mut data = src.split_to(wanted_bytes);
339
340 if let Some(recv_log) = self.recv_log.as_mut() {
344 recv_log.update(&data);
345 }
346
347 let body = data.split_off(HEADER_SIZE).freeze();
349 let mut reader = Reader::from_bytes(&body);
350
351 let cell = msg::Versions::decode_from_reader(cmd, &mut reader)
353 .map_err(|e| Self::Error::from_bytes_err(e, "new cell handler"))?;
354 Ok(Some(cell))
355 }
356}
357
358impl futures_codec::Encoder for NewChannelHandler {
359 type Item<'a> = msg::Versions;
360 type Error = ChanError;
361
362 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
363 let encoded_bytes = item
364 .encode_for_handshake()
365 .map_err(|e| Self::Error::from_bytes_enc(e, "new cell handler"))?;
366 if let Some(send_log) = self.send_log.as_mut() {
368 send_log.update(&encoded_bytes);
369 }
370 dst.extend_from_slice(&encoded_bytes);
372 Ok(())
373 }
374}
375
376pub(crate) struct HandshakeChannelHandler {
379 filter: MessageFilter,
381 inner: codec::ChannelCodec,
383 send_log: Option<ll::d::Sha256>,
387 recv_log: Option<ll::d::Sha256>,
391}
392
393impl HandshakeChannelHandler {
394 fn new(new_handler: &mut NewChannelHandler, link_version: LinkVersion) -> Self {
396 Self {
397 filter: MessageFilter::new(
398 link_version,
399 new_handler.channel_type,
400 super::msg::MessageStage::Handshake,
401 ),
402 send_log: new_handler.send_log.take(),
403 recv_log: new_handler.recv_log.take(),
404 inner: codec::ChannelCodec::new(link_version.value()),
405 }
406 }
407
408 fn finalize_log(log: Option<ll::d::Sha256>) -> Option<[u8; 32]> {
411 log.map(|sha256| sha256.finalize().into())
412 }
413
414 fn next_handler(&mut self) -> OpenChannelHandler {
416 OpenChannelHandler::new(
417 self.inner
418 .link_version()
419 .try_into()
420 .expect("Channel Codec with unknown link version"),
421 self.channel_type(),
422 )
423 }
424
425 pub(crate) fn take_send_log_digest(&mut self) -> Option<AuthLogDigest> {
433 Self::finalize_log(self.send_log.take())
434 }
435
436 pub(crate) fn take_recv_log_digest(&mut self) -> Option<AuthLogDigest> {
444 Self::finalize_log(self.recv_log.take())
445 }
446
447 pub(crate) fn channel_type(&self) -> ChannelType {
449 self.filter.channel_type()
450 }
451
452 pub(crate) fn set_authenticated(&mut self) {
454 self.filter.channel_type_mut().set_authenticated();
455 }
456}
457
458impl futures_codec::Encoder for HandshakeChannelHandler {
459 type Item<'a> = AnyChanCell;
460 type Error = ChanError;
461
462 fn encode(
463 &mut self,
464 item: Self::Item<'_>,
465 dst: &mut BytesMut,
466 ) -> std::result::Result<(), Self::Error> {
467 let before_dst_len = dst.len();
468 self.filter.encode_cell(item, &mut self.inner, dst)?;
469 let after_dst_len = dst.len();
470 if let Some(send_log) = self.send_log.as_mut() {
471 send_log.update(&dst[before_dst_len..after_dst_len]);
474 }
475 Ok(())
476 }
477}
478
479impl futures_codec::Decoder for HandshakeChannelHandler {
480 type Item = AnyChanCell;
481 type Error = ChanError;
482
483 fn decode(
484 &mut self,
485 src: &mut BytesMut,
486 ) -> std::result::Result<Option<Self::Item>, Self::Error> {
487 let orig = src.clone(); let cell = self.filter.decode_cell(&mut self.inner, src)?;
489 if let Some(recv_log) = self.recv_log.as_mut() {
490 let n_used = orig.len() - src.len();
491 recv_log.update(&orig[..n_used]);
492 }
493 Ok(cell)
494 }
495}
496
497pub(crate) struct OpenChannelHandler {
499 filter: MessageFilter,
501 inner: codec::ChannelCodec,
503}
504
505impl OpenChannelHandler {
506 fn new(link_version: LinkVersion, channel_type: ChannelType) -> Self {
508 Self {
509 inner: codec::ChannelCodec::new(link_version.value()),
510 filter: MessageFilter::new(link_version, channel_type, super::msg::MessageStage::Open),
511 }
512 }
513
514 fn channel_type(&self) -> ChannelType {
516 self.filter.channel_type()
517 }
518}
519
520impl futures_codec::Encoder for OpenChannelHandler {
521 type Item<'a> = AnyChanCell;
522 type Error = ChanError;
523
524 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
525 self.filter.encode_cell(item, &mut self.inner, dst)
526 }
527}
528
529impl futures_codec::Decoder for OpenChannelHandler {
530 type Item = AnyChanCell;
531 type Error = ChanError;
532
533 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
534 self.filter.decode_cell(&mut self.inner, src)
535 }
536}
537
538#[cfg(test)]
539pub(crate) mod test {
540 #![allow(clippy::unwrap_used)]
541 use bytes::BytesMut;
542 use digest::Digest;
543 use futures::io::{AsyncRead, AsyncWrite, Cursor, Result};
544 use futures::sink::SinkExt;
545 use futures::stream::StreamExt;
546 use futures::task::{Context, Poll};
547 use hex_literal::hex;
548 use std::pin::Pin;
549
550 use tor_bytes::Writer;
551 use tor_llcrypto as ll;
552 use tor_rtcompat::StreamOps;
553
554 use crate::channel::msg::LinkVersion;
555 use crate::channel::{ChannelType, new_frame};
556
557 use super::{ChannelCellHandler, OpenChannelHandler, futures_codec};
558 use tor_cell::chancell::{AnyChanCell, ChanCmd, ChanMsg, CircId, msg};
559
560 pub(crate) struct MsgBuf {
562 inbuf: futures::io::Cursor<Vec<u8>>,
564 outbuf: futures::io::Cursor<Vec<u8>>,
566 }
567
568 impl AsyncRead for MsgBuf {
569 fn poll_read(
570 mut self: Pin<&mut Self>,
571 cx: &mut Context<'_>,
572 buf: &mut [u8],
573 ) -> Poll<Result<usize>> {
574 Pin::new(&mut self.inbuf).poll_read(cx, buf)
575 }
576 }
577 impl AsyncWrite for MsgBuf {
578 fn poll_write(
579 mut self: Pin<&mut Self>,
580 cx: &mut Context<'_>,
581 buf: &[u8],
582 ) -> Poll<Result<usize>> {
583 Pin::new(&mut self.outbuf).poll_write(cx, buf)
584 }
585 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
586 Pin::new(&mut self.outbuf).poll_flush(cx)
587 }
588 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
589 Pin::new(&mut self.outbuf).poll_close(cx)
590 }
591 }
592
593 impl StreamOps for MsgBuf {}
594
595 impl MsgBuf {
596 pub(crate) fn new<T: Into<Vec<u8>>>(output: T) -> Self {
597 let inbuf = Cursor::new(output.into());
598 let outbuf = Cursor::new(Vec::new());
599 MsgBuf { inbuf, outbuf }
600 }
601
602 pub(crate) fn consumed(&self) -> usize {
603 self.inbuf.position() as usize
604 }
605
606 pub(crate) fn all_consumed(&self) -> bool {
607 self.inbuf.get_ref().len() == self.consumed()
608 }
609
610 pub(crate) fn into_response(self) -> Vec<u8> {
611 self.outbuf.into_inner()
612 }
613 }
614
615 fn new_client_open_frame(mbuf: MsgBuf) -> futures_codec::Framed<MsgBuf, ChannelCellHandler> {
616 let open_handler = ChannelCellHandler::Open(OpenChannelHandler::new(
617 LinkVersion::V5,
618 ChannelType::ClientInitiator,
619 ));
620 futures_codec::Framed::new(mbuf, open_handler)
621 }
622
623 #[test]
624 fn check_client_encoding() {
625 tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
626 let mb = MsgBuf::new(&b""[..]);
627 let mut framed = new_client_open_frame(mb);
628
629 let destroycell = msg::Destroy::new(2.into());
630 framed
631 .send(AnyChanCell::new(CircId::new(7), destroycell.into()))
632 .await
633 .unwrap();
634
635 framed.flush().await.unwrap();
636
637 let data = framed.into_inner().into_response();
638
639 assert_eq!(&data[0..10], &hex!("00000007 04 0200000000")[..]);
640 });
641 }
642
643 #[test]
644 fn check_client_decoding() {
645 tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
646 let mut dat = Vec::new();
647 dat.extend_from_slice(&hex!("00000007 04 0200000000")[..]);
649 dat.resize(514, 0);
650 let mb = MsgBuf::new(&dat[..]);
651 let mut framed = new_client_open_frame(mb);
652
653 let destroy = framed.next().await.unwrap().unwrap();
654
655 let circ_id = CircId::new(7);
656 assert_eq!(destroy.circid(), circ_id);
657 assert_eq!(destroy.msg().cmd(), ChanCmd::DESTROY);
658
659 assert!(framed.into_inner().all_consumed());
660 });
661 }
662
663 #[test]
664 fn handler_transition() {
665 let mut handler: ChannelCellHandler = ChannelType::ClientInitiator.into();
667 assert!(matches!(handler, ChannelCellHandler::New(_)));
668
669 let r = handler.set_link_version(5);
671 assert!(r.is_ok());
672 assert!(matches!(handler, ChannelCellHandler::Handshake(_)));
673
674 let r = handler.set_open();
676 assert!(r.is_ok());
677 assert!(matches!(handler, ChannelCellHandler::Open(_)));
678 }
679
680 #[test]
681 fn clog_digest() {
682 tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
683 let mut our_clog = ll::d::Sha256::new();
684 let mbuf = MsgBuf::new(*b"");
685 let mut frame = new_frame(mbuf, ChannelType::RelayInitiator);
686
687 our_clog.update(hex!("0000 07 0002 0005"));
689 let version_cell = AnyChanCell::new(
690 None,
691 msg::Versions::new(vec![5]).expect("Fail VERSIONS").into(),
692 );
693 let _ = frame.send(version_cell).await.unwrap();
694
695 frame
696 .codec_mut()
697 .set_link_version(5)
698 .expect("Fail link version set");
699
700 our_clog.update(hex!("0000 0000 81 0001 00"));
702 let certs_cell = msg::Certs::new_empty();
703 frame
704 .send(AnyChanCell::new(None, certs_cell.into()))
705 .await
706 .unwrap();
707
708 let clog_hash: [u8; 32] = our_clog.finalize().into();
710 assert_eq!(frame.codec_mut().take_send_log_digest().unwrap(), clog_hash);
711 });
712 }
713
714 #[test]
715 fn slog_digest() {
716 tor_rtcompat::test_with_all_runtimes!(|_rt| async move {
717 let mut our_slog = ll::d::Sha256::new();
718
719 let mut data = BytesMut::new();
721 data.extend_from_slice(
722 msg::Versions::new(vec![5])
723 .unwrap()
724 .encode_for_handshake()
725 .expect("Fail VERSIONS encoding")
726 .as_slice(),
727 );
728 our_slog.update(&data);
729
730 let mbuf = MsgBuf::new(data);
731 let mut frame = new_frame(mbuf, ChannelType::RelayInitiator);
732
733 let _ = frame.next().await.transpose().expect("Fail to get cell");
735 frame
738 .codec_mut()
739 .set_link_version(5)
740 .expect("Fail link version set");
741
742 let mut data = BytesMut::new();
744 data.write_u32(0);
746 data.write_u8(ChanCmd::AUTH_CHALLENGE.into());
747 data.write_u16(36); msg::AuthChallenge::new([42_u8; 32], vec![3])
749 .encode_onto(&mut data)
750 .expect("Fail AUTH_CHALLENGE encoding");
751 our_slog.update(&data);
752
753 *frame = MsgBuf::new(data);
755 let _ = frame.next().await.transpose().expect("Fail to get cell");
757
758 let slog_hash: [u8; 32] = our_slog.finalize().into();
760 assert_eq!(frame.codec_mut().take_recv_log_digest().unwrap(), slog_hash);
761 });
762 }
763}