Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 127 additions & 5 deletions src/server/conn/auto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ use std::future::Future;
use std::marker::PhantomPinned;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use std::{error::Error as StdError, io, time::Duration};
use std::{error::Error as StdError, fmt, io, time::Duration};

use bytes::Bytes;
use http::{Request, Response};
use http_body::Body;
use hyper::{
body::Incoming,
rt::{Read, ReadBuf, Timer, Write},
rt::{Read, ReadBuf, Sleep, Timer, Write},
service::Service,
};

Expand Down Expand Up @@ -52,6 +53,19 @@ pub trait HttpServerConnExec<A, B: Body> {}
#[cfg(not(feature = "http2"))]
impl<A, B: Body, T> HttpServerConnExec<A, B> for T {}

/// A type-erased, shareable [`Timer`] used to bound the protocol-detection read.
///
/// Wrapping the `Arc` in a newtype lets [`Builder`] keep its `Clone`/`Debug`
/// derives even though `dyn Timer` is neither.
#[derive(Clone)]
struct ReadVersionTimer(Arc<dyn Timer + Send + Sync>);

impl fmt::Debug for ReadVersionTimer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadVersionTimer").finish_non_exhaustive()
}
}

/// Http1 or Http2 connection builder.
#[derive(Clone, Debug)]
pub struct Builder<E> {
Expand All @@ -61,6 +75,13 @@ pub struct Builder<E> {
http2: http2::Builder<E>,
#[cfg(any(feature = "http1", feature = "http2"))]
version: Option<Version>,
/// Timer used to drive [`Self::read_version_timeout`]; `None` disables it.
#[cfg(any(feature = "http1", feature = "http2"))]
read_version_timer: Option<ReadVersionTimer>,
/// How long the protocol-detection read may take before the connection is
/// closed. Only enforced when `read_version_timer` is also set.
#[cfg(any(feature = "http1", feature = "http2"))]
read_version_timeout: Option<Duration>,
#[cfg(not(feature = "http2"))]
_executor: E,
}
Expand Down Expand Up @@ -98,11 +119,39 @@ impl<E> Builder<E> {
http2: http2::Builder::new(executor),
#[cfg(any(feature = "http1", feature = "http2"))]
version: None,
#[cfg(any(feature = "http1", feature = "http2"))]
read_version_timer: None,
#[cfg(any(feature = "http1", feature = "http2"))]
read_version_timeout: None,
#[cfg(not(feature = "http2"))]
_executor: executor,
}
}

/// Set a timeout for the initial protocol-detection read.
///
/// `serve_connection`/`serve_connection_with_upgrades` first read a few
/// bytes to decide whether the connection is HTTP/1 or HTTP/2. Until that
/// read completes the per-protocol settings — including
/// [`Http1Builder::header_read_timeout`] — are not yet active, so a peer
/// that connects but never sends data would otherwise keep the connection
/// open indefinitely.
///
/// With this set, the connection is closed if the protocol is not detected
/// within `duration`. Has no effect when `http1_only`/`http2_only` is used,
/// since no detection read happens then.
///
/// [`Http1Builder::header_read_timeout`]: Http1Builder::header_read_timeout
#[cfg(any(feature = "http1", feature = "http2"))]
pub fn read_version_timeout<M>(&mut self, timer: M, duration: Duration) -> &mut Self
where
M: Timer + Send + Sync + 'static,
{
self.read_version_timer = Some(ReadVersionTimer(Arc::new(timer)));
self.read_version_timeout = Some(duration);
self
}

