Skip to main content

tor_rtmock/
util.rs

1//! Internal utilities for `tor_rtmock`
2
3use derive_deftly::define_derive_deftly;
4use futures::channel::mpsc;
5
6define_derive_deftly! {
7/// Implements `Runtime` for a struct made of multiple sub-providers
8///
9/// The type must be a struct containing
10/// field(s) which implement `SleepProvider`, `NetProvider`, etc.
11///
12/// The corresponding fields must be decorated with:
13///
14///  * `#[deftly(mock(task))]` to indicate the field implementing `Spawn + BlockOn`
15///  * `#[deftly(mock(net))]` to indicate the field implementing `NetProvider`
16///  * `#[deftly(mock(sleep))]` to indicate the field implementing `SleepProvider`
17///     and `CoarseTimeProvider`.
18///  * `#[deftly(mock(toplevel))]` to indicate the field implementing `ToplevelBlockOn`
19///     unconditionally.
20///  * `#[deftly(mock(toplevel_where = "BOUND"))]` to indicate the field implementing
21///    `ToplevelBlockOn` only if BOUND is satisfied.
22///    For example, `#[deftly(mock(toplevel_where = "R: ToplevelBlockOn"))] runtime: R,`.
23// This could perhaps be further reduced:
24// ambassador might be able to remove most of the body (although does it do async well?)
25    SomeMockRuntime for struct, expect items, beta_deftly:
26
27 $(
28  ${when fmeta(mock(task))}
29
30    impl <$tgens> Spawn for $ttype {
31        fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
32            self.$fname.spawn_obj(future)
33        }
34    }
35
36    impl <$tgens> Blocking for $ttype {
37        type ThreadHandle<T: Send + 'static> = <$ftype as Blocking>::ThreadHandle<T>;
38
39        fn spawn_blocking<F, T>(&self, f: F) -> <$ftype as Blocking>::ThreadHandle<T>
40        where
41            F: FnOnce() -> T + Send + 'static,
42            T: Send + 'static {
43            self.$fname.spawn_blocking(f)
44        }
45
46        fn reenter_block_on<F>(&self, future: F) -> F::Output
47        where
48            F: Future,
49            F::Output: Send + 'static
50        {
51            self.$fname.reenter_block_on(future)
52        }
53    }
54
55 )
56 $(
57  ${when any(fmeta(mock(toplevel)), fmeta(mock(toplevel_where)))}
58
59    impl <$tgens> ToplevelBlockOn for $ttype
60    where ${fmeta(mock(toplevel_where)) as token_stream, default {}}
61    {
62        fn block_on<F: Future>(&self, future: F) -> F::Output {
63            self.$fname.block_on(future)
64        }
65    }
66
67 )
68 $(
69  ${when fmeta(mock(net))}
70
71    #[async_trait]
72    impl <$tgens> NetStreamProvider for $ttype {
73        type Stream = <$ftype as NetStreamProvider>::Stream;
74        type Listener = <$ftype as NetStreamProvider>::Listener;
75
76        async fn connect(&self, addr: &SocketAddr) -> IoResult<Self::Stream> {
77            self.$fname.connect(addr).await
78        }
79        async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
80            self.$fname.listen(addr).await
81        }
82    }
83
84    #[async_trait]
85    impl <$tgens> NetStreamProvider<tor_general_addr::unix::SocketAddr> for $ttype {
86        type Stream = FakeStream;
87        type Listener = FakeListener<tor_general_addr::unix::SocketAddr>;
88
89        async fn connect(&self, _addr: &tor_general_addr::unix::SocketAddr) -> IoResult<Self::Stream> {
90            Err(tor_general_addr::unix::NoAfUnixSocketSupport::default().into())
91        }
92        async fn listen(&self, _addr: &tor_general_addr::unix::SocketAddr) -> IoResult<Self::Listener> {
93            Err(tor_general_addr::unix::NoAfUnixSocketSupport::default().into())
94        }
95    }
96
97    impl <$tgens> TlsProvider<<$ftype as NetStreamProvider>::Stream> for $ttype {
98        type Connector = <$ftype as TlsProvider<
99            <$ftype as NetStreamProvider>::Stream
100            >>::Connector;
101        type TlsStream = <$ftype as TlsProvider<
102            <$ftype as NetStreamProvider>::Stream
103            >>::TlsStream;
104        type Acceptor = <$ftype as TlsProvider<
105            <$ftype as NetStreamProvider>::Stream
106            >>::Acceptor;
107        type TlsServerStream = <$ftype as TlsProvider<
108            <$ftype as NetStreamProvider>::Stream
109            >>::TlsServerStream;
110
111        fn tls_connector(&self) -> Self::Connector {
112            self.$fname.tls_connector()
113        }
114        fn tls_acceptor(&self, settings: tor_rtcompat::tls::TlsAcceptorSettings) -> std::io::Result<Self::Acceptor> {
115            self.$fname.tls_acceptor(settings)
116        }
117        fn supports_keying_material_export(&self) -> bool {
118            self.$fname.supports_keying_material_export()
119        }
120    }
121
122    #[async_trait]
123    impl <$tgens> UdpProvider for $ttype {
124        type UdpSocket = <$ftype as UdpProvider>::UdpSocket;
125
126        #[inline]
127        async fn bind(&self, addr: &SocketAddr) -> IoResult<Self::UdpSocket> {
128            self.$fname.bind(addr).await
129        }
130    }
131
132 )
133 $(
134  ${when fmeta(mock(sleep))}
135
136    impl <$tgens> SleepProvider for $ttype {
137        type SleepFuture = <$ftype as SleepProvider>::SleepFuture;
138
139        fn sleep(&self, dur: Duration) -> Self::SleepFuture {
140            self.$fname.sleep(dur)
141        }
142        fn now(&self) -> Instant {
143            self.$fname.now()
144        }
145        fn wallclock(&self) -> SystemTime {
146            self.$fname.wallclock()
147        }
148        fn block_advance<T: Into<String>>(&self, reason: T) {
149            self.$fname.block_advance(reason);
150        }
151        fn release_advance<T: Into<String>>(&self, reason: T) {
152            self.$fname.release_advance(reason);
153        }
154        fn allow_one_advance(&self, dur: Duration) {
155            self.$fname.allow_one_advance(dur);
156        }
157    }
158
159    impl <$tgens> CoarseTimeProvider for $ttype {
160        fn now_coarse(&self) -> CoarseInstant {
161            self.$fname.now_coarse()
162        }
163    }
164
165 )
166
167   // TODO this wants to be assert_impl but it fails at generics
168   #[allow(unused)]
169   const _: fn() = || {
170       fn x(_: impl Runtime) { }
171       fn check_impl_runtime<$tgens>(t: $ttype) { x(t) }
172   };
173}
174
175/// Prelude that must be imported to derive
176/// [`SomeMockRuntime`](derive_deftly_template_SomeMockRuntime)
177//
178// This could have been part of the expansion of `impl_runtime!`,
179// but it seems rather too exciting for a macro to import things as a side gig.
180//
181// Arguably this ought to be an internal crate::prelude instead.
182// But crate-internal preludes are controversial within the Arti team.  -Diziet
183//
184// For macro visibility reasons, this must come *lexically after* the macro,
185// to allow it to refer to the macro in the doc comment.
186pub(crate) mod impl_runtime_prelude {
187    pub(crate) use async_trait::async_trait;
188    pub(crate) use derive_deftly::Deftly;
189    pub(crate) use futures::Future;
190    pub(crate) use futures::task::{FutureObj, Spawn, SpawnError};
191    pub(crate) use std::io::Result as IoResult;
192    pub(crate) use std::net::SocketAddr;
193    pub(crate) use tor_rtcompat::{
194        Blocking, CoarseInstant, CoarseTimeProvider, NetStreamProvider, Runtime, SleepProvider,
195        TlsProvider, ToplevelBlockOn, UdpProvider, unimpl::FakeListener, unimpl::FakeStream,
196    };
197    pub(crate) use web_time_compat::{Duration, Instant, SystemTime, SystemTimeExt};
198}
199
200/// Wrapper for `futures::channel::mpsc::channel` that embodies the `#[allow]`
201///
202/// We don't care about mq tracking in this test crate.
203///
204/// Exactly like `tor_async_utils::mpsc_channel_no_memquota`,
205/// but we can't use that here for crate hierarchy reasons.
206#[allow(clippy::disallowed_methods)]
207pub(crate) fn mpsc_channel<T>(buffer: usize) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
208    mpsc::channel(buffer)
209}