Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions src/windows/common/HandleIO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "HandleIO.h"
#pragma hdrstop

using wsl::windows::common::io::AcceptHandle;
using wsl::windows::common::io::BufferWrapper;
using wsl::windows::common::io::DockerIORelayHandle;
using wsl::windows::common::io::EventHandle;
Expand All @@ -16,7 +17,6 @@ using wsl::windows::common::io::OverlappedIOHandle;
using wsl::windows::common::io::ReadHandle;
using wsl::windows::common::io::ReadNamedPipe;
using wsl::windows::common::io::ReadSocketMessageHandle;
using wsl::windows::common::io::SingleAcceptHandle;
using wsl::windows::common::io::WriteHandle;
using wsl::windows::common::io::WriteNamedPipe;

Expand Down Expand Up @@ -360,41 +360,77 @@ void ReadNamedPipe::Collect()
ReadHandle::Collect();
}

// SingleAcceptHandle
// AcceptHandle

SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted) :
ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted))
AcceptHandle::AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::function<void(wil::unique_socket&&)>&& OnAccepted) :
ListenSocket(std::move(ListenSocket)), AcceptOnce(AcceptOnce), OnAccepted(std::move(OnAccepted))
{
Overlapped.hEvent = Event.get();

// Query the listen socket so accepted sockets can be created with a matching address family, type, and protocol.
WSAPROTOCOL_INFOW protocolInfo{};
int length = sizeof(protocolInfo);
THROW_LAST_ERROR_IF(
getsockopt(reinterpret_cast<SOCKET>(this->ListenSocket.Get()), SOL_SOCKET, SO_PROTOCOL_INFOW, reinterpret_cast<char*>(&protocolInfo), &length) ==
SOCKET_ERROR);

AddressFamily = protocolInfo.iAddressFamily;
SocketType = protocolInfo.iSocketType;
Protocol = protocolInfo.iProtocol;
}

SingleAcceptHandle::~SingleAcceptHandle()
AcceptHandle::~AcceptHandle()
{
if (State == IOHandleStatus::Pending)
{
LOG_IF_WIN32_BOOL_FALSE(CancelIoEx(ListenSocket.Get(), &Overlapped));
CancelPendingIo(reinterpret_cast<SOCKET>(ListenSocket.Get()), Overlapped);
}
}

DWORD bytesProcessed{};
DWORD flagsReturned{};
if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned))
{
auto error = GetLastError();
LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED);
}
void AcceptHandle::CreateAcceptSocket()
{
AcceptedSocket.reset(WSASocketW(AddressFamily, SocketType, Protocol, nullptr, 0, WSA_FLAG_OVERLAPPED));
THROW_LAST_ERROR_IF(!AcceptedSocket);

if (AddressFamily == AF_HYPERV)
{
ULONG enable = 1;
THROW_LAST_ERROR_IF(
setsockopt(AcceptedSocket.get(), HV_PROTOCOL_RAW, HVSOCKET_CONNECTED_SUSPEND, reinterpret_cast<char*>(&enable), sizeof(enable)) ==
SOCKET_ERROR);
}
}

void AcceptHandle::OnComplete()
{
wsl::windows::common::socket::SetAcceptContext(AcceptedSocket.get(), reinterpret_cast<SOCKET>(ListenSocket.Get()));

OnAccepted(std::move(AcceptedSocket));

if (AcceptOnce)
{
State = IOHandleStatus::Completed;
}
else
{
State = IOHandleStatus::Standby;
}
}

void SingleAcceptHandle::Schedule()
void AcceptHandle::Schedule()
{
WI_ASSERT(State == IOHandleStatus::Standby);

CreateAcceptSocket();

Event.ResetEvent();

// Schedule the accept.
DWORD bytesReturned{};
if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped))
if (AcceptEx((SOCKET)ListenSocket.Get(), AcceptedSocket.get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped))
{
// Accept completed immediately.
State = IOHandleStatus::Completed;
OnAccepted();
OnComplete();
}
else
{
Expand All @@ -405,7 +441,7 @@ void SingleAcceptHandle::Schedule()
}
}

void SingleAcceptHandle::Collect()
void AcceptHandle::Collect()
{
WI_ASSERT(State == IOHandleStatus::Pending);

Expand All @@ -414,11 +450,10 @@ void SingleAcceptHandle::Collect()

THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned));

State = IOHandleStatus::Completed;
OnAccepted();
OnComplete();
}

HANDLE SingleAcceptHandle::GetHandle() const
HANDLE AcceptHandle::GetHandle() const
{
return Event.get();
}
Expand Down
21 changes: 14 additions & 7 deletions src/windows/common/HandleIO.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,32 @@ class ReadNamedPipe : public ReadHandle
bool m_connected = false;
};

class SingleAcceptHandle : public OverlappedIOHandle
class AcceptHandle : public OverlappedIOHandle
{
public:
NON_COPYABLE(SingleAcceptHandle)
NON_MOVABLE(SingleAcceptHandle)
NON_COPYABLE(AcceptHandle)
NON_MOVABLE(AcceptHandle)

SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted);
~SingleAcceptHandle();
AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::function<void(wil::unique_socket&&)>&& OnAccepted);
~AcceptHandle();

