diff --git a/src/server/conn/auto/mod.rs b/src/server/conn/auto/mod.rs
index bdadf4b1..867927dc 100644
--- a/src/server/conn/auto/mod.rs
+++ b/src/server/conn/auto/mod.rs
@@ -7,15 +7,16 @@ use std::future::Future;
use std::marker::PhantomPinned;
use std::mem::MaybeUninit;
use std::pin::Pin;
+use std::sync::Arc;
use std::task::{Context, Poll, ready};
-use std::{error::Error as StdError, io, time::Duration};
+use std::{error::Error as StdError, fmt, io, time::Duration};
use bytes::Bytes;
use http::{Request, Response};
use http_body::Body;
use hyper::{
body::Incoming,
- rt::{Read, ReadBuf, Timer, Write},
+ rt::{Read, ReadBuf, Sleep, Timer, Write},
service::Service,
};
@@ -52,6 +53,19 @@ pub trait HttpServerConnExec {}
#[cfg(not(feature = "http2"))]
impl HttpServerConnExec for T {}
+/// A type-erased, shareable [`Timer`] used to bound the protocol-detection read.
+///
+/// Wrapping the `Arc` in a newtype lets [`Builder`] keep its `Clone`/`Debug`
+/// derives even though `dyn Timer` is neither.
+#[derive(Clone)]
+struct ReadVersionTimer(Arc);
+
+impl fmt::Debug for ReadVersionTimer {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReadVersionTimer").finish_non_exhaustive()
+ }
+}
+
/// Http1 or Http2 connection builder.
#[derive(Clone, Debug)]
pub struct Builder {
@@ -61,6 +75,13 @@ pub struct Builder {
http2: http2::Builder,
#[cfg(any(feature = "http1", feature = "http2"))]
version: Option,
+ /// Timer used to drive [`Self::read_version_timeout`]; `None` disables it.
+ #[cfg(any(feature = "http1", feature = "http2"))]
+ read_version_timer: Option,
+ /// How long the protocol-detection read may take before the connection is
+ /// closed. Only enforced when `read_version_timer` is also set.
+ #[cfg(any(feature = "http1", feature = "http2"))]
+ read_version_timeout: Option,
#[cfg(not(feature = "http2"))]
_executor: E,
}
@@ -98,11 +119,39 @@ impl Builder {
http2: http2::Builder::new(executor),
#[cfg(any(feature = "http1", feature = "http2"))]
version: None,
+ #[cfg(any(feature = "http1", feature = "http2"))]
+ read_version_timer: None,
+ #[cfg(any(feature = "http1", feature = "http2"))]
+ read_version_timeout: None,
#[cfg(not(feature = "http2"))]
_executor: executor,
}
}
+ /// Set a timeout for the initial protocol-detection read.
+ ///
+ /// `serve_connection`/`serve_connection_with_upgrades` first read a few
+ /// bytes to decide whether the connection is HTTP/1 or HTTP/2. Until that
+ /// read completes the per-protocol settings — including
+ /// [`Http1Builder::header_read_timeout`] — are not yet active, so a peer
+ /// that connects but never sends data would otherwise keep the connection
+ /// open indefinitely.
+ ///
+ /// With this set, the connection is closed if the protocol is not detected
+ /// within `duration`. Has no effect when `http1_only`/`http2_only` is used,
+ /// since no detection read happens then.
+ ///
+ /// [`Http1Builder::header_read_timeout`]: Http1Builder::header_read_timeout
+ #[cfg(any(feature = "http1", feature = "http2"))]
+ pub fn read_version_timeout(&mut self, timer: M, duration: Duration) -> &mut Self
+ where
+ M: Timer + Send + Sync + 'static,
+ {
+ self.read_version_timer = Some(ReadVersionTimer(Arc::new(timer)));
+ self.read_version_timeout = Some(duration);
+ self
+ }
+
/// Http1 configuration.
#[cfg(feature = "http1")]
pub fn http1(&mut self) -> Http1Builder<'_, E> {
@@ -244,7 +293,11 @@ impl Builder {
}
#[cfg(any(feature = "http1", feature = "http2"))]
_ => ConnState::ReadVersion {
- read_version: read_version(io),
+ read_version: read_version(
+ io,
+ self.read_version_timer.clone(),
+ self.read_version_timeout,
+ ),
builder: Cow::Borrowed(self),
service: Some(service),
},
@@ -278,7 +331,11 @@ impl Builder {
{
UpgradeableConnection {
state: UpgradeableConnState::ReadVersion {
- read_version: read_version(io),
+ read_version: read_version(
+ io,
+ self.read_version_timer.clone(),
+ self.read_version_timeout,
+ ),
builder: Cow::Borrowed(self),
service: Some(service),
},
@@ -303,15 +360,31 @@ impl Version {
}
}
-fn read_version(io: I) -> ReadVersion
+#[cfg_attr(
+ not(any(feature = "http1", feature = "http2")),
+ expect(dead_code, reason = "only constructed when a protocol feature is on")
+)]
+fn read_version(
+ io: I,
+ timer: Option,
+ timeout: Option,
+) -> ReadVersion
where
I: Read + Unpin,
{
+ // Arm the detection-read timeout eagerly so it counts from the moment we
+ // start serving, not from the first byte received.
+ let timeout = match (timer, timeout) {
+ (Some(timer), Some(duration)) => Some(timer.0.sleep(duration)),
+ _ => None,
+ };
+
ReadVersion {
io: Some(io),
buf: [MaybeUninit::uninit(); 24],
filled: 0,
version: Version::H2,
+ timeout,
cancelled: false,
_pin: PhantomPinned,
}
@@ -324,6 +397,9 @@ pin_project! {
// the amount of `buf` thats been filled
filled: usize,
version: Version,
+ // Optional deadline for the protocol-detection read. A `Pin>`
+ // is already heap-pinned, so it needs no `#[pin]` projection.
+ timeout: Option>>,
cancelled: bool,
// Make this future `!Unpin` for compatibility with async trait methods.
#[pin]
@@ -349,6 +425,17 @@ where
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled")));
}
+ // Close the connection if the protocol isn't detected in time. Polling
+ // here registers the timer's waker, so an idle peer is woken on expiry.
+ if let Some(timeout) = this.timeout.as_mut() {
+ if timeout.as_mut().poll(cx).is_ready() {
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::TimedOut,
+ "timed out detecting connection protocol",
+ )));
+ }
+ }
+
let mut buf = ReadBuf::uninit(&mut *this.buf);
// SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and
// we're only advancing by that many.
@@ -1315,6 +1402,41 @@ mod tests {
assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted);
}
+ // A peer that connects but never sends the protocol-detection bytes must be
+ // dropped once `read_version_timeout` elapses, rather than held open
+ // indefinitely (the gap behind hyperium/hyper#3756).
+ #[cfg(not(miri))]
+ #[tokio::test]
+ async fn read_version_timeout_closes_idle_connection() {
+ use crate::rt::TokioTimer;
+
+ let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
+ .await
+ .unwrap();
+ let listener_addr = listener.local_addr().unwrap();
+
+ let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() });
+ // Connect but never send any bytes, so protocol detection cannot complete.
+ let _stream = TcpStream::connect(listener_addr).await.unwrap();
+
+ let (stream, _) = listen_task.await.unwrap();
+ let stream = TokioIo::new(stream);
+
+ let mut builder = auto::Builder::new(TokioExecutor::new());
+ builder.read_version_timeout(TokioTimer::new(), Duration::from_millis(50));
+ let connection = builder.serve_connection_with_upgrades(stream, service_fn(hello));
+
+ let connection_error = tokio::time::timeout(Duration::from_secs(1), connection)
+ .await
+ .expect("connection should close promptly once the read-version timeout elapses")
+ .expect_err("connection should error out on the read-version timeout");
+
+ let io_err = connection_error
+ .downcast_ref::()
+ .expect("The error should have been `std::io::Error`.");
+ assert_eq!(io_err.kind(), std::io::ErrorKind::TimedOut);
+ }
+
async fn connect_h1(addr: SocketAddr) -> client::conn::http1::SendRequest
where
B: Body + Send + 'static,