tor_rtcompat/impls/
tokio.rs1pub(crate) mod net {
8 use crate::{impls, traits};
9 use async_trait::async_trait;
10 #[cfg(unix)]
11 use tor_general_addr::unix;
12
13 pub(crate) use tokio_crate::net::{
14 TcpListener as TokioTcpListener, TcpStream as TokioTcpStream, UdpSocket as TokioUdpSocket,
15 };
16 #[cfg(unix)]
17 pub(crate) use tokio_crate::net::{
18 UnixListener as TokioUnixListener, UnixStream as TokioUnixStream,
19 };
20
21 use futures::io::{AsyncRead, AsyncWrite};
22 use paste::paste;
23 use tokio_util::compat::{Compat, TokioAsyncReadCompatExt as _};
24
25 use std::io::Result as IoResult;
26 use std::net::SocketAddr;
27 use std::pin::Pin;
28 use std::task::{Context, Poll};
29
30 macro_rules! stream_impl {
32 {
33 $kind:ident,
34 $addr:ty,
35 $cvt_addr:ident
36 } => {paste!{
37 #[doc = stringify!($kind)]
39 pub struct [<$kind Stream>] {
43 s: Compat<[<Tokio $kind Stream>]>,
45 }
46 impl From<[<Tokio $kind Stream>]> for [<$kind Stream>] {
47 fn from(s: [<Tokio $kind Stream>]) -> [<$kind Stream>] {
48 let s = s.compat();
49 [<$kind Stream>] { s }
50 }
51 }
52 impl AsyncRead for [<$kind Stream>] {
53 fn poll_read(
54 mut self: Pin<&mut Self>,
55 cx: &mut Context<'_>,
56 buf: &mut [u8],
57 ) -> Poll<IoResult<usize>> {
58 Pin::new(&mut self.s).poll_read(cx, buf)
59 }
60 }
61 impl AsyncWrite for [<$kind Stream>] {
62 fn poll_write(
63 mut self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &[u8],
66 ) -> Poll<IoResult<usize>> {
67 Pin::new(&mut self.s).poll_write(cx, buf)
68 }
69 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
70 Pin::new(&mut self.s).poll_flush(cx)
71 }
72 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
73 Pin::new(&mut self.s).poll_close(cx)
74 }
75 }
76
77 #[doc = stringify!($kind)]
79 pub struct [<$kind Listener>] {
81 pub(super) lis: [<Tokio $kind Listener>],
83 }
84
85 #[doc = stringify!($kind)]
87 pub struct [<Incoming $kind Streams>] {
91 pub(super) lis: [<Tokio $kind Listener>],
93 }
94
95 impl futures::stream::Stream for [<Incoming $kind Streams>] {
96 type Item = IoResult<([<$kind Stream>], $addr)>;
97
98 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
99 match self.lis.poll_accept(cx) {
100 Poll::Ready(Ok((s, a))) => Poll::Ready(Some(Ok((s.into(), $cvt_addr(a)? )))),
101 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
102 Poll::Pending => Poll::Pending,
103 }
104 }
105 }
106 impl traits::NetStreamListener<$addr> for [<$kind Listener>] {
107 type Stream = [<$kind Stream>];
108 type Incoming = [<Incoming $kind Streams>];
109 fn incoming(self) -> Self::Incoming {
110 [<Incoming $kind Streams>] { lis: self.lis }
111 }
112 fn local_addr(&self) -> IoResult<$addr> {
113 $cvt_addr(self.lis.local_addr()?)
114 }
115 }
116 }}
117 }
118
119 #[cfg(unix)]
124 #[allow(clippy::needless_pass_by_value)]
125 fn try_cvt_tokio_unix_addr(
126 addr: tokio_crate::net::unix::SocketAddr,
127 ) -> IoResult<unix::SocketAddr> {
128 if addr.is_unnamed() {
129 crate::unix::new_unnamed_socketaddr()
130 } else if let Some(p) = addr.as_pathname() {
131 unix::SocketAddr::from_pathname(p)
132 } else {
133 Err(crate::unix::UnsupportedAfUnixAddressType.into())
134 }
135 }
136
137 #[allow(clippy::unnecessary_wraps)]
139 fn identity_fn_socketaddr(addr: std::net::SocketAddr) -> IoResult<std::net::SocketAddr> {
140 Ok(addr)
141 }
142
143 stream_impl! { Tcp, std::net::SocketAddr, identity_fn_socketaddr }
144 #[cfg(unix)]
145 stream_impl! { Unix, unix::SocketAddr, try_cvt_tokio_unix_addr }
146
147 pub struct UdpSocket {
149 socket: TokioUdpSocket,
151 }
152
153 impl UdpSocket {
154 pub async fn bind(addr: SocketAddr) -> IoResult<Self> {
156 TokioUdpSocket::bind(addr)
157 .await
158 .map(|socket| UdpSocket { socket })
159 }
160 }
161
162 #[async_trait]
163 impl traits::UdpSocket for UdpSocket {
164 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
165 self.socket.recv_from(buf).await
166 }
167
168 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
169 self.socket.send_to(buf, target).await
170 }
171
172 fn local_addr(&self) -> IoResult<SocketAddr> {
173 self.socket.local_addr()
174 }
175 }
176
177 impl traits::StreamOps for TcpStream {
178 fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
179 impls::streamops::set_tcp_notsent_lowat(&self.s, notsent_lowat)
180 }
181
182 #[cfg(target_os = "linux")]
183 fn new_handle(&self) -> Box<dyn traits::StreamOps + Send + Unpin> {
184 Box::new(impls::streamops::TcpSockFd::from_fd(&self.s))
185 }
186 }
187
188 #[cfg(unix)]
189 impl traits::StreamOps for UnixStream {
190 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
191 Err(traits::UnsupportedStreamOp::new(
192 "set_tcp_notsent_lowat",
193 "unsupported on Unix streams",
194 )
195 .into())
196 }
197 }
198}
199
200use crate::traits::*;
203use async_trait::async_trait;
204use futures::Future;
205use std::io::Result as IoResult;
206use std::time::Duration;
207#[cfg(unix)]
208use tor_general_addr::unix;
209use tracing::instrument;
210
211impl SleepProvider for TokioRuntimeHandle {
212 type SleepFuture = tokio_crate::time::Sleep;
213 fn sleep(&self, duration: Duration) -> Self::SleepFuture {
214 tokio_crate::time::sleep(duration)
215 }
216}
217
218#[async_trait]
219impl crate::traits::NetStreamProvider for TokioRuntimeHandle {
220 type Stream = net::TcpStream;
221 type Listener = net::TcpListener;
222
223 #[instrument(skip_all, level = "trace")]
224 async fn connect(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Stream> {
225 let s = net::TokioTcpStream::connect(addr).await?;
226 Ok(s.into())
227 }
228 async fn listen(&self, addr: &std::net::SocketAddr) -> IoResult<Self::Listener> {
229 let lis = net::TokioTcpListener::from_std(super::tcp_listen(addr)?)?;
231
232 Ok(net::TcpListener { lis })
233 }
234}
235
236#[cfg(unix)]
237#[async_trait]
238impl crate::traits::NetStreamProvider<unix::SocketAddr> for TokioRuntimeHandle {
239 type Stream = net::UnixStream;
240 type Listener = net::UnixListener;
241
242 #[instrument(skip_all, level = "trace")]
243 async fn connect(&self, addr: &unix::SocketAddr) -> IoResult<Self::Stream> {
244 let path = addr
245 .as_pathname()
246 .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
247 let s = net::TokioUnixStream::connect(path).await?;
248 Ok(s.into())
249 }
250 async fn listen(&self, addr: &unix::SocketAddr) -> IoResult<Self::Listener> {
251 let path = addr
252 .as_pathname()
253 .ok_or(crate::unix::UnsupportedAfUnixAddressType)?;
254 let lis = net::TokioUnixListener::bind(path)?;
255 Ok(net::UnixListener { lis })
256 }
257}
258
259#[cfg(not(unix))]
260crate::impls::impl_unix_non_provider! { TokioRuntimeHandle }
261
262#[async_trait]
263impl crate::traits::UdpProvider for TokioRuntimeHandle {
264 type UdpSocket = net::UdpSocket;
265
266 async fn bind(&self, addr: &std::net::SocketAddr) -> IoResult<Self::UdpSocket> {
267 net::UdpSocket::bind(*addr).await
268 }
269}
270
271pub(crate) fn create_runtime() -> IoResult<TokioRuntimeHandle> {
273 let runtime = async_executors::exec::TokioTp::new().map_err(std::io::Error::other)?;
274 Ok(runtime.into())
275}
276
277#[derive(Clone, Debug)]
290pub struct TokioRuntimeHandle {
291 owned: Option<async_executors::TokioTp>,
297 handle: tokio_crate::runtime::Handle,
299}
300
301impl TokioRuntimeHandle {
302 pub(crate) fn new(handle: tokio_crate::runtime::Handle) -> Self {
310 handle.into()
311 }
312
313 pub fn is_owned(&self) -> bool {
315 self.owned.is_some()
316 }
317}
318
319impl From<tokio_crate::runtime::Handle> for TokioRuntimeHandle {
320 fn from(handle: tokio_crate::runtime::Handle) -> Self {
321 Self {
322 owned: None,
323 handle,
324 }
325 }
326}
327
328impl From<async_executors::TokioTp> for TokioRuntimeHandle {
329 fn from(owner: async_executors::TokioTp) -> TokioRuntimeHandle {
330 let handle = owner.block_on(async { tokio_crate::runtime::Handle::current() });
331 Self {
332 owned: Some(owner),
333 handle,
334 }
335 }
336}
337
338impl ToplevelBlockOn for TokioRuntimeHandle {
339 #[track_caller]
340 fn block_on<F: Future>(&self, f: F) -> F::Output {
341 self.handle.block_on(f)
342 }
343}
344
345impl Blocking for TokioRuntimeHandle {
346 type ThreadHandle<T: Send + 'static> = async_executors::BlockingHandle<T>;
347
348 #[track_caller]
349 fn spawn_blocking<F, T>(&self, f: F) -> async_executors::BlockingHandle<T>
350 where
351 F: FnOnce() -> T + Send + 'static,
352 T: Send + 'static,
353 {
354 async_executors::BlockingHandle::tokio(self.handle.spawn_blocking(f))
355 }
356
357 #[track_caller]
358 fn reenter_block_on<F: Future>(&self, future: F) -> F::Output {
359 self.handle.block_on(future)
360 }
361
362 #[track_caller]
363 fn blocking_io<F, T>(&self, f: F) -> impl Future<Output = T>
364 where
365 F: FnOnce() -> T + Send + 'static,
366 T: Send + 'static,
367 {
368 let r = tokio_crate::task::block_in_place(f);
369 std::future::ready(r)
370 }
371}
372
373impl futures::task::Spawn for TokioRuntimeHandle {
374 #[track_caller]
375 fn spawn_obj(
376 &self,
377 future: futures::task::FutureObj<'static, ()>,
378 ) -> Result<(), futures::task::SpawnError> {
379 let join_handle = self.handle.spawn(future);
380 drop(join_handle); Ok(())
382 }
383}