1use futures::lock::Mutex;
8use futures::stream::StreamExt;
9use hickory_proto::op::{Message, OpCode, Query, ResponseCode};
10use hickory_proto::rr::{DNSClass, Name, RData, Record, RecordType, rdata};
11use hickory_proto::serialize::binary::{BinDecodable, BinEncodable};
12use std::collections::HashMap;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::Arc;
15use tor_rtcompat::{SpawnExt, UdpProvider};
16use tracing::{debug, error, info, warn};
17
18use arti_client::{Error, HasKind, StreamPrefs, TorClient};
19use safelog::sensitive as sv;
20use tor_config::Listen;
21use tor_error::{error_report, warn_report};
22use tor_rtcompat::{Runtime, UdpSocket};
23
24use anyhow::{Result, anyhow};
25
26use crate::proxy::port_info;
27
28const MAX_DATAGRAM_SIZE: usize = 1536;
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36struct DnsIsolationKey(usize, IpAddr);
37
38impl arti_client::isolation::IsolationHelper for DnsIsolationKey {
39 fn compatible_same_type(&self, other: &Self) -> bool {
40 self == other
41 }
42
43 fn join_same_type(&self, other: &Self) -> Option<Self> {
44 if self == other {
45 Some(self.clone())
46 } else {
47 None
48 }
49 }
50
51 fn enables_long_lived_circuits(&self) -> bool {
52 false
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58struct DnsCacheKey(DnsIsolationKey, Vec<Query>);
59
60#[derive(Debug, Clone)]
62struct DnsResponseTarget<U> {
63 id: u16,
65 addr: SocketAddr,
67 socket: Arc<U>,
69}
70
71async fn do_query<R>(
73 tor_client: TorClient<R>,
74 queries: &[Query],
75 prefs: &StreamPrefs,
76) -> Result<Vec<Record>, ResponseCode>
77where
78 R: Runtime,
79{
80 let mut answers = Vec::new();
81
82 let err_conv = |error: Error| {
83 if tor_error::ErrorKind::RemoteHostNotFound == error.kind() {
84 ResponseCode::NoError
86 } else {
87 ResponseCode::ServFail
88 }
89 };
90 for query in queries {
91 let mut a = Vec::new();
92 let mut ptr = Vec::new();
93
94 match query.query_class() {
97 DNSClass::IN => {
98 match query.query_type() {
99 typ @ RecordType::A | typ @ RecordType::AAAA => {
100 let mut name = query.name().clone();
101 name.set_fqdn(false);
103 let res = tor_client
104 .resolve_with_prefs(&name.to_utf8(), prefs)
105 .await
106 .map_err(err_conv)?;
107 for ip in res {
108 a.push((query.name().clone(), ip, typ));
109 }
110 }
111 RecordType::PTR => {
112 let addr = query
113 .name()
114 .parse_arpa_name()
115 .map_err(|_| ResponseCode::FormErr)?
116 .addr();
117 let res = tor_client
118 .resolve_ptr_with_prefs(addr, prefs)
119 .await
120 .map_err(err_conv)?;
121 for domain in res {
122 let domain =
123 Name::from_utf8(domain).map_err(|_| ResponseCode::ServFail)?;
124 ptr.push((query.name().clone(), domain));
125 }
126 }
127 _ => {
128 return Err(ResponseCode::NotImp);
129 }
130 }
131 }
132 _ => {
133 return Err(ResponseCode::NotImp);
134 }
135 }
136 for (name, ip, typ) in a {
137 match (ip, typ) {
138 (IpAddr::V4(v4), RecordType::A) => {
139 answers.push(Record::from_rdata(name, 3600, RData::A(rdata::A(v4))));
140 }
141 (IpAddr::V6(v6), RecordType::AAAA) => {
142 answers.push(Record::from_rdata(name, 3600, RData::AAAA(rdata::AAAA(v6))));
143 }
144 _ => (),
145 }
146 }
147 for (ptr, name) in ptr {
148 answers.push(Record::from_rdata(ptr, 3600, RData::PTR(rdata::PTR(name))));
149 }
150 }
151
152 Ok(answers)
153}
154
155#[allow(clippy::cognitive_complexity)] async fn handle_dns_req<R, U>(
159 tor_client: TorClient<R>,
160 socket_id: usize,
161 packet: &[u8],
162 addr: SocketAddr,
163 socket: Arc<U>,
164 current_requests: &Mutex<HashMap<DnsCacheKey, Vec<DnsResponseTarget<U>>>>,
165) -> Result<()>
166where
167 R: Runtime,
168 U: UdpSocket,
169{
170 let query = Message::from_bytes(packet)?;
172 let id = query.metadata.id;
173 let queries = query.queries;
174 let isolation = DnsIsolationKey(socket_id, addr.ip());
175
176 let request_id = {
177 let request_id = DnsCacheKey(isolation.clone(), queries.clone());
178
179 let response_target = DnsResponseTarget { id, addr, socket };
180
181 let mut current_requests = current_requests.lock().await;
182
183 let req = current_requests.entry(request_id.clone()).or_default();
184 req.push(response_target);
185
186 if req.len() > 1 {
187 debug!("Received a query already being served");
188 return Ok(());
189 }
190 debug!("Received a new query");
191
192 request_id
193 };
194
195 let mut prefs = StreamPrefs::new();
196 prefs.set_isolation(isolation);
197
198 let mut response = match do_query(tor_client, &queries, &prefs).await {
199 Ok(answers) => {
200 let mut response = Message::response(id, OpCode::Query);
201 response.metadata.recursion_desired = query.metadata.recursion_desired;
202 response.metadata.recursion_available = true;
203 response.add_queries(queries).add_answers(answers);
204 response
206 }
207 Err(error_type) => Message::error_msg(id, OpCode::Query, error_type),
208 };
209
210 let targets = current_requests
212 .lock()
213 .await
214 .remove(&request_id)
215 .unwrap_or_default();
216
217 for target in targets {
218 response.metadata.id = target.id;
219 let response = match response.to_bytes() {
221 Ok(r) => r,
222 Err(e) => {
223 error_report!(e, "Failed to serialize DNS packet: {:?}", sv(&response));
228 continue;
229 }
230 };
231 let _ = target.socket.send(&response, &target.addr).await;
232 }
233 Ok(())
234}
235
236#[cfg_attr(feature = "experimental-api", visibility::make(pub))]
238#[must_use]
239pub(crate) struct DnsProxy<R: Runtime> {
240 udp_sockets: Vec<<R as UdpProvider>::UdpSocket>,
242 tor_client: TorClient<R>,
244}
245
246#[cfg_attr(feature = "experimental-api", visibility::make(pub))]
250#[allow(clippy::cognitive_complexity)] pub(crate) async fn bind_dns_resolver<R: Runtime>(
252 runtime: R,
253 tor_client: TorClient<R>,
254 listen: Listen,
255) -> Result<DnsProxy<R>> {
256 if !listen.is_loopback_only() {
257 warn!(
258 "Configured to listen for DNS on non-local addresses. This is usually insecure! We recommend listening on localhost only."
259 );
260 }
261
262 let mut listeners = Vec::new();
263
264 match listen.ip_addrs() {
266 Ok(addrgroups) => {
267 for addrgroup in addrgroups {
268 for addr in addrgroup {
269 match runtime.bind(&addr).await {
272 Ok(listener) => {
273 let bound_addr = listener.local_addr()?;
274 info!("Listening on {:?}.", bound_addr);
275 listeners.push(listener);
276 }
277 #[cfg(unix)]
278 Err(ref e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => {
279 warn_report!(e, "Address family not supported {}", addr);
280 }
281 Err(ref e) => {
282 return Err(anyhow!("Can't listen on {}: {e}", addr));
283 }
284 }
285 }
286 }
288 }
289 Err(e) => warn_report!(e, "Invalid listen spec"),
290 }
291 if listeners.is_empty() {
293 error!("Couldn't open any DNS listeners.");
294 return Err(anyhow!("Couldn't open any DNS listeners"));
295 }
296
297 Ok(DnsProxy {
298 tor_client,
299 udp_sockets: listeners,
300 })
301}
302
303impl<R: Runtime> DnsProxy<R> {
304 pub(crate) async fn run_dns_proxy(self) -> Result<()> {
306 let DnsProxy {
307 tor_client,
308 udp_sockets,
309 } = self;
310 run_dns_resolver_with_listeners(tor_client.runtime().clone(), tor_client, udp_sockets).await
311 }
312
313 pub(crate) fn port_info(&self) -> Result<Vec<port_info::Port>> {
315 Ok(self
316 .udp_sockets
317 .iter()
318 .map(|socket| {
319 socket.local_addr().map(|address| port_info::Port {
320 protocol: port_info::SupportedProtocol::DnsUdp,
321 address: address.into(),
322 })
323 })
324 .collect::<Result<Vec<_>, _>>()?)
325 }
326}
327
328async fn run_dns_resolver_with_listeners<R: Runtime>(
330 runtime: R,
331 tor_client: TorClient<R>,
332 listeners: Vec<<R as tor_rtcompat::UdpProvider>::UdpSocket>,
333) -> Result<()> {
334 let mut incoming = futures::stream::select_all(
335 listeners
336 .into_iter()
337 .map(|socket| {
338 futures::stream::unfold(Arc::new(socket), |socket| async {
339 let mut packet = [0; MAX_DATAGRAM_SIZE];
340 let packet = socket
341 .recv(&mut packet)
342 .await
343 .map(|(size, remote)| (packet, size, remote, socket.clone()));
344 Some((packet, socket))
345 })
346 })
347 .enumerate()
348 .map(|(listener_id, incoming_packet)| {
349 Box::pin(incoming_packet.map(move |packet| (packet, listener_id)))
350 }),
351 );
352
353 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
354 while let Some((packet, id)) = incoming.next().await {
355 let (packet, size, addr, socket) = match packet {
356 Ok(packet) => packet,
357 Err(err) => {
358 warn_report!(err, "Incoming datagram failed");
360 continue;
361 }
362 };
363
364 let client_ref = tor_client.clone();
365 runtime.spawn({
366 let pending_requests = pending_requests.clone();
367 async move {
368 let res = handle_dns_req(
369 client_ref,
370 id,
371 &packet[..size],
372 addr,
373 socket,
374 &pending_requests,
375 )
376 .await;
377 if let Err(e) = res {
378 warn!("connection exited with error: {}", tor_error::Report(e));
380 }
381 }
382 })?;
383 }
384
385 Ok(())
386}