void Schedule() override;
void Collect() override;
HANDLE GetHandle() const override;

private:
void CreateAcceptSocket();
void OnComplete();

HandleWrapper ListenSocket;
HandleWrapper AcceptedSocket;
wil::unique_socket AcceptedSocket;
int AddressFamily{};
int SocketType{};
int Protocol{};
bool AcceptOnce{};
wil::unique_event Event{wil::EventOptions::ManualReset};
OVERLAPPED Overlapped{};
std::function<void()> OnAccepted;
std::function<void(wil::unique_socket&&)> OnAccepted;
char AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)];
};

Expand Down
12 changes: 0 additions & 12 deletions src/windows/common/hvsocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address)
}
} // namespace

std::optional<wil::unique_socket> wsl::windows::common::hvsocket::CancellableAccept(
_In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
{
wil::unique_socket Socket = Create();
if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location))
{
return {};
}

return Socket;
}

wil::unique_socket wsl::windows::common::hvsocket::Connect(
_In_ const GUID& VmId, _In_ unsigned long Port, _In_opt_ HANDLE ExitHandle, _In_opt_ ULONG Timeout, _In_ const std::source_location& Location)
{
Expand Down
6 changes: 0 additions & 6 deletions src/windows/common/hvsocket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ Module Name:

namespace wsl::windows::common::hvsocket {

std::optional<wil::unique_socket> CancellableAccept(
_In_ SOCKET ListenSocket,
_In_ DWORD Timeout,
_In_opt_ HANDLE ExitHandle = nullptr,
const std::source_location& Location = std::source_location::current());

wil::unique_socket Connect(
_In_ const GUID& VmId,
_In_ unsigned long Port,
Expand Down
31 changes: 21 additions & 10 deletions src/windows/common/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,41 @@ void wsl::windows::common::socket::SetAcceptContext(_In_ SOCKET AcceptedSocket,
std::format("{}", Location).c_str());
}

bool wsl::windows::common::socket::CancellableAccept(
_In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
std::optional<wil::unique_socket> wsl::windows::common::socket::CancellableAccept(
_In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
{
io::MultiHandleWait io;

bool accepted = false;
std::optional<wil::unique_socket> accepted;

io.AddHandle(std::make_unique<io::SingleAcceptHandle>(ListenSocket, Socket, [&]() { accepted = true; }), io::MultiHandleWait::CancelOnCompleted);
io.AddHandle(
std::make_unique<io::AcceptHandle>(
ListenSocket, true, [&accepted](wil::unique_socket&& socket) { accepted = std::move(socket); }),
io::MultiHandleWait::CancelOnCompleted);

Comment on lines +29 to 40
if (ExitHandle != nullptr)
{
io.AddHandle(std::make_unique<io::EventHandle>(ExitHandle), io::MultiHandleWait::CancelOnCompleted);
}

io.Run(std::chrono::milliseconds(Timeout));

if (!accepted)
std::optional<std::chrono::milliseconds> timeout;
if (Timeout != INFINITE)
{
return false; // Accept was cancelled by the exit event.
timeout = std::chrono::milliseconds(Timeout);
}

SetAcceptContext(Socket, ListenSocket, Location);
try
{

io.Run(timeout);
}
catch (...)
{
auto hr = wil::ResultFromCaughtException();
THROW_HR_MSG(hr, "Failed to accept socket. From: %hs", std::format("{}", Location).c_str());
}

return true;
return accepted;
}

std::pair<DWORD, DWORD> wsl::windows::common::socket::GetResult(
Expand Down
3 changes: 1 addition & 2 deletions src/windows/common/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ namespace wsl::windows::common::socket {
// Sets SO_UPDATE_ACCEPT_CONTEXT on a socket accepted via AcceptEx to mark it as connected.
void SetAcceptContext(_In_ SOCKET AcceptedSocket, _In_ SOCKET ListenSocket, _In_ const std::source_location& Location = std::source_location::current());

bool CancellableAccept(
std::optional<wil::unique_socket> CancellableAccept(
_In_ SOCKET ListenSocket,
_In_ SOCKET Socket,
_In_ DWORD Timeout,
_In_opt_ HANDLE ExitHandle,
_In_ const std::source_location& Location = std::source_location::current());
Expand Down
4 changes: 2 additions & 2 deletions src/windows/service/exe/HcsVirtualMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ try
{
RETURN_HR_IF_NULL(E_POINTER, Socket);

auto socket = wsl::windows::common::hvsocket::CancellableAccept(m_listenSocket.get(), m_bootTimeoutMs, m_vmExitEvent.get());
auto socket = socket::CancellableAccept(m_listenSocket.get(), m_bootTimeoutMs, m_vmExitEvent.get());
THROW_HR_IF(E_ABORT, !socket.has_value());

*Socket = reinterpret_cast<HANDLE>(socket->release());
Expand Down Expand Up @@ -912,4 +912,4 @@ try
}
CATCH_RETURN()

} // namespace wsl::windows::service::wslc
} // namespace wsl::windows::service::wslc
Loading
Loading