Skip to main content

tor_chanmgr/transport/
default.rs

1//! Implement the default transport, which opens TCP connections using a
2//! happy-eyeballs style parallel algorithm.
3
4use std::{net::SocketAddr, sync::Arc, time::Duration};
5
6use async_trait::async_trait;
7use futures::{FutureExt, StreamExt, TryFutureExt, stream::FuturesUnordered};
8use safelog::sensitive as sv;
9use tor_error::bad_api_usage;
10use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
11use tor_proto::peer::PeerAddr;
12use tor_rtcompat::{NetStreamProvider, Runtime};
13use tracing::{instrument, trace};
14
15use crate::Error;
16
17/// A default transport object that opens TCP connections for a
18/// `ChannelMethod::Direct`.
19///
20/// It opens almost-simultaneous parallel TCP connections to each address, and
21/// chooses the first one to succeed.
22#[derive(Clone, Debug)]
23pub(crate) struct DefaultTransport<R: Runtime> {
24    /// The runtime that we use for connecting.
25    runtime: R,
26    /// The outbound proxy to use, if any
27    outbound_proxy: Option<crate::config::ProxyProtocol>,
28}
29
30impl<R: Runtime> DefaultTransport<R> {
31    /// Construct a new DefaultTransport
32    pub(crate) fn new(runtime: R, outbound_proxy: Option<crate::config::ProxyProtocol>) -> Self {
33        Self {
34            runtime,
35            outbound_proxy,
36        }
37    }
38}
39
40#[async_trait]
41impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
42    type Stream = <R as NetStreamProvider>::Stream;
43
44    /// Implements the transport: makes a TCP connection (possibly
45    /// tunneled over whatever protocol) if possible.
46    #[instrument(skip_all, level = "trace")]
47    async fn connect(&self, target: &OwnedChanTarget) -> crate::Result<(PeerAddr, Self::Stream)> {
48        let direct_addrs: Vec<_> = match target.chan_method() {
49            ChannelMethod::Direct(addrs) => addrs,
50            #[allow(unreachable_patterns)]
51            _ => {
52                return Err(Error::UnusableTarget(bad_api_usage!(
53                    "Used default transport implementation for an unsupported transport."
54                )));
55            }
56        };
57
58        trace!("Launching direct connection for {}", target);
59
60        let (stream, addr) =
61            connect_to_one(&self.runtime, &direct_addrs, &self.outbound_proxy).await?;
62        Ok((addr.into(), stream))
63    }
64}
65
66/// Time to wait between starting parallel connections to the same relay.
67static CONNECTION_DELAY: Duration = Duration::from_millis(150);
68
69/// Connect to one of the addresses in `addrs` by running connections in parallel until one works.
70///
71/// This implements a basic version of RFC 8305 "happy eyeballs".
72#[instrument(skip_all, level = "trace")]
73async fn connect_to_one<R: Runtime>(
74    rt: &R,
75    addrs: &[SocketAddr],
76    outbound_proxy: &Option<crate::config::ProxyProtocol>,
77) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
78    // We need *some* addresses to connect to.
79    if addrs.is_empty() {
80        return Err(Error::UnusableTarget(bad_api_usage!(
81            "No addresses for chosen relay"
82        )));
83    }
84
85    // Turn each address into a future that waits (i * CONNECTION_DELAY), then
86    // attempts to connect to the address using the runtime (where i is the
87    // array index). Shove all of these into a `FuturesUnordered`, polling them
88    // simultaneously and returning the results in completion order.
89    //
90    // This is basically the concurrent-connection stuff from RFC 8305, ish.
91    // TODO(eta): sort the addresses first?
92    let mut connections = addrs
93        .iter()
94        .enumerate()
95        .map(|(i, a)| {
96            let delay = rt.sleep(CONNECTION_DELAY * i as u32);
97            let proxy = outbound_proxy.clone();
98            delay.then(move |_| {
99                tracing::debug!("Connecting to {}", a);
100                let a = *a;
101                async move {
102                    let stream = if let Some(ref protocol) = proxy {
103                        // Use proxy - extract address and protocol details
104                        let target = tor_linkspec::PtTargetAddr::IpPort(a);
105                        match protocol {
106                            crate::config::ProxyProtocol::Socks {
107                                version,
108                                auth,
109                                addr,
110                            } => {
111                                let proto = super::proxied::Protocol::Socks(*version, auth.clone());
112                                super::proxied::connect_via_proxy(rt, addr, &proto, &target).await
113                            }
114                            crate::config::ProxyProtocol::HttpConnect { addr, credentials } => {
115                                // Wrap credentials in Sensitive to avoid accidental logging.
116                                let auth = credentials.as_ref().map(|cred| {
117                                    (
118                                        safelog::Sensitive::new(cred.username.clone()),
119                                        safelog::Sensitive::new(
120                                            cred.password.clone().unwrap_or_default(),
121                                        ),
122                                    )
123                                });
124                                let proto = super::proxied::Protocol::HttpConnect { auth };
125                                super::proxied::connect_via_proxy(rt, addr, &proto, &target).await
126                            }
127                        }
128                    } else {
129                        // Direct connection
130                        rt.connect(&a)
131                            .await
132                            .map_err(super::proxied::ProxyError::from)
133                    }?;
134                    Ok((stream, a))
135                }
136                .map_err(move |e: super::proxied::ProxyError| (e, a))
137            })
138        })
139        .collect::<FuturesUnordered<_>>();
140
141    let mut ret = None;
142    let mut errors = vec![];
143
144    while let Some(result) = connections.next().await {
145        match result {
146            Ok(s) => {
147                // We got a stream (and address).
148                ret = Some(s);
149                break;
150            }
151            Err((e, a)) => {
152                // We got a failure on one of the streams. Store the error.
153                // TODO(eta): ideally we'd start the next connection attempt immediately.
154                errors.push((e, a));
155            }
156        }
157    }
158
159    // Ensure we don't continue trying to make connections.
160    drop(connections);
161
162    ret.ok_or_else(|| Error::ChannelBuild {
163        addresses: errors
164            .into_iter()
165            .map(|(e, a)| (sv(a), Arc::new(std::io::Error::from(e))))
166            .collect(),
167    })
168}
169
170#[cfg(test)]
171mod test {
172    // @@ begin test lint list maintained by maint/add_warning @@
173    #![allow(clippy::bool_assert_comparison)]
174    #![allow(clippy::clone_on_copy)]
175    #![allow(clippy::dbg_macro)]
176    #![allow(clippy::mixed_attributes_style)]
177    #![allow(clippy::print_stderr)]
178    #![allow(clippy::print_stdout)]
179    #![allow(clippy::single_char_pattern)]
180    #![allow(clippy::unwrap_used)]
181    #![allow(clippy::unchecked_time_subtraction)]
182    #![allow(clippy::useless_vec)]
183    #![allow(clippy::needless_pass_by_value)]
184    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
185
186    use std::str::FromStr;
187
188    use tor_rtcompat::{SleepProviderExt, test_with_one_runtime};
189    use tor_rtmock::net::MockNetwork;
190
191    use super::*;
192
193    #[test]
194    fn test_connect_one() {
195        let client_addr = "192.0.1.16".parse().unwrap();
196        // We'll put a "relay" at this address
197        let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
198        // We'll put nothing at this address, to generate errors.
199        let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
200        // Well put a black hole at this address, to generate timeouts.
201        let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
202        // We'll put a "relay" at this address too
203        let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
204
205        test_with_one_runtime!(|rt| async move {
206            // Stub out the internet so that this connection can work.
207            let network = MockNetwork::new();
208
209            // Set up a client and server runtime with a given IP
210            let client_rt = network
211                .builder()
212                .add_address(client_addr)
213                .runtime(rt.clone());
214            let server_rt = network
215                .builder()
216                .add_address(addr1.ip())
217                .add_address(addr4.ip())
218                .runtime(rt.clone());
219            let _listener = server_rt.mock_net().listen(&addr1).await.unwrap();
220            let _listener2 = server_rt.mock_net().listen(&addr4).await.unwrap();
221            // TODO: Because this test doesn't mock time, there will actually be
222            // delays as we wait for connections to this address to time out. It
223            // would be good to use MockSleepProvider instead, once we figure
224            // out how to make it both reliable and convenient.
225            network.add_blackhole(addr3).unwrap();
226
227            // No addresses? Can't succeed.
228            let failure = connect_to_one(&client_rt, &[], &None).await;
229            assert!(failure.is_err());
230
231            // Connect to a set of addresses including addr1? That's a success.
232            for addresses in [
233                &[addr1][..],
234                &[addr1, addr2][..],
235                &[addr2, addr1][..],
236                &[addr1, addr3][..],
237                &[addr3, addr1][..],
238                &[addr1, addr2, addr3][..],
239                &[addr3, addr2, addr1][..],
240            ] {
241                let (_conn, addr) = connect_to_one(&client_rt, addresses, &None).await.unwrap();
242                assert_eq!(addr, addr1);
243            }
244
245            // Connect to a set of addresses including addr2 but not addr1?
246            // That's an error of one kind or another.
247            for addresses in [
248                &[addr2][..],
249                &[addr2, addr3][..],
250                &[addr3, addr2][..],
251                &[addr3][..],
252            ] {
253                let expect_timeout = addresses.contains(&addr3);
254                let failure = rt
255                    .timeout(
256                        Duration::from_millis(300),
257                        connect_to_one(&client_rt, addresses, &None),
258                    )
259                    .await;
260                if expect_timeout {
261                    assert!(failure.is_err());
262                } else {
263                    assert!(failure.unwrap().is_err());
264                }
265            }
266
267            // Connect to addr1 and addr4?  The first one should win.
268            let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4], &None)
269                .await
270                .unwrap();
271            assert_eq!(addr, addr1);
272            let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1], &None)
273                .await
274                .unwrap();
275            assert_eq!(addr, addr4);
276        });
277    }
278}