/// Http1 configuration.
#[cfg(feature = "http1")]
pub fn http1(&mut self) -> Http1Builder<'_, E> {
Expand Down Expand Up @@ -244,7 +293,11 @@ impl<E> Builder<E> {
}
#[cfg(any(feature = "http1", feature = "http2"))]
_ => ConnState::ReadVersion {
read_version: read_version(io),
read_version: read_version(
io,
self.read_version_timer.clone(),
self.read_version_timeout,
),
builder: Cow::Borrowed(self),
service: Some(service),
},
Expand Down Expand Up @@ -278,7 +331,11 @@ impl<E> Builder<E> {
{
UpgradeableConnection {
state: UpgradeableConnState::ReadVersion {
read_version: read_version(io),
read_version: read_version(
io,
self.read_version_timer.clone(),
self.read_version_timeout,
),
builder: Cow::Borrowed(self),
service: Some(service),
},
Expand All @@ -303,15 +360,31 @@ impl Version {
}
}

fn read_version<I>(io: I) -> ReadVersion<I>
#[cfg_attr(
not(any(feature = "http1", feature = "http2")),
expect(dead_code, reason = "only constructed when a protocol feature is on")
)]
fn read_version<I>(
io: I,
timer: Option<ReadVersionTimer>,
timeout: Option<Duration>,
) -> ReadVersion<I>
where
I: Read + Unpin,
{
// Arm the detection-read timeout eagerly so it counts from the moment we
// start serving, not from the first byte received.
let timeout = match (timer, timeout) {
(Some(timer), Some(duration)) => Some(timer.0.sleep(duration)),
_ => None,
};

ReadVersion {
io: Some(io),
buf: [MaybeUninit::uninit(); 24],
filled: 0,
version: Version::H2,
timeout,
cancelled: false,
_pin: PhantomPinned,
}
Expand All @@ -324,6 +397,9 @@ pin_project! {
// the amount of `buf` thats been filled
filled: usize,
version: Version,
// Optional deadline for the protocol-detection read. A `Pin<Box<dyn Sleep>>`
// is already heap-pinned, so it needs no `#[pin]` projection.
timeout: Option<Pin<Box<dyn Sleep>>>,
cancelled: bool,
// Make this future `!Unpin` for compatibility with async trait methods.
#[pin]
Expand All @@ -349,6 +425,17 @@ where
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled")));
}

// Close the connection if the protocol isn't detected in time. Polling
// here registers the timer's waker, so an idle peer is woken on expiry.
if let Some(timeout) = this.timeout.as_mut() {
if timeout.as_mut().poll(cx).is_ready() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::TimedOut,
"timed out detecting connection protocol",
)));
}
}

let mut buf = ReadBuf::uninit(&mut *this.buf);
// SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and
// we're only advancing by that many.
Expand Down Expand Up @@ -1315,6 +1402,41 @@ mod tests {
assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted);
}

// A peer that connects but never sends the protocol-detection bytes must be
// dropped once `read_version_timeout` elapses, rather than held open
// indefinitely (the gap behind hyperium/hyper#3756).
#[cfg(not(miri))]
#[tokio::test]
async fn read_version_timeout_closes_idle_connection() {
use crate::rt::TokioTimer;

let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap();
let listener_addr = listener.local_addr().unwrap();

let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() });
// Connect but never send any bytes, so protocol detection cannot complete.
let _stream = TcpStream::connect(listener_addr).await.unwrap();

let (stream, _) = listen_task.await.unwrap();
let stream = TokioIo::new(stream);

let mut builder = auto::Builder::new(TokioExecutor::new());
builder.read_version_timeout(TokioTimer::new(), Duration::from_millis(50));
let connection = builder.serve_connection_with_upgrades(stream, service_fn(hello));

let connection_error = tokio::time::timeout(Duration::from_secs(1), connection)
.await
.expect("connection should close promptly once the read-version timeout elapses")
.expect_err("connection should error out on the read-version timeout");

let io_err = connection_error
.downcast_ref::<std::io::Error>()
.expect("The error should have been `std::io::Error`.");
assert_eq!(io_err.kind(), std::io::ErrorKind::TimedOut);
}

async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B>
where
B: Body + Send + 'static,
Expand Down