Skip to main content

tor_rtcompat/impls/
tokio.rs

1//! Re-exports of the tokio runtime for use with arti.
2//!
3//! This crate helps define a slim API around our async runtime so that we
4//! can easily swap it out.
5
6/// Types used for networking (tokio implementation)
7pub(crate) mod net {
8    use crate::{impls, traits};
9    use async_trait::async_trait;
10    #[cfg(unix)]
11    use tor_general_addr::unix;
12
13    pub(crate) use tokio_crate::net::{
14        TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
15    };
16    #[cfg(unix)]
17    pub(crate) use tokio_crate::net::{
18        UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
19    };
20
21    use futures::io::{AsyncRead, AsyncWrite};
22    use paste::paste;
23    use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
24
25    use std::io::Result as IoResult;
26    use std::net::SocketAddr;
27    use std::pin::Pin;
28    use std::task::{Context, Poll};
29
30    /// Provide a set of network stream wrappers for a single stream type.
31    macro_rules! stream_impl {
32        {
33            $kind:ident,
34            $addr:ty,
35            $cvt_addr:ident
36        } => {paste!{
37            /// Wrapper for Tokio's
38            #[doc = stringify!($kind)]
39            /// streams,
40            /// that implements the standard
41            /// AsyncRead and AsyncWrite.
42            pub struct [<$kind Stream>] {
43                /// Underlying tokio_util::compat::Compat wrapper.
44                s: Compat<[<Tokio $kind Stream>]>,
45            }
46            impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
47                fn from(s: [<Tokio $kind Stream>]) ->  [<$kind Stream>] {
48                    let s = s.compat();
49                    [<$kind Stream>] { s }
50                }
51            }
52            impl AsyncRead for  [<$kind Stream>] {
53                fn poll_read(
54                    mut self: Pin<&mut Self>,
55                    cx: &mut Context<'_>,
56                    buf: &mut [u8],
57                ) -> Poll<IoResult<usize>> {
58                    Pin::new(&mut self.s).poll_read(cx, buf)
59                }
60            }
61            impl AsyncWrite for  [<$kind Stream>] {
62                fn poll_write(
63                    mut self: Pin<&mut Self>,
64                    cx: &mut Context<'_>,
65                    buf: &[u8],
66                ) -> Poll<IoResult<usize>> {
67                    Pin::new(&mut self.s).poll_write(cx, buf)
68                }
69                fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
70                    Pin::new(&mut self.s).poll_flush(cx)
71                }
72                fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
73                    Pin::new(&mut self.s).poll_close(cx)
74                }
75            }
76
77            /// Wrap a Tokio
78            #[doc = stringify!($kind)]
79            /// Listener to behave as a futures::io::TcpListener.
80            pub struct [<$kind Listener>] {
81                /// The underlying listener.
82                pub(super) lis: [<Tokio $kind Listener>],
83            }
84
85            /// Asynchronous stream that yields incoming connections from a
86            #[doc = stringify!($kind)]
87            /// Listener.
88            ///
89            /// This is analogous to async_std::net::Incoming.
90            pub struct [<Incoming $kind Streams>] {
91                /// Reference to the underlying listener.
92                pub(super) lis: [<Tokio $kind Listener>],
93            }
94
95            impl futures::stream::Stream for [<Incoming $kind Streams>] {
96                type Item = IoResult<([<$kind Stream>], $addr)>;
97
98                fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99                    match self.lis.poll_accept(cx) {
100                        Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
101                        Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
102                        Poll::Pending => Poll::Pending,
103                    }
104                }
105            }
106            impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
107                type Stream = [<$kind Stream>];
108                type Incoming = [<Incoming $kind Streams>];
109                fn incoming(self) -> Self::Incoming {
110                    [<Incoming $kind Streams>] { lis: self.lis }
111                }
112                fn local_addr(&self) -> IoResult<$addr> {
113                    $cvt_addr(self.lis.local_addr()?)
114                }
115            }
116        }}
117    }
118
119    /// Try to convert a tokio `unix::SocketAddr` into a crate::SocketAddr.
120    ///
121    /// Frustratingly, this information is _right there_: Tokio's SocketAddr has a
122    /// std::unix::net::SocketAddr internally, but there appears to be no way to get it out.
123    #[cfg(unix)]
124    #[allow(clippy::needless_pass_by_value)]
125    fn try_cvt_tokio_unix_addr(
126        addr: tokio_crate::net::unix::SocketAddr,
127    ) -> IoResult<unix::SocketAddr> {
128        if addr.is_unnamed() {
129            crate::unix::new_unnamed_socketaddr()
130        } else if let Some(p) = addr.as_pathname() {
131            unix::SocketAddr::from_pathname(p)
132        } else {
133            Err(crate::unix::UnsupportedAfUnixAddressType.into())
134        }
135    }
136
137    /// Wrapper for (not) converting std::net::SocketAddr to itself.
138    #[allow(clippy::unnecessary_wraps)]
139    fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
140        Ok(addr)
141    }
142
143    stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
144    #[cfg(unix)]
145    stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
146
147    /// Wrap a Tokio UdpSocket
148    pub struct UdpSocket {
149        /// The underelying UdpSocket
150        socket: TokioUdpSocket,
151    }
152
153    impl UdpSocket {
154        /// Bind a UdpSocket
155        pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
156            TokioUdpSocket::bind(addr)
157                .await
158                .map(|socket| UdpSocket { socket })
159        }
160    }
161
162    #[async_trait]
163    impl traits::UdpSocket for UdpSocket {
164        async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
165            self.socket.recv_from(buf).await
166        }
167
168        async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
169            self.socket.send_to(buf, target).await
170        }
171
172        fn local_addr(&self) -> IoResult<SocketAddr> {
173            self.socket.local_addr()
174        }
175    }
176
177    impl traits::StreamOps for TcpStream {
178        fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
179            impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
180        }
181
182        #[cfg(target_os = "linux")]
183        fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
184            Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
185        }
186    }
187
188    #[cfg(unix)]
189    impl traits::StreamOps for UnixStream {
190        fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
191            Err(traits::UnsupportedStreamOp::new(
192                "set_tcp_notsent_lowat",
193                "unsupported on Unix streams",
194            )
195            .into())
196        }
197    }
198}
199
200// ==============================
201
202use crate::traits::*;
203use async_trait::async_trait;
204use futures::Future;
205use std::io::Result as IoResult;
206use std::time::Duration;
207#[cfg(unix)]
208use tor_general_addr::unix;
209use tracing::instrument;
210
211impl SleepProvider for TokioRuntimeHandle {
212    type SleepFuture = tokio_crate::time::Sleep;
213    fn sleep(&self, duration: Duration) -> Self::SleepFuture {
214        tokio_crate::time::sleep(duration)
215    }
216}
217
218#[async_trait]
219impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
220    type Stream = net::TcpStream;
221    type Listener = net::TcpListener;
222
223    #[instrument(skip_all, level = "trace")]
224    async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
225        let s = net::TokioTcpStream::connect(addr).await?;
226        Ok(s.into())
227    }
228    async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
229        // Use an implementation that's the same across all runtimes.
230        let lis = net::TokioTcpListener::from_std(super::tcp_listen(addr)?)?;
231
232        Ok(net::TcpListener { lis })
233    }
234}
235
236#[cfg(unix)]
237#[async_trait]
238impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
239    type Stream = net::UnixStream;
240    type Listener = net::UnixListener;
241
242    #[instrument(skip_all, level = "trace")]
243    async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
244        let path = addr
245            .as_pathname()
246            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247        let s = net::TokioUnixStream::connect(path).await?;
248        Ok(s.into())
249    }
250    async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
251        let path = addr
252            .as_pathname()
253            .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
254        let lis = net::TokioUnixListener::bind(path)?;
255        Ok(net::UnixListener { lis })
256    }
257}
258
259#[cfg(not(unix))]
260crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
261
262#[async_trait]
263impl crate::traits::UdpProvider for TokioRuntimeHandle {
264    type UdpSocket = net::UdpSocket;
265
266    async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
267        net::UdpSocket::bind(*addr).await
268    }
269}
270
271/// Create and return a new Tokio multithreaded runtime.
272pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
273    let runtime = async_executors::exec::TokioTp::new().map_err(std::io::Error::other)?;
274    Ok(runtime.into())
275}
276
277/// Wrapper around a Handle to a tokio runtime.
278///
279/// Ideally, this type would go away, and we would just use
280/// `tokio::runtime::Handle` directly.  Unfortunately, we can't implement
281/// `futures::Spawn` on it ourselves because of Rust's orphan rules, so we need
282/// to define a new type here.
283///
284/// # Limitations
285///
286/// Note that Arti requires that the runtime should have working implementations
287/// for Tokio's time, net, and io facilities, but we have no good way to check
288/// that when creating this object.
289#[derive(Clone, Debug)]
290pub struct TokioRuntimeHandle {
291    /// If present, the tokio executor that we've created (and which we own).
292    ///
293    /// We never access this directly; only through `handle`.  We keep it here
294    /// so that our Runtime types can be agnostic about whether they own the
295    /// executor.
296    owned: Option<async_executors::TokioTp>,
297    /// The underlying Handle.
298    handle: tokio_crate::runtime::Handle,
299}
300
301impl TokioRuntimeHandle {
302    /// Wrap a tokio runtime handle into a format that Arti can use.
303    ///
304    /// # Limitations
305    ///
306    /// Note that Arti requires that the runtime should have working
307    /// implementations for Tokio's time, net, and io facilities, but we have
308    /// no good way to check that when creating this object.
309    pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
310        handle.into()
311    }
312
313    /// Return true if this handle owns the executor that it points to.
314    pub fn is_owned(&self) -> bool {
315        self.owned.is_some()
316    }
317}
318
319impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
320    fn from(handle: tokio_crate::runtime::Handle) -> Self {
321        Self {
322            owned: None,
323            handle,
324        }
325    }
326}
327
328impl From<async_executors::TokioTp> for TokioRuntimeHandle {
329    fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
330        let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
331        Self {
332            owned: Some(owner),
333            handle,
334        }
335    }
336}
337
338impl ToplevelBlockOn for TokioRuntimeHandle {
339    #[track_caller]
340    fn block_on<F: Future>(&self, f: F) -> F::Output {
341        self.handle.block_on(f)
342    }
343}
344
345impl Blocking for TokioRuntimeHandle {
346    type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
347
348    #[track_caller]
349    fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
350    where
351        F: FnOnce() -> T + Send + 'static,
352        T: Send + 'static,
353    {
354        async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
355    }
356
357    #[track_caller]
358    fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
359        self.handle.block_on(future)
360    }
361
362    #[track_caller]
363    fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
364    where
365        F: FnOnce() -> T + Send + 'static,
366        T: Send + 'static,
367    {
368        let r = tokio_crate::task::block_in_place(f);
369        std::future::ready(r)
370    }
371}
372
373impl futures::task::Spawn for TokioRuntimeHandle {
374    #[track_caller]
375    fn spawn_obj(
376        &self,
377        future: futures::task::FutureObj<'static, ()>,
378    ) -> Result<(), futures::task::SpawnError> {
379        let join_handle = self.handle.spawn(future);
380        drop(join_handle); // this makes the task detached.
381        Ok(())
382    }
383}