diff --git a/src/server/conn/auto/mod.rs b/src/server/conn/auto/mod.rs index bdadf4b1..867927dc 100644 --- a/src/server/conn/auto/mod.rs +++ b/src/server/conn/auto/mod.rs @@ -7,15 +7,16 @@ use std::future::Future; use std::marker::PhantomPinned; use std::mem::MaybeUninit; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll, ready}; -use std::{error::Error as StdError, io, time::Duration}; +use std::{error::Error as StdError, fmt, io, time::Duration}; use bytes::Bytes; use http::{Request, Response}; use http_body::Body; use hyper::{ body::Incoming, - rt::{Read, ReadBuf, Timer, Write}, + rt::{Read, ReadBuf, Sleep, Timer, Write}, service::Service, }; @@ -52,6 +53,19 @@ pub trait HttpServerConnExec {} #[cfg(not(feature = "http2"))] impl HttpServerConnExec for T {} +/// A type-erased, shareable [`Timer`] used to bound the protocol-detection read. +/// +/// Wrapping the `Arc` in a newtype lets [`Builder`] keep its `Clone`/`Debug` +/// derives even though `dyn Timer` is neither. +#[derive(Clone)] +struct ReadVersionTimer(Arc); + +impl fmt::Debug for ReadVersionTimer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadVersionTimer").finish_non_exhaustive() + } +} + /// Http1 or Http2 connection builder. #[derive(Clone, Debug)] pub struct Builder { @@ -61,6 +75,13 @@ pub struct Builder { http2: http2::Builder, #[cfg(any(feature = "http1", feature = "http2"))] version: Option, + /// Timer used to drive [`Self::read_version_timeout`]; `None` disables it. + #[cfg(any(feature = "http1", feature = "http2"))] + read_version_timer: Option, + /// How long the protocol-detection read may take before the connection is + /// closed. Only enforced when `read_version_timer` is also set. + #[cfg(any(feature = "http1", feature = "http2"))] + read_version_timeout: Option, #[cfg(not(feature = "http2"))] _executor: E, } @@ -98,11 +119,39 @@ impl Builder { http2: http2::Builder::new(executor), #[cfg(any(feature = "http1", feature = "http2"))] version: None, + #[cfg(any(feature = "http1", feature = "http2"))] + read_version_timer: None, + #[cfg(any(feature = "http1", feature = "http2"))] + read_version_timeout: None, #[cfg(not(feature = "http2"))] _executor: executor, } } + /// Set a timeout for the initial protocol-detection read. + /// + /// `serve_connection`/`serve_connection_with_upgrades` first read a few + /// bytes to decide whether the connection is HTTP/1 or HTTP/2. Until that + /// read completes the per-protocol settings — including + /// [`Http1Builder::header_read_timeout`] — are not yet active, so a peer + /// that connects but never sends data would otherwise keep the connection + /// open indefinitely. + /// + /// With this set, the connection is closed if the protocol is not detected + /// within `duration`. Has no effect when `http1_only`/`http2_only` is used, + /// since no detection read happens then. + /// + /// [`Http1Builder::header_read_timeout`]: Http1Builder::header_read_timeout + #[cfg(any(feature = "http1", feature = "http2"))] + pub fn read_version_timeout(&mut self, timer: M, duration: Duration) -> &mut Self + where + M: Timer + Send + Sync + 'static, + { + self.read_version_timer = Some(ReadVersionTimer(Arc::new(timer))); + self.read_version_timeout = Some(duration); + self + } + /// Http1 configuration. #[cfg(feature = "http1")] pub fn http1(&mut self) -> Http1Builder<'_, E> { @@ -244,7 +293,11 @@ impl Builder { } #[cfg(any(feature = "http1", feature = "http2"))] _ => ConnState::ReadVersion { - read_version: read_version(io), + read_version: read_version( + io, + self.read_version_timer.clone(), + self.read_version_timeout, + ), builder: Cow::Borrowed(self), service: Some(service), }, @@ -278,7 +331,11 @@ impl Builder { { UpgradeableConnection { state: UpgradeableConnState::ReadVersion { - read_version: read_version(io), + read_version: read_version( + io, + self.read_version_timer.clone(), + self.read_version_timeout, + ), builder: Cow::Borrowed(self), service: Some(service), }, @@ -303,15 +360,31 @@ impl Version { } } -fn read_version(io: I) -> ReadVersion +#[cfg_attr( + not(any(feature = "http1", feature = "http2")), + expect(dead_code, reason = "only constructed when a protocol feature is on") +)] +fn read_version( + io: I, + timer: Option, + timeout: Option, +) -> ReadVersion where I: Read + Unpin, { + // Arm the detection-read timeout eagerly so it counts from the moment we + // start serving, not from the first byte received. + let timeout = match (timer, timeout) { + (Some(timer), Some(duration)) => Some(timer.0.sleep(duration)), + _ => None, + }; + ReadVersion { io: Some(io), buf: [MaybeUninit::uninit(); 24], filled: 0, version: Version::H2, + timeout, cancelled: false, _pin: PhantomPinned, } @@ -324,6 +397,9 @@ pin_project! { // the amount of `buf` thats been filled filled: usize, version: Version, + // Optional deadline for the protocol-detection read. A `Pin>` + // is already heap-pinned, so it needs no `#[pin]` projection. + timeout: Option>>, cancelled: bool, // Make this future `!Unpin` for compatibility with async trait methods. #[pin] @@ -349,6 +425,17 @@ where return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled"))); } + // Close the connection if the protocol isn't detected in time. Polling + // here registers the timer's waker, so an idle peer is woken on expiry. + if let Some(timeout) = this.timeout.as_mut() { + if timeout.as_mut().poll(cx).is_ready() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::TimedOut, + "timed out detecting connection protocol", + ))); + } + } + let mut buf = ReadBuf::uninit(&mut *this.buf); // SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and // we're only advancing by that many. @@ -1315,6 +1402,41 @@ mod tests { assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted); } + // A peer that connects but never sends the protocol-detection bytes must be + // dropped once `read_version_timeout` elapses, rather than held open + // indefinitely (the gap behind hyperium/hyper#3756). + #[cfg(not(miri))] + #[tokio::test] + async fn read_version_timeout_closes_idle_connection() { + use crate::rt::TokioTimer; + + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let listener_addr = listener.local_addr().unwrap(); + + let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() }); + // Connect but never send any bytes, so protocol detection cannot complete. + let _stream = TcpStream::connect(listener_addr).await.unwrap(); + + let (stream, _) = listen_task.await.unwrap(); + let stream = TokioIo::new(stream); + + let mut builder = auto::Builder::new(TokioExecutor::new()); + builder.read_version_timeout(TokioTimer::new(), Duration::from_millis(50)); + let connection = builder.serve_connection_with_upgrades(stream, service_fn(hello)); + + let connection_error = tokio::time::timeout(Duration::from_secs(1), connection) + .await + .expect("connection should close promptly once the read-version timeout elapses") + .expect_err("connection should error out on the read-version timeout"); + + let io_err = connection_error + .downcast_ref::() + .expect("The error should have been `std::io::Error`."); + assert_eq!(io_err.kind(), std::io::ErrorKind::TimedOut); + } + async fn connect_h1(addr: SocketAddr) -> client::conn::http1::SendRequest where B: Body + Send + 'static,