From d58e0f027c85f3493fe2d60fc4fe2949521248dd Mon Sep 17 00:00:00 2001 From: Blue Date: Thu, 11 Jun 2026 22:13:33 -0700 Subject: [PATCH 1/6] Save state --- src/windows/common/HandleIO.cpp | 81 +++++++++---- src/windows/common/HandleIO.h | 26 ++-- src/windows/common/hvsocket.cpp | 10 +- src/windows/common/socket.cpp | 22 ++-- src/windows/common/socket.hpp | 3 +- src/windows/service/exe/WslCoreVm.cpp | 166 +++++++++++++++----------- src/windows/service/exe/WslCoreVm.h | 2 + src/windows/wslrelay/localhost.cpp | 12 +- src/windows/wslrelay/main.cpp | 8 +- 9 files changed, 197 insertions(+), 133 deletions(-) diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index 532c0dc5c3..70ab971f0d 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -4,6 +4,7 @@ #include "HandleIO.h" #pragma hdrstop +using wsl::windows::common::io::AcceptHandle; using wsl::windows::common::io::BufferWrapper; using wsl::windows::common::io::DockerIORelayHandle; using wsl::windows::common::io::EventHandle; @@ -16,7 +17,6 @@ using wsl::windows::common::io::OverlappedIOHandle; using wsl::windows::common::io::ReadHandle; using wsl::windows::common::io::ReadNamedPipe; using wsl::windows::common::io::ReadSocketMessageHandle; -using wsl::windows::common::io::SingleAcceptHandle; using wsl::windows::common::io::WriteHandle; using wsl::windows::common::io::WriteNamedPipe; @@ -360,41 +360,81 @@ void ReadNamedPipe::Collect() ReadHandle::Collect(); } -// SingleAcceptHandle +// AcceptHandle -SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function&& OnAccepted) : - ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted)) +AcceptHandle::AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::function&& OnAccepted) : + ListenSocket(std::move(ListenSocket)), AcceptOnce(AcceptOnce), OnAccepted(std::move(OnAccepted)) { Overlapped.hEvent = Event.get(); + + // Query the listen socket so accepted sockets can be created with a matching address family, type, and protocol. + WSAPROTOCOL_INFOW protocolInfo{}; + int length = sizeof(protocolInfo); + THROW_LAST_ERROR_IF( + getsockopt(reinterpret_cast(this->ListenSocket.Get()), SOL_SOCKET, SO_PROTOCOL_INFOW, reinterpret_cast(&protocolInfo), &length) == SOCKET_ERROR); + + AddressFamily = protocolInfo.iAddressFamily; + SocketType = protocolInfo.iSocketType; + Protocol = protocolInfo.iProtocol; + + CreateAcceptSocket(); } -SingleAcceptHandle::~SingleAcceptHandle() +AcceptHandle::~AcceptHandle() { if (State == IOHandleStatus::Pending) { - LOG_IF_WIN32_BOOL_FALSE(CancelIoEx(ListenSocket.Get(), &Overlapped)); + CancelPendingIo(reinterpret_cast(ListenSocket.Get()), Overlapped); + } +} - DWORD bytesProcessed{}; - DWORD flagsReturned{}; - if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned)) - { - auto error = GetLastError(); - LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED); - } +void AcceptHandle::CreateAcceptSocket() +{ + AcceptedSocket.reset(WSASocketW(AddressFamily, SocketType, Protocol, nullptr, 0, WSA_FLAG_OVERLAPPED)); + THROW_LAST_ERROR_IF(!AcceptedSocket); + + // Configure the socket like hvsocket::Create() does so Hyper-V connections survive VM suspend. + if (AddressFamily == AF_HYPERV) + { + ULONG enable = 1; + THROW_LAST_ERROR_IF( + setsockopt(AcceptedSocket.get(), HV_PROTOCOL_RAW, HVSOCKET_CONNECTED_SUSPEND, reinterpret_cast(&enable), sizeof(enable)) == SOCKET_ERROR); + } +} + +void AcceptHandle::OnComplete() +{ + // Mark the accepted socket as connected (this also updates its context from the listen socket). + const auto listenSocket = reinterpret_cast(ListenSocket.Get()); + THROW_LAST_ERROR_IF( + setsockopt(AcceptedSocket.get(), SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&listenSocket), sizeof(listenSocket)) == SOCKET_ERROR); + + OnAccepted(std::move(AcceptedSocket)); + + if (AcceptOnce) + { + State = IOHandleStatus::Completed; + } + else + { + // Prepare a fresh socket for the next accept and return to standby so the loop reschedules. + CreateAcceptSocket(); + State = IOHandleStatus::Standby; } } -void SingleAcceptHandle::Schedule() +void AcceptHandle::Schedule() { WI_ASSERT(State == IOHandleStatus::Standby); + Event.ResetEvent(); + // Schedule the accept. DWORD bytesReturned{}; - if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped)) + if (AcceptEx((SOCKET)ListenSocket.Get(), AcceptedSocket.get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped)) { // Accept completed immediately. - State = IOHandleStatus::Completed; - OnAccepted(); + OnComplete(); } else { @@ -405,7 +445,7 @@ void SingleAcceptHandle::Schedule() } } -void SingleAcceptHandle::Collect() +void AcceptHandle::Collect() { WI_ASSERT(State == IOHandleStatus::Pending); @@ -414,11 +454,10 @@ void SingleAcceptHandle::Collect() THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned)); - State = IOHandleStatus::Completed; - OnAccepted(); + OnComplete(); } -HANDLE SingleAcceptHandle::GetHandle() const +HANDLE AcceptHandle::GetHandle() const { return Event.get(); } diff --git a/src/windows/common/HandleIO.h b/src/windows/common/HandleIO.h index 14d168c4af..0b2d504e0b 100644 --- a/src/windows/common/HandleIO.h +++ b/src/windows/common/HandleIO.h @@ -139,25 +139,37 @@ class ReadNamedPipe : public ReadHandle bool m_connected = false; }; -class SingleAcceptHandle : public OverlappedIOHandle +// Accepts connections on a listen socket using overlapped IO. The accepted sockets are created internally to +// match the address family, type, and protocol of the listen socket, so callers no longer need to pre-create them. +// Hyper-V sockets are additionally configured to match hvsocket::Create() (HVSOCKET_CONNECTED_SUSPEND). When +// AcceptOnce is true the handle completes after a single accept; otherwise it keeps accepting until the loop is +// cancelled. OnAccepted is invoked with ownership of each accepted socket. +class AcceptHandle : public OverlappedIOHandle { public: - NON_COPYABLE(SingleAcceptHandle) - NON_MOVABLE(SingleAcceptHandle) + NON_COPYABLE(AcceptHandle) + NON_MOVABLE(AcceptHandle) - SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function&& OnAccepted); - ~SingleAcceptHandle(); + AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::function&& OnAccepted); + ~AcceptHandle(); void Schedule() override; void Collect() override; HANDLE GetHandle() const override; private: + void CreateAcceptSocket(); + void OnComplete(); + HandleWrapper ListenSocket; - HandleWrapper AcceptedSocket; + wil::unique_socket AcceptedSocket; + int AddressFamily{}; + int SocketType{}; + int Protocol{}; + bool AcceptOnce{}; wil::unique_event Event{wil::EventOptions::ManualReset}; OVERLAPPED Overlapped{}; - std::function OnAccepted; + std::function OnAccepted; char AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]; }; diff --git a/src/windows/common/hvsocket.cpp b/src/windows/common/hvsocket.cpp index 204de0c849..c2df1c8d37 100644 --- a/src/windows/common/hvsocket.cpp +++ b/src/windows/common/hvsocket.cpp @@ -40,13 +40,9 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address) std::optional wsl::windows::common::hvsocket::CancellableAccept( _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { - wil::unique_socket Socket = Create(); - if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location)) - { - return {}; - } - - return Socket; + // AcceptHandle creates the accepted socket from the listen socket's protocol info and, because the listen socket + // is a Hyper-V socket, configures it with HVSOCKET_CONNECTED_SUSPEND to match Create(). + return socket::CancellableAccept(ListenSocket, Timeout, ExitHandle, Location); } wil::unique_socket wsl::windows::common::hvsocket::Connect( diff --git a/src/windows/common/socket.cpp b/src/windows/common/socket.cpp index d476018992..ab6548ba1e 100644 --- a/src/windows/common/socket.cpp +++ b/src/windows/common/socket.cpp @@ -26,30 +26,32 @@ void wsl::windows::common::socket::SetAcceptContext(_In_ SOCKET AcceptedSocket, std::format("{}", Location).c_str()); } -bool wsl::windows::common::socket::CancellableAccept( - _In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) +std::optional wsl::windows::common::socket::CancellableAccept( + _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { io::MultiHandleWait io; - bool accepted = false; + std::optional accepted; - io.AddHandle(std::make_unique(ListenSocket, Socket, [&]() { accepted = true; }), io::MultiHandleWait::CancelOnCompleted); + io.AddHandle( + std::make_unique( + ListenSocket, true, [&accepted](wil::unique_socket&& socket) { accepted = std::move(socket); }), + io::MultiHandleWait::CancelOnCompleted); if (ExitHandle != nullptr) { io.AddHandle(std::make_unique(ExitHandle), io::MultiHandleWait::CancelOnCompleted); } - io.Run(std::chrono::milliseconds(Timeout)); - - if (!accepted) + std::optional timeout; + if (Timeout != INFINITE) { - return false; // Accept was cancelled by the exit event. + timeout = std::chrono::milliseconds(Timeout); } - SetAcceptContext(Socket, ListenSocket, Location); + io.Run(timeout); - return true; + return accepted; } std::pair wsl::windows::common::socket::GetResult( diff --git a/src/windows/common/socket.hpp b/src/windows/common/socket.hpp index d14696084d..7af77b8421 100644 --- a/src/windows/common/socket.hpp +++ b/src/windows/common/socket.hpp @@ -21,9 +21,8 @@ namespace wsl::windows::common::socket { // Sets SO_UPDATE_ACCEPT_CONTEXT on a socket accepted via AcceptEx to mark it as connected. void SetAcceptContext(_In_ SOCKET AcceptedSocket, _In_ SOCKET ListenSocket, _In_ const std::source_location& Location = std::source_location::current()); -bool CancellableAccept( +std::optional CancellableAccept( _In_ SOCKET ListenSocket, - _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location = std::source_location::current()); diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index 050a519e2f..57fc250716 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -2617,92 +2617,112 @@ try { wsl::windows::common::wslutil::SetThreadDescription(L"VirtioFs - Worker"); - for (;;) - { - // Create a worker thread to handle each request. + io::MultiHandleWait io; + + io.AddHandle(std::make_unique(listenSocket.get(), false, [this, &io](wil::unique_socket&& socket) { + auto channel = std::make_shared(std::move(socket), "VirtioFs"); + auto buffer = std::make_shared>(); + auto pendingBytes = std::make_shared>(); + + io.AddHandle( + std::make_unique( + io::HandleWrapper(channel->Socket()), + *buffer, + *pendingBytes, + [this, &io, channel, buffer, pendingBytes](const gsl::span& message) { + if (message.empty()) + { + return; // Channel closed, exit. + } - auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); - if (!socket.has_value()) - { - break; - } + THROW_HR_IF_MSG( + E_UNEXPECTED, !pendingBytes->empty(), "Received message with additional bytes: %lu", pendingBytes->size()); - wsl::shared::SocketChannel channel{std::move(socket.value()), "VirtioFs", {m_terminatingEvent.get()}}; - std::thread([this, channel = std::move(channel)]() mutable { - try - { - wsl::windows::common::wslutil::SetThreadDescription(L"VirtioFs - Request"); + try + { + auto response = ProcessVirtioFsRequest(message); - auto transaction = channel.ReceiveTransaction(); - auto [message, span] = transaction.ReceiveOrClosed(); - if (message == nullptr) - { - return; - } + // Move the socket out of the channel into the WriteHandle so it is closed once the reply is sent. + io.AddHandle(std::make_unique(channel->Release(), response), io::MultiHandleWait::IgnoreErrors); + } + CATCH_LOG(); + }), + io::MultiHandleWait::IgnoreErrors); + })); - auto respondWithTag = [&](const std::wstring& tag, const std::wstring& source, HRESULT result) { - // Respond to the guest with the tag that should be used to mount the device. + io.AddHandle(std::make_unique(m_terminatingEvent.get()), io::MultiHandleWait::CancelOnCompleted); - wsl::shared::MessageWriter response(LxInitMessageAddVirtioFsDeviceResponse); - response->Result = SUCCEEDED(result) ? 0 : EINVAL; // TODO: Improved HRESULT -> errno mapping. - response.WriteString(response->TagOffset, tag); - response.WriteString(response->SourceOffset, source); + io.Run({}); +} +CATCH_LOG() - transaction.Send(response.Span()); - }; +std::vector WslCoreVm::ProcessVirtioFsRequest(_In_ gsl::span Request) +{ + const auto* header = gslhelpers::try_get_struct(Request); + THROW_HR_IF(E_UNEXPECTED, !header); + + auto buildResponse = [header](const std::wstring& tag, const std::wstring& source, HRESULT result) { + // Respond to the guest with the tag that should be used to mount the device. + wsl::shared::MessageWriter response(LxInitMessageAddVirtioFsDeviceResponse); + response->Result = SUCCEEDED(result) ? 0 : EINVAL; // TODO: Improved HRESULT -> errno mapping. + response.WriteString(response->TagOffset, tag); + response.WriteString(response->SourceOffset, source); + + // Echo the request's transaction id and mark the message as the first (and only) reply. + response->Header.TransactionId = header->TransactionId; + response->Header.TransactionStep = static_cast(TRANSACTION_STEP::FIRST_REPLY); + + const auto span = response.Span(); + return std::vector(reinterpret_cast(span.data()), reinterpret_cast(span.data()) + span.size()); + }; - if (message->MessageType == LxInitMessageAddVirtioFsDevice) - { - std::wstring tag; - std::wstring source; - const auto result = wil::ResultFromException([this, span, &tag, &source]() { - const auto* addShare = gslhelpers::try_get_struct(span); - THROW_HR_IF(E_UNEXPECTED, !addShare); - - const auto path = wsl::shared::string::FromSpan(span, addShare->PathOffset); - const auto pathWide = wsl::shared::string::MultiByteToWide(path); - const auto options = wsl::shared::string::FromSpan(span, addShare->OptionsOffset); - const auto optionsWide = wsl::shared::string::MultiByteToWide(options); - - // Acquire the lock and attempt to add the device. - auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); - std::tie(tag, source) = AddVirtioFsShare(addShare->Admin, pathWide.c_str(), optionsWide.c_str()); - }); - - respondWithTag(tag, source, result); - } - else if (message->MessageType == LxInitMessageRemountVirtioFsDevice) - { - std::wstring newTag; - std::wstring source; - const auto result = wil::ResultFromException([this, span, &newTag, &source]() { - const auto* remountShare = gslhelpers::try_get_struct(span); - THROW_HR_IF(E_UNEXPECTED, !remountShare); + if (header->MessageType == LxInitMessageAddVirtioFsDevice) + { + std::wstring tag; + std::wstring source; + const auto result = wil::ResultFromException([&]() { + const auto* addShare = gslhelpers::try_get_struct(Request); + THROW_HR_IF(E_UNEXPECTED, !addShare); - const std::string tag = wsl::shared::string::FromSpan(span, remountShare->TagOffset); - const auto tagWide = wsl::shared::string::MultiByteToWide(tag); - auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); - const auto foundShare = FindVirtioFsShare(tagWide.c_str(), !remountShare->Admin); - THROW_HR_IF_MSG(E_UNEXPECTED, !foundShare.has_value(), "Unknown tag %ls", tagWide.c_str()); + const auto path = wsl::shared::string::FromSpan(Request, addShare->PathOffset); + const auto pathWide = wsl::shared::string::MultiByteToWide(path); + const auto options = wsl::shared::string::FromSpan(Request, addShare->OptionsOffset); + const auto optionsWide = wsl::shared::string::MultiByteToWide(options); - std::tie(newTag, source) = - AddVirtioFsShare(remountShare->Admin, foundShare->Path.c_str(), foundShare->OptionsString().c_str()); + // Acquire the lock and attempt to add the device. + auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); + std::tie(tag, source) = AddVirtioFsShare(addShare->Admin, pathWide.c_str(), optionsWide.c_str()); + }); - WI_ASSERT(source == foundShare->Path); - }); + return buildResponse(tag, source, result); + } + else if (header->MessageType == LxInitMessageRemountVirtioFsDevice) + { + std::wstring newTag; + std::wstring source; + const auto result = wil::ResultFromException([&]() { + const auto* remountShare = gslhelpers::try_get_struct(Request); + THROW_HR_IF(E_UNEXPECTED, !remountShare); - respondWithTag(newTag, source, result); - } - else - { - THROW_HR_MSG(E_UNEXPECTED, "Unexpected MessageType %d", message->MessageType); - } - } - CATCH_LOG() - }).detach(); + const std::string tag = wsl::shared::string::FromSpan(Request, remountShare->TagOffset); + const auto tagWide = wsl::shared::string::MultiByteToWide(tag); + auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); + const auto foundShare = FindVirtioFsShare(tagWide.c_str(), !remountShare->Admin); + THROW_HR_IF_MSG(E_UNEXPECTED, !foundShare.has_value(), "Unknown tag %ls", tagWide.c_str()); + + std::tie(newTag, source) = + AddVirtioFsShare(remountShare->Admin, foundShare->Path.c_str(), foundShare->OptionsString().c_str()); + + WI_ASSERT(source == foundShare->Path); + }); + + return buildResponse(newTag, source, result); + } + else + { + THROW_HR_MSG(E_UNEXPECTED, "Unexpected MessageType %d", header->MessageType); } } -CATCH_LOG() std::string WslCoreVm::s_GetMountTargetName(_In_ PCWSTR Disk, _In_opt_ PCWSTR Name, _In_ int PartitionIndex) { diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index b3bb810b6f..30ce6b4637 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -251,6 +251,8 @@ class WslCoreVm void VirtioFsWorker(_In_ const wil::unique_socket& socket); + std::vector ProcessVirtioFsRequest(_In_ gsl::span Request); + static std::string s_GetMountTargetName(_In_ PCWSTR Disk, _In_opt_ PCWSTR Name, _In_ int PartitionIndex); static LX_INIT_DRVFS_MOUNT s_InitializeDrvFs(_Inout_ WslCoreVm* VmContext, _In_ HANDLE UserToken); diff --git a/src/windows/wslrelay/localhost.cpp b/src/windows/wslrelay/localhost.cpp index 1b2af83294..bcd549d071 100644 --- a/src/windows/wslrelay/localhost.cpp +++ b/src/windows/wslrelay/localhost.cpp @@ -291,15 +291,11 @@ try { // Begin accepting connections until the relay is stopped. - const int AddressFamily = WindowsAddressFamily(Arguments->Family); - for (;;) { - wil::unique_socket InetSocket(WSASocket(AddressFamily, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)); - THROW_LAST_ERROR_IF(!InetSocket); - - if (!wsl::windows::common::socket::CancellableAccept( - Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get())) + auto InetSocket = wsl::windows::common::socket::CancellableAccept( + Arguments->ListenSocket.get(), INFINITE, Arguments->ExitEvent.get()); + if (!InetSocket) { break; // Exit event was signaled, exit. } @@ -308,7 +304,7 @@ try WSL_LOG("PortRelayUsage", TraceLoggingValue(Arguments->Family, "family"), TraceLoggingValue(Arguments->Port, "port"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); - auto RelayThread = std::thread([Arguments, InetSocket = std::move(InetSocket)]() { + auto RelayThread = std::thread([Arguments, InetSocket = std::move(*InetSocket)]() { try { wsl::windows::common::wslutil::SetThreadDescription(L"Port relay"); diff --git a/src/windows/wslrelay/main.cpp b/src/windows/wslrelay/main.cpp index ea5080be24..c5657fd4b0 100644 --- a/src/windows/wslrelay/main.cpp +++ b/src/windows/wslrelay/main.cpp @@ -123,17 +123,15 @@ try THROW_LAST_ERROR_IF(listen(listenSocket.get(), 1) == SOCKET_ERROR); - const wil::unique_socket socket(WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)); - THROW_LAST_ERROR_IF(!socket); - - if (!wsl::windows::common::socket::CancellableAccept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get())) + auto socket = wsl::windows::common::socket::CancellableAccept(listenSocket.get(), INFINITE, exitEvent.get()); + if (!socket) { return 1; } // Begin the relay. wsl::windows::common::relay::BidirectionalRelay( - reinterpret_cast(socket.get()), pipe.get(), 0x1000, wsl::windows::common::relay::RelayFlags::LeftIsSocket); + reinterpret_cast(socket->get()), pipe.get(), 0x1000, wsl::windows::common::relay::RelayFlags::LeftIsSocket); break; } From 76a48b236ceeb0eaf424aaf5044004ad21ffb5fc Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 12 Jun 2026 12:21:32 -0700 Subject: [PATCH 2/6] Add test coverage --- src/windows/common/HandleIO.cpp | 1 - src/windows/common/HandleIO.h | 5 --- src/windows/common/hvsocket.cpp | 2 - src/windows/service/exe/WslCoreVm.cpp | 6 ++- test/windows/DrvFsTests.cpp | 60 ++++++++++++++++++++++++++- 5 files changed, 64 insertions(+), 10 deletions(-) diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index 70ab971f0d..294d2a0f32 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -393,7 +393,6 @@ void AcceptHandle::CreateAcceptSocket() AcceptedSocket.reset(WSASocketW(AddressFamily, SocketType, Protocol, nullptr, 0, WSA_FLAG_OVERLAPPED)); THROW_LAST_ERROR_IF(!AcceptedSocket); - // Configure the socket like hvsocket::Create() does so Hyper-V connections survive VM suspend. if (AddressFamily == AF_HYPERV) { ULONG enable = 1; diff --git a/src/windows/common/HandleIO.h b/src/windows/common/HandleIO.h index 0b2d504e0b..1803a07c19 100644 --- a/src/windows/common/HandleIO.h +++ b/src/windows/common/HandleIO.h @@ -139,11 +139,6 @@ class ReadNamedPipe : public ReadHandle bool m_connected = false; }; -// Accepts connections on a listen socket using overlapped IO. The accepted sockets are created internally to -// match the address family, type, and protocol of the listen socket, so callers no longer need to pre-create them. -// Hyper-V sockets are additionally configured to match hvsocket::Create() (HVSOCKET_CONNECTED_SUSPEND). When -// AcceptOnce is true the handle completes after a single accept; otherwise it keeps accepting until the loop is -// cancelled. OnAccepted is invoked with ownership of each accepted socket. class AcceptHandle : public OverlappedIOHandle { public: diff --git a/src/windows/common/hvsocket.cpp b/src/windows/common/hvsocket.cpp index c2df1c8d37..8648a40667 100644 --- a/src/windows/common/hvsocket.cpp +++ b/src/windows/common/hvsocket.cpp @@ -40,8 +40,6 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address) std::optional wsl::windows::common::hvsocket::CancellableAccept( _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { - // AcceptHandle creates the accepted socket from the listen socket's protocol info and, because the listen socket - // is a Hyper-V socket, configures it with HVSOCKET_CONNECTED_SUSPEND to match Create(). return socket::CancellableAccept(ListenSocket, Timeout, ExitHandle, Location); } diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index 57fc250716..b3d856c427 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -2636,7 +2636,7 @@ try } THROW_HR_IF_MSG( - E_UNEXPECTED, !pendingBytes->empty(), "Received message with additional bytes: %lu", pendingBytes->size()); + E_UNEXPECTED, !pendingBytes->empty(), "Received message with additional bytes: %zu", pendingBytes->size()); try { @@ -2661,6 +2661,8 @@ std::vector WslCoreVm::ProcessVirtioFsRequest(_In_ gsl::span Re const auto* header = gslhelpers::try_get_struct(Request); THROW_HR_IF(E_UNEXPECTED, !header); + WSL_LOG("VirtiofsMessageRequest", TraceLoggingValue(header->PrettyPrint().c_str(), "Content")); + auto buildResponse = [header](const std::wstring& tag, const std::wstring& source, HRESULT result) { // Respond to the guest with the tag that should be used to mount the device. wsl::shared::MessageWriter response(LxInitMessageAddVirtioFsDeviceResponse); @@ -2672,6 +2674,8 @@ std::vector WslCoreVm::ProcessVirtioFsRequest(_In_ gsl::span Re response->Header.TransactionId = header->TransactionId; response->Header.TransactionStep = static_cast(TRANSACTION_STEP::FIRST_REPLY); + WSL_LOG("VirtiofsMessageResponse", TraceLoggingValue(response->PrettyPrint().c_str(), "Content")); + const auto span = response.Span(); return std::vector(reinterpret_cast(span.data()), reinterpret_cast(span.data()) + span.size()); }; diff --git a/test/windows/DrvFsTests.cpp b/test/windows/DrvFsTests.cpp index 9025ba6a36..7838b9719d 100644 --- a/test/windows/DrvFsTests.cpp +++ b/test/windows/DrvFsTests.cpp @@ -409,6 +409,59 @@ class DrvFsTests VERIFY_IS_TRUE(out.find(L"test-file.txt") != std::wstring::npos); } + void DrvfsMountManyVirtioFsShares(DrvFsMode Mode) + { + if (Mode != DrvFsMode::VirtioFs) + { + LogSkipped("This test is only applicable to VirtioFs"); + return; + } + + WINDOWS_11_TEST_ONLY(); + SKIP_TEST_ARM64(); + + constexpr auto c_iterations = 20; + auto testDir = std::filesystem::current_path() / "virtiofs-loop-test"; + + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + LxsstuLaunchWsl(L"umount /tmp/virtiofs-loop-test-*"); + + std::error_code ec; + std::filesystem::remove_all(testDir, ec); + }); + + for (int i = 0; i < c_iterations; ++i) + { + const auto sourceDir = testDir / std::to_string(i); + std::filesystem::create_directories(sourceDir); + + const auto expected = std::format("virtiofs share {}", i); + { + std::ofstream markerFile(std::filesystem::path(sourceDir) / L"marker"); + markerFile << expected; + } + + const auto mountPoint = std::format(L"/tmp/virtiofs-loop-test-{}", i); + + // Mount the share. + VERIFY_ARE_EQUAL(LxsstuLaunchWsl(std::format(L"mkdir -p '{}'", mountPoint)), 0); + VERIFY_ARE_EQUAL(LxsstuLaunchWsl(std::format(L"mount -t drvfs '{}' '{}'", sourceDir.string(), mountPoint)), 0); + + // Validate that it can be accessed. + { + auto [out, err] = LxsstuLaunchWslAndCaptureOutput(std::format(L"cat '{}/marker'", mountPoint)); + VERIFY_ARE_EQUAL(out, wsl::shared::string::MultiByteToWide(expected)); + } + + // Validate the mount options. + { + auto [out, err] = LxsstuLaunchWslAndCaptureOutput(std::format(L"findmnt -ln '{}'", mountPoint)); + + VerifyPatternMatch(wsl::shared::string::WideToMultiByte(out.c_str()), std::format("{} * virtiofs rw,relatime\n", mountPoint)); + } + } + } + // DrvFsTests Private Methods private: static VOID CreateDrvFsTestFiles(bool Metadata) @@ -1303,6 +1356,11 @@ class WSL1 : public DrvFsTests { \ DrvFsTests::DrvFsMountUnicodePath(DrvFsMode::##_mode##); \ } \ +\ + WSL2_TEST_METHOD(DrvfsMountManyVirtioFsShares) \ + { \ + DrvFsTests::DrvfsMountManyVirtioFsShares(DrvFsMode::##_mode##); \ + } \ } WSL2_DRVFS_TEST_CLASS(Plan9); @@ -1313,4 +1371,4 @@ WSL2_DRVFS_TEST_CLASS(VirtioFs); // TODO: Enable again once the issue is resolved // WSL2_DRVFS_TEST_CLASS(Virtio9p); -} // namespace DrvFsTests \ No newline at end of file +} // namespace DrvFsTests From cab0811409bbc693dc6abc2469ba1a797d233b8d Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 12 Jun 2026 12:21:49 -0700 Subject: [PATCH 3/6] Format --- src/windows/common/HandleIO.cpp | 9 ++++++--- src/windows/wslrelay/localhost.cpp | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index 294d2a0f32..c6423be2a2 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -371,7 +371,8 @@ AcceptHandle::AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::f WSAPROTOCOL_INFOW protocolInfo{}; int length = sizeof(protocolInfo); THROW_LAST_ERROR_IF( - getsockopt(reinterpret_cast(this->ListenSocket.Get()), SOL_SOCKET, SO_PROTOCOL_INFOW, reinterpret_cast(&protocolInfo), &length) == SOCKET_ERROR); + getsockopt(reinterpret_cast(this->ListenSocket.Get()), SOL_SOCKET, SO_PROTOCOL_INFOW, reinterpret_cast(&protocolInfo), &length) == + SOCKET_ERROR); AddressFamily = protocolInfo.iAddressFamily; SocketType = protocolInfo.iSocketType; @@ -397,7 +398,8 @@ void AcceptHandle::CreateAcceptSocket() { ULONG enable = 1; THROW_LAST_ERROR_IF( - setsockopt(AcceptedSocket.get(), HV_PROTOCOL_RAW, HVSOCKET_CONNECTED_SUSPEND, reinterpret_cast(&enable), sizeof(enable)) == SOCKET_ERROR); + setsockopt(AcceptedSocket.get(), HV_PROTOCOL_RAW, HVSOCKET_CONNECTED_SUSPEND, reinterpret_cast(&enable), sizeof(enable)) == + SOCKET_ERROR); } } @@ -406,7 +408,8 @@ void AcceptHandle::OnComplete() // Mark the accepted socket as connected (this also updates its context from the listen socket). const auto listenSocket = reinterpret_cast(ListenSocket.Get()); THROW_LAST_ERROR_IF( - setsockopt(AcceptedSocket.get(), SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&listenSocket), sizeof(listenSocket)) == SOCKET_ERROR); + setsockopt(AcceptedSocket.get(), SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&listenSocket), sizeof(listenSocket)) == + SOCKET_ERROR); OnAccepted(std::move(AcceptedSocket)); diff --git a/src/windows/wslrelay/localhost.cpp b/src/windows/wslrelay/localhost.cpp index bcd549d071..378ee47676 100644 --- a/src/windows/wslrelay/localhost.cpp +++ b/src/windows/wslrelay/localhost.cpp @@ -293,8 +293,8 @@ try for (;;) { - auto InetSocket = wsl::windows::common::socket::CancellableAccept( - Arguments->ListenSocket.get(), INFINITE, Arguments->ExitEvent.get()); + auto InetSocket = + wsl::windows::common::socket::CancellableAccept(Arguments->ListenSocket.get(), INFINITE, Arguments->ExitEvent.get()); if (!InetSocket) { break; // Exit event was signaled, exit. From 9f6579045cb96fb764b13e3806e59245187a1c6f Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 12 Jun 2026 12:25:27 -0700 Subject: [PATCH 4/6] Simplify socket creation --- src/windows/common/HandleIO.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index c6423be2a2..02c57618e8 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -377,8 +377,6 @@ AcceptHandle::AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::f AddressFamily = protocolInfo.iAddressFamily; SocketType = protocolInfo.iSocketType; Protocol = protocolInfo.iProtocol; - - CreateAcceptSocket(); } AcceptHandle::~AcceptHandle() @@ -419,8 +417,6 @@ void AcceptHandle::OnComplete() } else { - // Prepare a fresh socket for the next accept and return to standby so the loop reschedules. - CreateAcceptSocket(); State = IOHandleStatus::Standby; } } @@ -429,6 +425,8 @@ void AcceptHandle::Schedule() { WI_ASSERT(State == IOHandleStatus::Standby); + CreateAcceptSocket(); + Event.ResetEvent(); // Schedule the accept. From bc268e15f7ce23e6b4daac8f3f6155f1e7163c6f Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 12 Jun 2026 14:29:29 -0700 Subject: [PATCH 5/6] Apply PR feedback --- src/windows/common/HandleIO.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index 02c57618e8..d94c5cf9a8 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -403,11 +403,7 @@ void AcceptHandle::CreateAcceptSocket() void AcceptHandle::OnComplete() { - // Mark the accepted socket as connected (this also updates its context from the listen socket). - const auto listenSocket = reinterpret_cast(ListenSocket.Get()); - THROW_LAST_ERROR_IF( - setsockopt(AcceptedSocket.get(), SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&listenSocket), sizeof(listenSocket)) == - SOCKET_ERROR); + wsl::windows::common::socket::SetAcceptContext(AcceptedSocket.get(), reinterpret_cast(ListenSocket.Get())); OnAccepted(std::move(AcceptedSocket)); From 32f51ab736ed111d0c0b94c303a2290b08bb776d Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 12 Jun 2026 16:49:09 -0700 Subject: [PATCH 6/6] Cleanup --- src/windows/common/hvsocket.cpp | 6 ------ src/windows/common/hvsocket.hpp | 6 ------ src/windows/common/socket.cpp | 11 ++++++++++- src/windows/service/exe/HcsVirtualMachine.cpp | 4 ++-- src/windows/service/exe/WslCoreVm.cpp | 4 ++-- src/windows/wslcsession/WSLCVirtualMachine.cpp | 2 +- 6 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/windows/common/hvsocket.cpp b/src/windows/common/hvsocket.cpp index 8648a40667..5375226c54 100644 --- a/src/windows/common/hvsocket.cpp +++ b/src/windows/common/hvsocket.cpp @@ -37,12 +37,6 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address) } } // namespace -std::optional wsl::windows::common::hvsocket::CancellableAccept( - _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) -{ - return socket::CancellableAccept(ListenSocket, Timeout, ExitHandle, Location); -} - wil::unique_socket wsl::windows::common::hvsocket::Connect( _In_ const GUID& VmId, _In_ unsigned long Port, _In_opt_ HANDLE ExitHandle, _In_opt_ ULONG Timeout, _In_ const std::source_location& Location) { diff --git a/src/windows/common/hvsocket.hpp b/src/windows/common/hvsocket.hpp index 9aa67253ba..9f4ab56504 100644 --- a/src/windows/common/hvsocket.hpp +++ b/src/windows/common/hvsocket.hpp @@ -19,12 +19,6 @@ Module Name: namespace wsl::windows::common::hvsocket { -std::optional CancellableAccept( - _In_ SOCKET ListenSocket, - _In_ DWORD Timeout, - _In_opt_ HANDLE ExitHandle = nullptr, - const std::source_location& Location = std::source_location::current()); - wil::unique_socket Connect( _In_ const GUID& VmId, _In_ unsigned long Port, diff --git a/src/windows/common/socket.cpp b/src/windows/common/socket.cpp index ab6548ba1e..e8b3510495 100644 --- a/src/windows/common/socket.cpp +++ b/src/windows/common/socket.cpp @@ -49,7 +49,16 @@ std::optional wsl::windows::common::socket::CancellableAccep timeout = std::chrono::milliseconds(Timeout); } - io.Run(timeout); + try + { + + io.Run(timeout); + } + catch (...) + { + auto hr = wil::ResultFromCaughtException(); + THROW_HR_MSG(hr, "Failed to accept socket. From: %hs", std::format("{}", Location).c_str()); + } return accepted; } diff --git a/src/windows/service/exe/HcsVirtualMachine.cpp b/src/windows/service/exe/HcsVirtualMachine.cpp index 74f4e0d890..2f073db59c 100644 --- a/src/windows/service/exe/HcsVirtualMachine.cpp +++ b/src/windows/service/exe/HcsVirtualMachine.cpp @@ -387,7 +387,7 @@ try { RETURN_HR_IF_NULL(E_POINTER, Socket); - auto socket = wsl::windows::common::hvsocket::CancellableAccept(m_listenSocket.get(), m_bootTimeoutMs, m_vmExitEvent.get()); + auto socket = socket::CancellableAccept(m_listenSocket.get(), m_bootTimeoutMs, m_vmExitEvent.get()); THROW_HR_IF(E_ABORT, !socket.has_value()); *Socket = reinterpret_cast(socket->release()); @@ -912,4 +912,4 @@ try } CATCH_RETURN() -} // namespace wsl::windows::service::wslc \ No newline at end of file +} // namespace wsl::windows::service::wslc diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index b3d856c427..e6531e1862 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -856,7 +856,7 @@ WslCoreVm::~WslCoreVm() noexcept wil::unique_socket WslCoreVm::AcceptConnection(_In_ DWORD ReceiveTimeout, _In_ const std::source_location& Location) const { - auto socket = hvsocket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location); + auto socket = socket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location); THROW_HR_IF(E_ABORT, !socket.has_value()); if (ReceiveTimeout != 0) @@ -1087,7 +1087,7 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const { try { - auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); + auto socket = socket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); if (!socket.has_value()) { break; // VM is exiting. diff --git a/src/windows/wslcsession/WSLCVirtualMachine.cpp b/src/windows/wslcsession/WSLCVirtualMachine.cpp index 5c5766f5ef..593198d16c 100644 --- a/src/windows/wslcsession/WSLCVirtualMachine.cpp +++ b/src/windows/wslcsession/WSLCVirtualMachine.cpp @@ -1259,7 +1259,7 @@ void WSLCVirtualMachine::CollectCrashDumps(wil::unique_socket&& listenSocket) { try { - auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_vmTerminatingEvent.get()); + auto socket = socket::CancellableAccept(listenSocket.get(), INFINITE, m_vmTerminatingEvent.get()); if (!socket) { // VM is exiting.