1mod conn;
4mod display_error_stack;
5mod incoming;
6mod io_stream;
7mod service;
8#[cfg(feature = "_tls-any")]
9mod tls;
10#[cfg(unix)]
11mod unix;
12
13use tokio_stream::StreamExt as _;
14use tracing::{debug, trace};
15
16#[cfg(feature = "router")]
17use crate::{server::NamedService, service::Routes};
18
19#[cfg(feature = "router")]
20use std::convert::Infallible;
21
22pub use conn::{Connected, TcpConnectInfo};
23use hyper_util::{
24 rt::{TokioExecutor, TokioIo, TokioTimer},
25 server::conn::auto::{Builder as ConnectionBuilder, HttpServerConnExec},
26 service::TowerToHyperService,
27};
28#[cfg(feature = "_tls-any")]
29pub use tls::ServerTlsConfig;
30
31#[cfg(feature = "_tls-any")]
32pub use conn::TlsConnectInfo;
33
34#[cfg(feature = "_tls-any")]
35use self::service::TlsAcceptor;
36
37#[cfg(unix)]
38pub use unix::UdsConnectInfo;
39
40pub use incoming::TcpIncoming;
41
42#[cfg(feature = "_tls-any")]
43use crate::transport::Error;
44
45use self::service::{ConnectInfoLayer, ServerIo};
46use super::service::GrpcTimeout;
47use crate::body::Body;
48use crate::service::RecoverErrorLayer;
49use crate::transport::server::display_error_stack::DisplayErrorStack;
50use bytes::Bytes;
51use http::{Request, Response};
52use http_body_util::BodyExt;
53use hyper::{body::Incoming, service::Service as HyperService};
54use pin_project::pin_project;
55use std::{
56 fmt,
57 future::{self, Future},
58 marker::PhantomData,
59 net::SocketAddr,
60 pin::{pin, Pin},
61 sync::Arc,
62 task::{ready, Context, Poll},
63 time::Duration,
64};
65use tokio::io::{AsyncRead, AsyncWrite};
66use tokio_stream::Stream;
67use tower::{
68 layer::util::{Identity, Stack},
69 layer::Layer,
70 limit::concurrency::ConcurrencyLimitLayer,
71 load_shed::LoadShedLayer,
72 util::BoxCloneService,
73 Service, ServiceBuilder, ServiceExt,
74};
75
76type BoxService = tower::util::BoxCloneService<Request<Body>, Response<Body>, crate::BoxError>;
77type TraceInterceptor = Arc<dyn Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static>;
78
79const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
80
81#[derive(Clone)]
90pub struct Server<L = Identity> {
91 trace_interceptor: Option<TraceInterceptor>,
92 concurrency_limit: Option<usize>,
93 load_shed: bool,
94 timeout: Option<Duration>,
95 #[cfg(feature = "_tls-any")]
96 tls: Option<TlsAcceptor>,
97 init_stream_window_size: Option<u32>,
98 init_connection_window_size: Option<u32>,
99 max_concurrent_streams: Option<u32>,
100 tcp_keepalive: Option<Duration>,
101 tcp_keepalive_interval: Option<Duration>,
102 tcp_keepalive_retries: Option<u32>,
103 tcp_nodelay: bool,
104 http2_keepalive_interval: Option<Duration>,
105 http2_keepalive_timeout: Duration,
106 http2_adaptive_window: Option<bool>,
107 http2_max_pending_accept_reset_streams: Option<usize>,
108 http2_max_local_error_reset_streams: Option<usize>,
109 http2_max_header_list_size: Option<u32>,
110 max_frame_size: Option<u32>,
111 accept_http1: bool,
112 service_builder: ServiceBuilder<L>,
113 max_connection_age: Option<Duration>,
114 max_connection_age_grace: Option<Duration>,
115}
116
117impl Default for Server<Identity> {
118 fn default() -> Self {
119 Self {
120 trace_interceptor: None,
121 concurrency_limit: None,
122 load_shed: false,
123 timeout: None,
124 #[cfg(feature = "_tls-any")]
125 tls: None,
126 init_stream_window_size: None,
127 init_connection_window_size: None,
128 max_concurrent_streams: None,
129 tcp_keepalive: None,
130 tcp_keepalive_interval: None,
131 tcp_keepalive_retries: None,
132 tcp_nodelay: true,
133 http2_keepalive_interval: None,
134 http2_keepalive_timeout: DEFAULT_HTTP2_KEEPALIVE_TIMEOUT,
135 http2_adaptive_window: None,
136 http2_max_pending_accept_reset_streams: None,
137 http2_max_local_error_reset_streams: None,
138 http2_max_header_list_size: None,
139 max_frame_size: None,
140 accept_http1: false,
141 service_builder: Default::default(),
142 max_connection_age: None,
143 max_connection_age_grace: None,
144 }
145 }
146}
147
148#[cfg(feature = "router")]
150#[derive(Clone, Debug)]
151pub struct Router<L = Identity> {
152 server: Server<L>,
153 routes: Routes,
154}
155
156impl Server {
157 pub fn builder() -> Self {
159 Self::default()
160 }
161}
162
163impl<L> Server<L> {
164 #[cfg(feature = "_tls-any")]
166 pub fn tls_config(self, tls_config: ServerTlsConfig) -> Result<Self, Error> {
167 Ok(Server {
168 tls: Some(tls_config.tls_acceptor().map_err(Error::from_source)?),
169 ..self
170 })
171 }
172
173 #[must_use]
184 pub fn concurrency_limit_per_connection(self, limit: usize) -> Self {
185 Server {
186 concurrency_limit: Some(limit),
187 ..self
188 }
189 }
190
191 #[must_use]
208 pub fn load_shed(self, load_shed: bool) -> Self {
209 Server { load_shed, ..self }
210 }
211
212 #[must_use]
224 pub fn timeout(self, timeout: Duration) -> Self {
225 Server {
226 timeout: Some(timeout),
227 ..self
228 }
229 }
230
231 #[must_use]
238 pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
239 Server {
240 init_stream_window_size: sz.into(),
241 ..self
242 }
243 }
244
245 #[must_use]
249 pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
250 Server {
251 init_connection_window_size: sz.into(),
252 ..self
253 }
254 }
255
256 #[must_use]
263 pub fn max_concurrent_streams(self, max: impl Into<Option<u32>>) -> Self {
264 Server {
265 max_concurrent_streams: max.into(),
266 ..self
267 }
268 }
269
270 #[must_use]
284 pub fn max_connection_age(self, max_connection_age: Duration) -> Self {
285 Server {
286 max_connection_age: Some(max_connection_age),
287 ..self
288 }
289 }
290
291 #[must_use]
314 pub fn max_connection_age_grace(self, max_connection_age_grace: Duration) -> Self {
315 Server {
316 max_connection_age_grace: Some(max_connection_age_grace),
317 ..self
318 }
319 }
320
321 #[must_use]
331 pub fn http2_keepalive_interval(self, http2_keepalive_interval: Option<Duration>) -> Self {
332 Server {
333 http2_keepalive_interval,
334 ..self
335 }
336 }
337
338 #[must_use]
346 pub fn http2_keepalive_timeout(mut self, http2_keepalive_timeout: Option<Duration>) -> Self {
347 if let Some(timeout) = http2_keepalive_timeout {
348 self.http2_keepalive_timeout = timeout;
349 }
350 self
351 }
352
353 #[must_use]
357 pub fn http2_adaptive_window(self, enabled: Option<bool>) -> Self {
358 Server {
359 http2_adaptive_window: enabled,
360 ..self
361 }
362 }
363
364 #[must_use]
370 pub fn http2_max_pending_accept_reset_streams(self, max: Option<usize>) -> Self {
371 Server {
372 http2_max_pending_accept_reset_streams: max,
373 ..self
374 }
375 }
376
377 #[must_use]
381 pub fn http2_max_local_error_reset_streams(self, max: Option<usize>) -> Self {
382 Server {
383 http2_max_local_error_reset_streams: max,
384 ..self
385 }
386 }
387
388 #[must_use]
399 pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
400 Server {
401 tcp_keepalive,
402 ..self
403 }
404 }
405
406 #[must_use]
417 pub fn tcp_keepalive_interval(self, tcp_keepalive_interval: Option<Duration>) -> Self {
418 Server {
419 tcp_keepalive_interval,
420 ..self
421 }
422 }
423
424 #[must_use]
436 pub fn tcp_keepalive_retries(self, tcp_keepalive_retries: Option<u32>) -> Self {
437 Server {
438 tcp_keepalive_retries,
439 ..self
440 }
441 }
442
443 #[must_use]
447 pub fn tcp_nodelay(self, enabled: bool) -> Self {
448 Server {
449 tcp_nodelay: enabled,
450 ..self
451 }
452 }
453
454 #[must_use]
458 pub fn http2_max_header_list_size(self, max: impl Into<Option<u32>>) -> Self {
459 Server {
460 http2_max_header_list_size: max.into(),
461 ..self
462 }
463 }
464
465 #[must_use]
471 pub fn max_frame_size(self, frame_size: impl Into<Option<u32>>) -> Self {
472 Server {
473 max_frame_size: frame_size.into(),
474 ..self
475 }
476 }
477
478 #[must_use]
487 pub fn accept_http1(self, accept_http1: bool) -> Self {
488 Server {
489 accept_http1,
490 ..self
491 }
492 }
493
494 #[must_use]
496 pub fn trace_fn<F>(self, f: F) -> Self
497 where
498 F: Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static,
499 {
500 Server {
501 trace_interceptor: Some(Arc::new(f)),
502 ..self
503 }
504 }
505
506 #[cfg(feature = "router")]
511 pub fn add_service<S>(&mut self, svc: S) -> Router<L>
512 where
513 S: Service<Request<Body>, Error = Infallible>
514 + NamedService
515 + Clone
516 + Send
517 + Sync
518 + 'static,
519 S::Response: axum::response::IntoResponse,
520 S::Future: Send + 'static,
521 L: Clone,
522 {
523 Router::new(self.clone(), Routes::new(svc))
524 }
525
526 #[cfg(feature = "router")]
535 pub fn add_optional_service<S>(&mut self, svc: Option<S>) -> Router<L>
536 where
537 S: Service<Request<Body>, Error = Infallible>
538 + NamedService
539 + Clone
540 + Send
541 + Sync
542 + 'static,
543 S::Response: axum::response::IntoResponse,
544 S::Future: Send + 'static,
545 L: Clone,
546 {
547 let routes = svc.map(Routes::new).unwrap_or_default();
548 Router::new(self.clone(), routes)
549 }
550
551 #[cfg(feature = "router")]
556 pub fn add_routes(&mut self, routes: Routes) -> Router<L>
557 where
558 L: Clone,
559 {
560 Router::new(self.clone(), routes)
561 }
562
563 pub fn layer<NewLayer>(self, new_layer: NewLayer) -> Server<Stack<NewLayer, L>> {
625 Server {
626 service_builder: self.service_builder.layer(new_layer),
627 trace_interceptor: self.trace_interceptor,
628 concurrency_limit: self.concurrency_limit,
629 load_shed: self.load_shed,
630 timeout: self.timeout,
631 #[cfg(feature = "_tls-any")]
632 tls: self.tls,
633 init_stream_window_size: self.init_stream_window_size,
634 init_connection_window_size: self.init_connection_window_size,
635 max_concurrent_streams: self.max_concurrent_streams,
636 tcp_keepalive: self.tcp_keepalive,
637 tcp_keepalive_interval: self.tcp_keepalive_interval,
638 tcp_keepalive_retries: self.tcp_keepalive_retries,
639 tcp_nodelay: self.tcp_nodelay,
640 http2_keepalive_interval: self.http2_keepalive_interval,
641 http2_keepalive_timeout: self.http2_keepalive_timeout,
642 http2_adaptive_window: self.http2_adaptive_window,
643 http2_max_pending_accept_reset_streams: self.http2_max_pending_accept_reset_streams,
644 http2_max_header_list_size: self.http2_max_header_list_size,
645 http2_max_local_error_reset_streams: self.http2_max_local_error_reset_streams,
646 max_frame_size: self.max_frame_size,
647 accept_http1: self.accept_http1,
648 max_connection_age: self.max_connection_age,
649 max_connection_age_grace: self.max_connection_age_grace,
650 }
651 }
652
653 fn bind_incoming(&self, addr: SocketAddr) -> Result<TcpIncoming, super::Error> {
654 Ok(TcpIncoming::bind(addr)
655 .map_err(super::Error::from_source)?
656 .with_nodelay(Some(self.tcp_nodelay))
657 .with_keepalive(self.tcp_keepalive)
658 .with_keepalive_interval(self.tcp_keepalive_interval)
659 .with_keepalive_retries(self.tcp_keepalive_retries))
660 }
661
662 pub async fn serve<S, ResBody>(self, addr: SocketAddr, svc: S) -> Result<(), super::Error>
664 where
665 L: Layer<S>,
666 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
667 <<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
668 <<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
669 Into<crate::BoxError> + Send + 'static,
670 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
671 ResBody::Error: Into<crate::BoxError>,
672 {
673 let incoming = self.bind_incoming(addr)?;
674 self.serve_with_incoming(svc, incoming).await
675 }
676
677 pub async fn serve_with_shutdown<S, F, ResBody>(
679 self,
680 addr: SocketAddr,
681 svc: S,
682 signal: F,
683 ) -> Result<(), super::Error>
684 where
685 L: Layer<S>,
686 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
687 <<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
688 <<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
689 Into<crate::BoxError> + Send + 'static,
690 F: Future<Output = ()>,
691 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
692 ResBody::Error: Into<crate::BoxError>,
693 {
694 let incoming = self.bind_incoming(addr)?;
695 self.serve_with_incoming_shutdown(svc, incoming, signal)
696 .await
697 }
698
699 pub async fn serve_with_incoming<S, I, IO, IE, ResBody>(
703 self,
704 svc: S,
705 incoming: I,
706 ) -> Result<(), super::Error>
707 where
708 L: Layer<S>,
709 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
710 <<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
711 <<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
712 Into<crate::BoxError> + Send + 'static,
713 I: Stream<Item = Result<IO, IE>>,
714 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
715 IE: Into<crate::BoxError>,
716 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
717 ResBody::Error: Into<crate::BoxError>,
718 {
719 self.serve_internal(svc, incoming, Option::<future::Ready<()>>::None)
720 .await
721 }
722
723 pub async fn serve_with_incoming_shutdown<S, I, F, IO, IE, ResBody>(
725 self,
726 svc: S,
727 incoming: I,
728 signal: F,
729 ) -> Result<(), super::Error>
730 where
731 L: Layer<S>,
732 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
733 <<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
734 <<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
735 Into<crate::BoxError> + Send + 'static,
736 I: Stream<Item = Result<IO, IE>>,
737 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
738 IE: Into<crate::BoxError>,
739 F: Future<Output = ()>,
740 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
741 ResBody::Error: Into<crate::BoxError>,
742 {
743 self.serve_internal(svc, incoming, Some(signal)).await
744 }
745
746 async fn serve_internal<S, I, F, IO, IE, ResBody>(
747 self,
748 svc: S,
749 incoming: I,
750 signal: Option<F>,
751 ) -> Result<(), super::Error>
752 where
753 L: Layer<S>,
754 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
755 <<L as Layer<S>>::Service as Service<Request<Body>>>::Future: Send,
756 <<L as Layer<S>>::Service as Service<Request<Body>>>::Error:
757 Into<crate::BoxError> + Send + 'static,
758 I: Stream<Item = Result<IO, IE>>,
759 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
760 IE: Into<crate::BoxError>,
761 F: Future<Output = ()>,
762 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
763 ResBody::Error: Into<crate::BoxError>,
764 {
765 let trace_interceptor = self.trace_interceptor.clone();
766 let concurrency_limit = self.concurrency_limit;
767 let load_shed = self.load_shed;
768 let init_connection_window_size = self.init_connection_window_size;
769 let init_stream_window_size = self.init_stream_window_size;
770 let max_concurrent_streams = self.max_concurrent_streams;
771 let timeout = self.timeout;
772 let max_header_list_size = self.http2_max_header_list_size;
773 let max_frame_size = self.max_frame_size;
774 let http2_only = !self.accept_http1;
775
776 let http2_keepalive_interval = self.http2_keepalive_interval;
777 let http2_keepalive_timeout = self.http2_keepalive_timeout;
778 let http2_adaptive_window = self.http2_adaptive_window;
779 let http2_max_pending_accept_reset_streams = self.http2_max_pending_accept_reset_streams;
780 let http2_max_local_error_reset_streams = self.http2_max_local_error_reset_streams;
781 let max_connection_age = self.max_connection_age;
782 let max_connection_age_grace = self.max_connection_age_grace;
783
784 let svc = self.service_builder.service(svc);
785
786 let incoming = io_stream::ServerIoStream::new(
787 incoming,
788 #[cfg(feature = "_tls-any")]
789 self.tls,
790 );
791 let mut svc = MakeSvc {
792 inner: svc,
793 concurrency_limit,
794 load_shed,
795 timeout,
796 trace_interceptor,
797 _io: PhantomData,
798 };
799
800 let server = {
801 let mut builder = ConnectionBuilder::new(TokioExecutor::new());
802
803 if http2_only {
804 builder = builder.http2_only();
805 }
806
807 builder
808 .http2()
809 .timer(TokioTimer::new())
810 .initial_connection_window_size(init_connection_window_size)
811 .initial_stream_window_size(init_stream_window_size)
812 .max_concurrent_streams(max_concurrent_streams)
813 .keep_alive_interval(http2_keepalive_interval)
814 .keep_alive_timeout(http2_keepalive_timeout)
815 .adaptive_window(http2_adaptive_window.unwrap_or_default())
816 .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams)
817 .max_local_error_reset_streams(http2_max_local_error_reset_streams)
818 .max_frame_size(max_frame_size);
819
820 if let Some(max_header_list_size) = max_header_list_size {
821 builder.http2().max_header_list_size(max_header_list_size);
822 }
823
824 builder
825 };
826
827 let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
828 let signal_tx = Arc::new(signal_tx);
829
830 let graceful = signal.is_some();
831 let mut sig = pin!(Fuse { inner: signal });
832 let mut incoming = pin!(incoming);
833
834 loop {
835 tokio::select! {
836 _ = &mut sig => {
837 trace!("signal received, shutting down");
838 break;
839 },
840 io = incoming.next() => {
841 let io = match io {
842 Some(Ok(io)) => io,
843 Some(Err(e)) => {
844 trace!("error accepting connection: {}", DisplayErrorStack(&*e));
845 continue;
846 },
847 None => {
848 break
849 },
850 };
851
852 trace!("connection accepted");
853
854 let req_svc = svc
855 .call(&io)
856 .await
857 .map_err(super::Error::from_source)?;
858
859 let hyper_io = TokioIo::new(io);
860 let hyper_svc = TowerToHyperService::new(req_svc.map_request(|req: Request<Incoming>| req.map(Body::new)));
861
862 serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), max_connection_age, max_connection_age_grace);
863 }
864 }
865 }
866
867 if graceful {
868 let _ = signal_tx.send(());
869 drop(signal_rx);
870 trace!(
871 "waiting for {} connections to close",
872 signal_tx.receiver_count()
873 );
874
875 signal_tx.closed().await;
877 }
878
879 Ok(())
880 }
881}
882
883enum TimeoutAction {
884 GracefulShutdown,
885 ForcefulShutdown,
886}
887
888async fn connection_timeout_future(
889 max_connection_age: Option<Duration>,
890 max_connection_age_grace: Option<Duration>,
891) -> TimeoutAction {
892 if let Some(age) = max_connection_age {
893 tokio::time::sleep(age).await;
894
895 if let Some(grace) = max_connection_age_grace {
896 tokio::time::sleep(grace).await;
897 TimeoutAction::ForcefulShutdown
898 } else {
899 TimeoutAction::GracefulShutdown
900 }
901 } else {
902 future::pending().await
903 }
904}
905
906fn serve_connection<B, IO, S, E>(
909 hyper_io: IO,
910 hyper_svc: S,
911 builder: ConnectionBuilder<E>,
912 mut watcher: Option<tokio::sync::watch::Receiver<()>>,
913 max_connection_age: Option<Duration>,
914 max_connection_age_grace: Option<Duration>,
915) where
916 B: http_body::Body + Send + 'static,
917 B::Data: Send,
918 B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
919 IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
920 S: HyperService<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
921 S::Future: Send + 'static,
922 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
923 E: HttpServerConnExec<S::Future, B> + Send + Sync + 'static,
924{
925 tokio::spawn(async move {
926 {
927 let mut sig = pin!(Fuse {
928 inner: watcher.as_mut().map(|w| w.changed()),
929 });
930
931 let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc));
932
933 let mut connection_timeout = pin!(connection_timeout_future(
934 max_connection_age,
935 max_connection_age_grace,
936 ));
937
938 loop {
939 tokio::select! {
940 rv = &mut conn => {
941 if let Err(err) = rv {
942 debug!("failed serving connection: {}", DisplayErrorStack(&*err));
943 }
944 break;
945 },
946 timeout_action = &mut connection_timeout => {
947 match timeout_action {
948 TimeoutAction::GracefulShutdown => {
949 conn.as_mut().graceful_shutdown();
950 },
951 TimeoutAction::ForcefulShutdown => {
952 debug!("forcefully closed connection");
953 break;
954 }
955 }
956 },
957 _ = &mut sig => {
958 conn.as_mut().graceful_shutdown();
959 },
960 }
961 }
962 }
963
964 drop(watcher);
965 trace!("connection closed");
966 });
967}
968
969#[cfg(feature = "router")]
970impl<L> Router<L> {
971 pub(crate) fn new(server: Server<L>, routes: Routes) -> Self {
972 Self { server, routes }
973 }
974}
975
976#[cfg(feature = "router")]
977impl<L> Router<L> {
978 pub fn add_service<S>(mut self, svc: S) -> Self
980 where
981 S: Service<Request<Body>, Error = Infallible>
982 + NamedService
983 + Clone
984 + Send
985 + Sync
986 + 'static,
987 S::Response: axum::response::IntoResponse,
988 S::Future: Send + 'static,
989 {
990 self.routes = self.routes.add_service(svc);
991 self
992 }
993
994 pub fn add_optional_service<S>(mut self, svc: Option<S>) -> Self
1000 where
1001 S: Service<Request<Body>, Error = Infallible>
1002 + NamedService
1003 + Clone
1004 + Send
1005 + Sync
1006 + 'static,
1007 S::Response: axum::response::IntoResponse,
1008 S::Future: Send + 'static,
1009 {
1010 if let Some(svc) = svc {
1011 self.routes = self.routes.add_service(svc);
1012 }
1013 self
1014 }
1015
1016 pub async fn serve<ResBody>(self, addr: SocketAddr) -> Result<(), super::Error>
1022 where
1023 L: Layer<Routes> + Clone,
1024 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
1025 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
1026 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
1027 Into<crate::BoxError> + Send,
1028 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1029 ResBody::Error: Into<crate::BoxError>,
1030 {
1031 self.server.serve(addr, self.routes.prepare()).await
1032 }
1033
1034 pub async fn serve_with_shutdown<F: Future<Output = ()>, ResBody>(
1041 self,
1042 addr: SocketAddr,
1043 signal: F,
1044 ) -> Result<(), super::Error>
1045 where
1046 L: Layer<Routes>,
1047 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
1048 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
1049 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
1050 Into<crate::BoxError> + Send,
1051 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1052 ResBody::Error: Into<crate::BoxError>,
1053 {
1054 self.server
1055 .serve_with_shutdown(addr, self.routes.prepare(), signal)
1056 .await
1057 }
1058
1059 pub async fn serve_with_incoming<I, IO, IE, ResBody>(
1066 self,
1067 incoming: I,
1068 ) -> Result<(), super::Error>
1069 where
1070 I: Stream<Item = Result<IO, IE>>,
1071 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
1072 IE: Into<crate::BoxError>,
1073 L: Layer<Routes>,
1074
1075 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
1076 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
1077 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
1078 Into<crate::BoxError> + Send,
1079 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1080 ResBody::Error: Into<crate::BoxError>,
1081 {
1082 self.server
1083 .serve_with_incoming(self.routes.prepare(), incoming)
1084 .await
1085 }
1086
1087 pub async fn serve_with_incoming_shutdown<I, IO, IE, F, ResBody>(
1096 self,
1097 incoming: I,
1098 signal: F,
1099 ) -> Result<(), super::Error>
1100 where
1101 I: Stream<Item = Result<IO, IE>>,
1102 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
1103 IE: Into<crate::BoxError>,
1104 F: Future<Output = ()>,
1105 L: Layer<Routes>,
1106 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
1107 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Future: Send,
1108 <<L as Layer<Routes>>::Service as Service<Request<Body>>>::Error:
1109 Into<crate::BoxError> + Send,
1110 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1111 ResBody::Error: Into<crate::BoxError>,
1112 {
1113 self.server
1114 .serve_with_incoming_shutdown(self.routes.prepare(), incoming, signal)
1115 .await
1116 }
1117}
1118
1119impl<L> fmt::Debug for Server<L> {
1120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1121 f.debug_struct("Builder").finish()
1122 }
1123}
1124
1125#[derive(Clone)]
1126struct Svc<S> {
1127 inner: S,
1128 trace_interceptor: Option<TraceInterceptor>,
1129}
1130
1131impl<S, ResBody> Service<Request<Body>> for Svc<S>
1132where
1133 S: Service<Request<Body>, Response = Response<ResBody>>,
1134 S::Error: Into<crate::BoxError>,
1135 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1136 ResBody::Error: Into<crate::BoxError>,
1137{
1138 type Response = Response<Body>;
1139 type Error = crate::BoxError;
1140 type Future = SvcFuture<S::Future>;
1141
1142 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1143 self.inner.poll_ready(cx).map_err(Into::into)
1144 }
1145
1146 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
1147 let span = if let Some(trace_interceptor) = &self.trace_interceptor {
1148 let (parts, body) = req.into_parts();
1149 let bodyless_request = Request::from_parts(parts, ());
1150
1151 let span = trace_interceptor(&bodyless_request);
1152
1153 let (parts, _) = bodyless_request.into_parts();
1154 req = Request::from_parts(parts, body);
1155
1156 span
1157 } else {
1158 tracing::Span::none()
1159 };
1160
1161 SvcFuture {
1162 inner: self.inner.call(req),
1163 span,
1164 }
1165 }
1166}
1167
1168#[pin_project]
1169struct SvcFuture<F> {
1170 #[pin]
1171 inner: F,
1172 span: tracing::Span,
1173}
1174
1175impl<F, E, ResBody> Future for SvcFuture<F>
1176where
1177 F: Future<Output = Result<Response<ResBody>, E>>,
1178 E: Into<crate::BoxError>,
1179 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1180 ResBody::Error: Into<crate::BoxError>,
1181{
1182 type Output = Result<Response<Body>, crate::BoxError>;
1183
1184 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1185 let this = self.project();
1186 let _guard = this.span.enter();
1187
1188 let response: Response<ResBody> = ready!(this.inner.poll(cx)).map_err(Into::into)?;
1189 let response = response.map(|body| Body::new(body.map_err(Into::into)));
1190 Poll::Ready(Ok(response))
1191 }
1192}
1193
1194impl<S> fmt::Debug for Svc<S> {
1195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1196 f.debug_struct("Svc").finish()
1197 }
1198}
1199
1200#[derive(Clone)]
1201struct MakeSvc<S, IO> {
1202 concurrency_limit: Option<usize>,
1203 load_shed: bool,
1204 timeout: Option<Duration>,
1205 inner: S,
1206 trace_interceptor: Option<TraceInterceptor>,
1207 _io: PhantomData<fn() -> IO>,
1208}
1209
1210impl<S, ResBody, IO> Service<&ServerIo<IO>> for MakeSvc<S, IO>
1211where
1212 IO: Connected + 'static,
1213 S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
1214 S::Future: Send,
1215 S::Error: Into<crate::BoxError> + Send,
1216 ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1217 ResBody::Error: Into<crate::BoxError>,
1218{
1219 type Response = BoxService;
1220 type Error = crate::BoxError;
1221 type Future = future::Ready<Result<Self::Response, Self::Error>>;
1222
1223 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1224 Ok(()).into()
1225 }
1226
1227 fn call(&mut self, io: &ServerIo<IO>) -> Self::Future {
1228 let conn_info = io.connect_info();
1229
1230 let svc = self.inner.clone();
1231 let concurrency_limit = self.concurrency_limit;
1232 let timeout = self.timeout;
1233 let trace_interceptor = self.trace_interceptor.clone();
1234
1235 let svc = ServiceBuilder::new()
1236 .layer(RecoverErrorLayer::new())
1237 .option_layer(self.load_shed.then_some(LoadShedLayer::new()))
1238 .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
1239 .layer_fn(|s| GrpcTimeout::new(s, timeout))
1240 .service(svc);
1241
1242 let svc = ServiceBuilder::new()
1243 .layer(BoxCloneService::layer())
1244 .layer(ConnectInfoLayer::new(conn_info.clone()))
1245 .service(Svc {
1246 inner: svc,
1247 trace_interceptor,
1248 });
1249
1250 future::ready(Ok(svc))
1251 }
1252}
1253
1254#[pin_project]
1258struct Fuse<F> {
1259 #[pin]
1260 inner: Option<F>,
1261}
1262
1263impl<F> Future for Fuse<F>
1264where
1265 F: Future,
1266{
1267 type Output = F::Output;
1268
1269 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1270 match self.as_mut().project().inner.as_pin_mut() {
1271 Some(fut) => fut.poll(cx).map(|output| {
1272 self.project().inner.set(None);
1273 output
1274 }),
1275 None => Poll::Pending,
1276 }
1277 }
1278}
1279
1280#[cfg(test)]
1281mod tests {
1282 use super::*;
1283 use crate::transport::Server;
1284 use std::time::Duration;
1285
1286 #[tokio::test(start_paused = true)]
1287 async fn test_connection_timeout_no_max_age() {
1288 let future = connection_timeout_future(None, None);
1289
1290 tokio::select! {
1291 _ = future => {
1292 panic!("timeout future should never complete when max_connection_age is None");
1293 }
1294 _ = tokio::time::sleep(Duration::from_secs(1000)) => {
1295 }
1296 }
1297 }
1298
1299 #[tokio::test(start_paused = true)]
1300 async fn test_connection_timeout_with_max_connection_age() {
1301 let future = connection_timeout_future(Some(Duration::from_secs(10)), None);
1302
1303 let action = future.await;
1304 assert!(matches!(action, TimeoutAction::GracefulShutdown));
1305 }
1306
1307 #[tokio::test(start_paused = true)]
1308 async fn test_connection_timeout_with_max_connection_age_grace() {
1309 let mut future = pin!(connection_timeout_future(
1310 Some(Duration::from_secs(10)),
1311 Some(Duration::from_secs(5)),
1312 ));
1313
1314 tokio::select! {
1315 _ = &mut future => {
1316 panic!("should not complete before max_connection_age");
1317 }
1318 _ = tokio::time::sleep(Duration::from_secs(9)) => {}
1319 }
1320
1321 tokio::select! {
1322 _ = &mut future => {
1323 panic!("should not complete before max_connection_age_grace");
1324 }
1325 _ = tokio::time::sleep(Duration::from_secs(4)) => {}
1326 }
1327
1328 let action = future.await;
1329 assert!(matches!(action, TimeoutAction::ForcefulShutdown));
1330 }
1331
1332 #[test]
1333 fn server_tcp_defaults() {
1334 const EXAMPLE_TCP_KEEPALIVE: Duration = Duration::from_secs(10);
1335 const EXAMPLE_TCP_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(5);
1336 const EXAMPLE_TCP_KEEPALIVE_RETRIES: u32 = 3;
1337
1338 let server_via_builder = Server::builder();
1340 assert!(server_via_builder.tcp_nodelay);
1341 assert_eq!(server_via_builder.tcp_keepalive, None);
1342 assert_eq!(server_via_builder.tcp_keepalive_interval, None);
1343 assert_eq!(server_via_builder.tcp_keepalive_retries, None);
1344 let server_via_default = Server::default();
1345 assert!(server_via_default.tcp_nodelay);
1346 assert_eq!(server_via_default.tcp_keepalive, None);
1347 assert_eq!(server_via_default.tcp_keepalive_interval, None);
1348 assert_eq!(server_via_default.tcp_keepalive_retries, None);
1349
1350 let server_via_builder = Server::builder()
1352 .tcp_nodelay(false)
1353 .tcp_keepalive(Some(EXAMPLE_TCP_KEEPALIVE))
1354 .tcp_keepalive_interval(Some(EXAMPLE_TCP_KEEPALIVE_INTERVAL))
1355 .tcp_keepalive_retries(Some(EXAMPLE_TCP_KEEPALIVE_RETRIES));
1356 assert!(!server_via_builder.tcp_nodelay);
1357 assert_eq!(
1358 server_via_builder.tcp_keepalive,
1359 Some(EXAMPLE_TCP_KEEPALIVE)
1360 );
1361 assert_eq!(
1362 server_via_builder.tcp_keepalive_interval,
1363 Some(EXAMPLE_TCP_KEEPALIVE_INTERVAL)
1364 );
1365 assert_eq!(
1366 server_via_builder.tcp_keepalive_retries,
1367 Some(EXAMPLE_TCP_KEEPALIVE_RETRIES)
1368 );
1369 }
1370}