diff --git a/src/windows/common/HandleIO.cpp b/src/windows/common/HandleIO.cpp index 532c0dc5c..d94c5cf9a 100644 --- a/src/windows/common/HandleIO.cpp +++ b/src/windows/common/HandleIO.cpp @@ -4,6 +4,7 @@ #include "HandleIO.h" #pragma hdrstop +using wsl::windows::common::io::AcceptHandle; using wsl::windows::common::io::BufferWrapper; using wsl::windows::common::io::DockerIORelayHandle; using wsl::windows::common::io::EventHandle; @@ -16,7 +17,6 @@ using wsl::windows::common::io::OverlappedIOHandle; using wsl::windows::common::io::ReadHandle; using wsl::windows::common::io::ReadNamedPipe; using wsl::windows::common::io::ReadSocketMessageHandle; -using wsl::windows::common::io::SingleAcceptHandle; using wsl::windows::common::io::WriteHandle; using wsl::windows::common::io::WriteNamedPipe; @@ -360,41 +360,77 @@ void ReadNamedPipe::Collect() ReadHandle::Collect(); } -// SingleAcceptHandle +// AcceptHandle -SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function&& OnAccepted) : - ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted)) +AcceptHandle::AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::function&& OnAccepted) : + ListenSocket(std::move(ListenSocket)), AcceptOnce(AcceptOnce), OnAccepted(std::move(OnAccepted)) { Overlapped.hEvent = Event.get(); + + // Query the listen socket so accepted sockets can be created with a matching address family, type, and protocol. + WSAPROTOCOL_INFOW protocolInfo{}; + int length = sizeof(protocolInfo); + THROW_LAST_ERROR_IF( + getsockopt(reinterpret_cast(this->ListenSocket.Get()), SOL_SOCKET, SO_PROTOCOL_INFOW, reinterpret_cast(&protocolInfo), &length) == + SOCKET_ERROR); + + AddressFamily = protocolInfo.iAddressFamily; + SocketType = protocolInfo.iSocketType; + Protocol = protocolInfo.iProtocol; } -SingleAcceptHandle::~SingleAcceptHandle() +AcceptHandle::~AcceptHandle() { if (State == IOHandleStatus::Pending) { - LOG_IF_WIN32_BOOL_FALSE(CancelIoEx(ListenSocket.Get(), &Overlapped)); + CancelPendingIo(reinterpret_cast(ListenSocket.Get()), Overlapped); + } +} - DWORD bytesProcessed{}; - DWORD flagsReturned{}; - if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned)) - { - auto error = GetLastError(); - LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED); - } +void AcceptHandle::CreateAcceptSocket() +{ + AcceptedSocket.reset(WSASocketW(AddressFamily, SocketType, Protocol, nullptr, 0, WSA_FLAG_OVERLAPPED)); + THROW_LAST_ERROR_IF(!AcceptedSocket); + + if (AddressFamily == AF_HYPERV) + { + ULONG enable = 1; + THROW_LAST_ERROR_IF( + setsockopt(AcceptedSocket.get(), HV_PROTOCOL_RAW, HVSOCKET_CONNECTED_SUSPEND, reinterpret_cast(&enable), sizeof(enable)) == + SOCKET_ERROR); + } +} + +void AcceptHandle::OnComplete() +{ + wsl::windows::common::socket::SetAcceptContext(AcceptedSocket.get(), reinterpret_cast(ListenSocket.Get())); + + OnAccepted(std::move(AcceptedSocket)); + + if (AcceptOnce) + { + State = IOHandleStatus::Completed; + } + else + { + State = IOHandleStatus::Standby; } } -void SingleAcceptHandle::Schedule() +void AcceptHandle::Schedule() { WI_ASSERT(State == IOHandleStatus::Standby); + CreateAcceptSocket(); + + Event.ResetEvent(); + // Schedule the accept. DWORD bytesReturned{}; - if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped)) + if (AcceptEx((SOCKET)ListenSocket.Get(), AcceptedSocket.get(), &AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped)) { // Accept completed immediately. - State = IOHandleStatus::Completed; - OnAccepted(); + OnComplete(); } else { @@ -405,7 +441,7 @@ void SingleAcceptHandle::Schedule() } } -void SingleAcceptHandle::Collect() +void AcceptHandle::Collect() { WI_ASSERT(State == IOHandleStatus::Pending); @@ -414,11 +450,10 @@ void SingleAcceptHandle::Collect() THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned)); - State = IOHandleStatus::Completed; - OnAccepted(); + OnComplete(); } -HANDLE SingleAcceptHandle::GetHandle() const +HANDLE AcceptHandle::GetHandle() const { return Event.get(); } diff --git a/src/windows/common/HandleIO.h b/src/windows/common/HandleIO.h index 14d168c4a..1803a07c1 100644 --- a/src/windows/common/HandleIO.h +++ b/src/windows/common/HandleIO.h @@ -139,25 +139,32 @@ class ReadNamedPipe : public ReadHandle bool m_connected = false; }; -class SingleAcceptHandle : public OverlappedIOHandle +class AcceptHandle : public OverlappedIOHandle { public: - NON_COPYABLE(SingleAcceptHandle) - NON_MOVABLE(SingleAcceptHandle) + NON_COPYABLE(AcceptHandle) + NON_MOVABLE(AcceptHandle) - SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function&& OnAccepted); - ~SingleAcceptHandle(); + AcceptHandle(HandleWrapper&& ListenSocket, bool AcceptOnce, std::function&& OnAccepted); + ~AcceptHandle(); void Schedule() override; void Collect() override; HANDLE GetHandle() const override; private: + void CreateAcceptSocket(); + void OnComplete(); + HandleWrapper ListenSocket; - HandleWrapper AcceptedSocket; + wil::unique_socket AcceptedSocket; + int AddressFamily{}; + int SocketType{}; + int Protocol{}; + bool AcceptOnce{}; wil::unique_event Event{wil::EventOptions::ManualReset}; OVERLAPPED Overlapped{}; - std::function OnAccepted; + std::function OnAccepted; char AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]; }; diff --git a/src/windows/common/hvsocket.cpp b/src/windows/common/hvsocket.cpp index 204de0c84..5375226c5 100644 --- a/src/windows/common/hvsocket.cpp +++ b/src/windows/common/hvsocket.cpp @@ -37,18 +37,6 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address) } } // namespace -std::optional wsl::windows::common::hvsocket::CancellableAccept( - _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) -{ - wil::unique_socket Socket = Create(); - if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location)) - { - return {}; - } - - return Socket; -} - wil::unique_socket wsl::windows::common::hvsocket::Connect( _In_ const GUID& VmId, _In_ unsigned long Port, _In_opt_ HANDLE ExitHandle, _In_opt_ ULONG Timeout, _In_ const std::source_location& Location) { diff --git a/src/windows/common/hvsocket.hpp b/src/windows/common/hvsocket.hpp index 9aa67253b..9f4ab5650 100644 --- a/src/windows/common/hvsocket.hpp +++ b/src/windows/common/hvsocket.hpp @@ -19,12 +19,6 @@ Module Name: namespace wsl::windows::common::hvsocket { -std::optional CancellableAccept( - _In_ SOCKET ListenSocket, - _In_ DWORD Timeout, - _In_opt_ HANDLE ExitHandle = nullptr, - const std::source_location& Location = std::source_location::current()); - wil::unique_socket Connect( _In_ const GUID& VmId, _In_ unsigned long Port, diff --git a/src/windows/common/socket.cpp b/src/windows/common/socket.cpp index d47601899..e8b351049 100644 --- a/src/windows/common/socket.cpp +++ b/src/windows/common/socket.cpp @@ -26,30 +26,41 @@ void wsl::windows::common::socket::SetAcceptContext(_In_ SOCKET AcceptedSocket, std::format("{}", Location).c_str()); } -bool wsl::windows::common::socket::CancellableAccept( - _In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) +std::optional wsl::windows::common::socket::CancellableAccept( + _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { io::MultiHandleWait io; - bool accepted = false; + std::optional accepted; - io.AddHandle(std::make_unique(ListenSocket, Socket, [&]() { accepted = true; }), io::MultiHandleWait::CancelOnCompleted); + io.AddHandle( + std::make_unique( + ListenSocket, true, [&accepted](wil::unique_socket&& socket) { accepted = std::move(socket); }), + io::MultiHandleWait::CancelOnCompleted); if (ExitHandle != nullptr) { io.AddHandle(std::make_unique(ExitHandle), io::MultiHandleWait::CancelOnCompleted); } - io.Run(std::chrono::milliseconds(Timeout)); - - if (!accepted) + std::optional timeout; + if (Timeout != INFINITE) { - return false; // Accept was cancelled by the exit event. + timeout = std::chrono::milliseconds(Timeout); } - SetAcceptContext(Socket, ListenSocket, Location); + try + { + + io.Run(timeout); + } + catch (...) + { + auto hr = wil::ResultFromCaughtException(); + THROW_HR_MSG(hr, "Failed to accept socket. From: %hs", std::format("{}", Location).c_str()); + } - return true; + return accepted; } std::pair wsl::windows::common::socket::GetResult( diff --git a/src/windows/common/socket.hpp b/src/windows/common/socket.hpp index d14696084..7af77b842 100644 --- a/src/windows/common/socket.hpp +++ b/src/windows/common/socket.hpp @@ -21,9 +21,8 @@ namespace wsl::windows::common::socket { // Sets SO_UPDATE_ACCEPT_CONTEXT on a socket accepted via AcceptEx to mark it as connected. void SetAcceptContext(_In_ SOCKET AcceptedSocket, _In_ SOCKET ListenSocket, _In_ const std::source_location& Location = std::source_location::current()); -bool CancellableAccept( +std::optional CancellableAccept( _In_ SOCKET ListenSocket, - _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location = std::source_location::current()); diff --git a/src/windows/service/exe/HcsVirtualMachine.cpp b/src/windows/service/exe/HcsVirtualMachine.cpp index 74f4e0d89..2f073db59 100644 --- a/src/windows/service/exe/HcsVirtualMachine.cpp +++ b/src/windows/service/exe/HcsVirtualMachine.cpp @@ -387,7 +387,7 @@ try { RETURN_HR_IF_NULL(E_POINTER, Socket); - auto socket = wsl::windows::common::hvsocket::CancellableAccept(m_listenSocket.get(), m_bootTimeoutMs, m_vmExitEvent.get()); + auto socket = socket::CancellableAccept(m_listenSocket.get(), m_bootTimeoutMs, m_vmExitEvent.get()); THROW_HR_IF(E_ABORT, !socket.has_value()); *Socket = reinterpret_cast(socket->release()); @@ -912,4 +912,4 @@ try } CATCH_RETURN() -} // namespace wsl::windows::service::wslc \ No newline at end of file +} // namespace wsl::windows::service::wslc diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index 050a519e2..e6531e186 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -856,7 +856,7 @@ WslCoreVm::~WslCoreVm() noexcept wil::unique_socket WslCoreVm::AcceptConnection(_In_ DWORD ReceiveTimeout, _In_ const std::source_location& Location) const { - auto socket = hvsocket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location); + auto socket = socket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location); THROW_HR_IF(E_ABORT, !socket.has_value()); if (ReceiveTimeout != 0) @@ -1087,7 +1087,7 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const { try { - auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); + auto socket = socket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); if (!socket.has_value()) { break; // VM is exiting. @@ -2617,92 +2617,116 @@ try { wsl::windows::common::wslutil::SetThreadDescription(L"VirtioFs - Worker"); - for (;;) - { - // Create a worker thread to handle each request. + io::MultiHandleWait io; - auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); - if (!socket.has_value()) - { - break; - } + io.AddHandle(std::make_unique(listenSocket.get(), false, [this, &io](wil::unique_socket&& socket) { + auto channel = std::make_shared(std::move(socket), "VirtioFs"); + auto buffer = std::make_shared>(); + auto pendingBytes = std::make_shared>(); - wsl::shared::SocketChannel channel{std::move(socket.value()), "VirtioFs", {m_terminatingEvent.get()}}; - std::thread([this, channel = std::move(channel)]() mutable { - try - { - wsl::windows::common::wslutil::SetThreadDescription(L"VirtioFs - Request"); + io.AddHandle( + std::make_unique( + io::HandleWrapper(channel->Socket()), + *buffer, + *pendingBytes, + [this, &io, channel, buffer, pendingBytes](const gsl::span& message) { + if (message.empty()) + { + return; // Channel closed, exit. + } - auto transaction = channel.ReceiveTransaction(); - auto [message, span] = transaction.ReceiveOrClosed(); - if (message == nullptr) - { - return; - } + THROW_HR_IF_MSG( + E_UNEXPECTED, !pendingBytes->empty(), "Received message with additional bytes: %zu", pendingBytes->size()); - auto respondWithTag = [&](const std::wstring& tag, const std::wstring& source, HRESULT result) { - // Respond to the guest with the tag that should be used to mount the device. + try + { + auto response = ProcessVirtioFsRequest(message); - wsl::shared::MessageWriter response(LxInitMessageAddVirtioFsDeviceResponse); - response->Result = SUCCEEDED(result) ? 0 : EINVAL; // TODO: Improved HRESULT -> errno mapping. - response.WriteString(response->TagOffset, tag); - response.WriteString(response->SourceOffset, source); + // Move the socket out of the channel into the WriteHandle so it is closed once the reply is sent. + io.AddHandle(std::make_unique(channel->Release(), response), io::MultiHandleWait::IgnoreErrors); + } + CATCH_LOG(); + }), + io::MultiHandleWait::IgnoreErrors); + })); - transaction.Send(response.Span()); - }; + io.AddHandle(std::make_unique(m_terminatingEvent.get()), io::MultiHandleWait::CancelOnCompleted); - if (message->MessageType == LxInitMessageAddVirtioFsDevice) - { - std::wstring tag; - std::wstring source; - const auto result = wil::ResultFromException([this, span, &tag, &source]() { - const auto* addShare = gslhelpers::try_get_struct(span); - THROW_HR_IF(E_UNEXPECTED, !addShare); - - const auto path = wsl::shared::string::FromSpan(span, addShare->PathOffset); - const auto pathWide = wsl::shared::string::MultiByteToWide(path); - const auto options = wsl::shared::string::FromSpan(span, addShare->OptionsOffset); - const auto optionsWide = wsl::shared::string::MultiByteToWide(options); - - // Acquire the lock and attempt to add the device. - auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); - std::tie(tag, source) = AddVirtioFsShare(addShare->Admin, pathWide.c_str(), optionsWide.c_str()); - }); - - respondWithTag(tag, source, result); - } - else if (message->MessageType == LxInitMessageRemountVirtioFsDevice) - { - std::wstring newTag; - std::wstring source; - const auto result = wil::ResultFromException([this, span, &newTag, &source]() { - const auto* remountShare = gslhelpers::try_get_struct(span); - THROW_HR_IF(E_UNEXPECTED, !remountShare); + io.Run({}); +} +CATCH_LOG() - const std::string tag = wsl::shared::string::FromSpan(span, remountShare->TagOffset); - const auto tagWide = wsl::shared::string::MultiByteToWide(tag); - auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); - const auto foundShare = FindVirtioFsShare(tagWide.c_str(), !remountShare->Admin); - THROW_HR_IF_MSG(E_UNEXPECTED, !foundShare.has_value(), "Unknown tag %ls", tagWide.c_str()); +std::vector WslCoreVm::ProcessVirtioFsRequest(_In_ gsl::span Request) +{ + const auto* header = gslhelpers::try_get_struct(Request); + THROW_HR_IF(E_UNEXPECTED, !header); - std::tie(newTag, source) = - AddVirtioFsShare(remountShare->Admin, foundShare->Path.c_str(), foundShare->OptionsString().c_str()); + WSL_LOG("VirtiofsMessageRequest", TraceLoggingValue(header->PrettyPrint().c_str(), "Content")); - WI_ASSERT(source == foundShare->Path); - }); + auto buildResponse = [header](const std::wstring& tag, const std::wstring& source, HRESULT result) { + // Respond to the guest with the tag that should be used to mount the device. + wsl::shared::MessageWriter response(LxInitMessageAddVirtioFsDeviceResponse); + response->Result = SUCCEEDED(result) ? 0 : EINVAL; // TODO: Improved HRESULT -> errno mapping. + response.WriteString(response->TagOffset, tag); + response.WriteString(response->SourceOffset, source); - respondWithTag(newTag, source, result); - } - else - { - THROW_HR_MSG(E_UNEXPECTED, "Unexpected MessageType %d", message->MessageType); - } - } - CATCH_LOG() - }).detach(); + // Echo the request's transaction id and mark the message as the first (and only) reply. + response->Header.TransactionId = header->TransactionId; + response->Header.TransactionStep = static_cast(TRANSACTION_STEP::FIRST_REPLY); + + WSL_LOG("VirtiofsMessageResponse", TraceLoggingValue(response->PrettyPrint().c_str(), "Content")); + + const auto span = response.Span(); + return std::vector(reinterpret_cast(span.data()), reinterpret_cast(span.data()) + span.size()); + }; + + if (header->MessageType == LxInitMessageAddVirtioFsDevice) + { + std::wstring tag; + std::wstring source; + const auto result = wil::ResultFromException([&]() { + const auto* addShare = gslhelpers::try_get_struct(Request); + THROW_HR_IF(E_UNEXPECTED, !addShare); + + const auto path = wsl::shared::string::FromSpan(Request, addShare->PathOffset); + const auto pathWide = wsl::shared::string::MultiByteToWide(path); + const auto options = wsl::shared::string::FromSpan(Request, addShare->OptionsOffset); + const auto optionsWide = wsl::shared::string::MultiByteToWide(options); + + // Acquire the lock and attempt to add the device. + auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); + std::tie(tag, source) = AddVirtioFsShare(addShare->Admin, pathWide.c_str(), optionsWide.c_str()); + }); + + return buildResponse(tag, source, result); + } + else if (header->MessageType == LxInitMessageRemountVirtioFsDevice) + { + std::wstring newTag; + std::wstring source; + const auto result = wil::ResultFromException([&]() { + const auto* remountShare = gslhelpers::try_get_struct(Request); + THROW_HR_IF(E_UNEXPECTED, !remountShare); + + const std::string tag = wsl::shared::string::FromSpan(Request, remountShare->TagOffset); + const auto tagWide = wsl::shared::string::MultiByteToWide(tag); + auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); + const auto foundShare = FindVirtioFsShare(tagWide.c_str(), !remountShare->Admin); + THROW_HR_IF_MSG(E_UNEXPECTED, !foundShare.has_value(), "Unknown tag %ls", tagWide.c_str()); + + std::tie(newTag, source) = + AddVirtioFsShare(remountShare->Admin, foundShare->Path.c_str(), foundShare->OptionsString().c_str()); + + WI_ASSERT(source == foundShare->Path); + }); + + return buildResponse(newTag, source, result); + } + else + { + THROW_HR_MSG(E_UNEXPECTED, "Unexpected MessageType %d", header->MessageType); } } -CATCH_LOG() std::string WslCoreVm::s_GetMountTargetName(_In_ PCWSTR Disk, _In_opt_ PCWSTR Name, _In_ int PartitionIndex) { diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index b3bb810b6..30ce6b463 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -251,6 +251,8 @@ class WslCoreVm void VirtioFsWorker(_In_ const wil::unique_socket& socket); + std::vector ProcessVirtioFsRequest(_In_ gsl::span Request); + static std::string s_GetMountTargetName(_In_ PCWSTR Disk, _In_opt_ PCWSTR Name, _In_ int PartitionIndex); static LX_INIT_DRVFS_MOUNT s_InitializeDrvFs(_Inout_ WslCoreVm* VmContext, _In_ HANDLE UserToken); diff --git a/src/windows/wslcsession/WSLCVirtualMachine.cpp b/src/windows/wslcsession/WSLCVirtualMachine.cpp index 5c5766f5e..593198d16 100644 --- a/src/windows/wslcsession/WSLCVirtualMachine.cpp +++ b/src/windows/wslcsession/WSLCVirtualMachine.cpp @@ -1259,7 +1259,7 @@ void WSLCVirtualMachine::CollectCrashDumps(wil::unique_socket&& listenSocket) { try { - auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_vmTerminatingEvent.get()); + auto socket = socket::CancellableAccept(listenSocket.get(), INFINITE, m_vmTerminatingEvent.get()); if (!socket) { // VM is exiting. diff --git a/src/windows/wslrelay/localhost.cpp b/src/windows/wslrelay/localhost.cpp index 1b2af8329..378ee4767 100644 --- a/src/windows/wslrelay/localhost.cpp +++ b/src/windows/wslrelay/localhost.cpp @@ -291,15 +291,11 @@ try { // Begin accepting connections until the relay is stopped. - const int AddressFamily = WindowsAddressFamily(Arguments->Family); - for (;;) { - wil::unique_socket InetSocket(WSASocket(AddressFamily, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)); - THROW_LAST_ERROR_IF(!InetSocket); - - if (!wsl::windows::common::socket::CancellableAccept( - Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get())) + auto InetSocket = + wsl::windows::common::socket::CancellableAccept(Arguments->ListenSocket.get(), INFINITE, Arguments->ExitEvent.get()); + if (!InetSocket) { break; // Exit event was signaled, exit. } @@ -308,7 +304,7 @@ try WSL_LOG("PortRelayUsage", TraceLoggingValue(Arguments->Family, "family"), TraceLoggingValue(Arguments->Port, "port"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); - auto RelayThread = std::thread([Arguments, InetSocket = std::move(InetSocket)]() { + auto RelayThread = std::thread([Arguments, InetSocket = std::move(*InetSocket)]() { try { wsl::windows::common::wslutil::SetThreadDescription(L"Port relay"); diff --git a/src/windows/wslrelay/main.cpp b/src/windows/wslrelay/main.cpp index ea5080be2..c5657fd4b 100644 --- a/src/windows/wslrelay/main.cpp +++ b/src/windows/wslrelay/main.cpp @@ -123,17 +123,15 @@ try THROW_LAST_ERROR_IF(listen(listenSocket.get(), 1) == SOCKET_ERROR); - const wil::unique_socket socket(WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)); - THROW_LAST_ERROR_IF(!socket); - - if (!wsl::windows::common::socket::CancellableAccept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get())) + auto socket = wsl::windows::common::socket::CancellableAccept(listenSocket.get(), INFINITE, exitEvent.get()); + if (!socket) { return 1; } // Begin the relay. wsl::windows::common::relay::BidirectionalRelay( - reinterpret_cast(socket.get()), pipe.get(), 0x1000, wsl::windows::common::relay::RelayFlags::LeftIsSocket); + reinterpret_cast(socket->get()), pipe.get(), 0x1000, wsl::windows::common::relay::RelayFlags::LeftIsSocket); break; } diff --git a/test/windows/DrvFsTests.cpp b/test/windows/DrvFsTests.cpp index 9025ba6a3..7838b9719 100644 --- a/test/windows/DrvFsTests.cpp +++ b/test/windows/DrvFsTests.cpp @@ -409,6 +409,59 @@ class DrvFsTests VERIFY_IS_TRUE(out.find(L"test-file.txt") != std::wstring::npos); } + void DrvfsMountManyVirtioFsShares(DrvFsMode Mode) + { + if (Mode != DrvFsMode::VirtioFs) + { + LogSkipped("This test is only applicable to VirtioFs"); + return; + } + + WINDOWS_11_TEST_ONLY(); + SKIP_TEST_ARM64(); + + constexpr auto c_iterations = 20; + auto testDir = std::filesystem::current_path() / "virtiofs-loop-test"; + + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + LxsstuLaunchWsl(L"umount /tmp/virtiofs-loop-test-*"); + + std::error_code ec; + std::filesystem::remove_all(testDir, ec); + }); + + for (int i = 0; i < c_iterations; ++i) + { + const auto sourceDir = testDir / std::to_string(i); + std::filesystem::create_directories(sourceDir); + + const auto expected = std::format("virtiofs share {}", i); + { + std::ofstream markerFile(std::filesystem::path(sourceDir) / L"marker"); + markerFile << expected; + } + + const auto mountPoint = std::format(L"/tmp/virtiofs-loop-test-{}", i); + + // Mount the share. + VERIFY_ARE_EQUAL(LxsstuLaunchWsl(std::format(L"mkdir -p '{}'", mountPoint)), 0); + VERIFY_ARE_EQUAL(LxsstuLaunchWsl(std::format(L"mount -t drvfs '{}' '{}'", sourceDir.string(), mountPoint)), 0); + + // Validate that it can be accessed. + { + auto [out, err] = LxsstuLaunchWslAndCaptureOutput(std::format(L"cat '{}/marker'", mountPoint)); + VERIFY_ARE_EQUAL(out, wsl::shared::string::MultiByteToWide(expected)); + } + + // Validate the mount options. + { + auto [out, err] = LxsstuLaunchWslAndCaptureOutput(std::format(L"findmnt -ln '{}'", mountPoint)); + + VerifyPatternMatch(wsl::shared::string::WideToMultiByte(out.c_str()), std::format("{} * virtiofs rw,relatime\n", mountPoint)); + } + } + } + // DrvFsTests Private Methods private: static VOID CreateDrvFsTestFiles(bool Metadata) @@ -1303,6 +1356,11 @@ class WSL1 : public DrvFsTests { \ DrvFsTests::DrvFsMountUnicodePath(DrvFsMode::##_mode##); \ } \ +\ + WSL2_TEST_METHOD(DrvfsMountManyVirtioFsShares) \ + { \ + DrvFsTests::DrvfsMountManyVirtioFsShares(DrvFsMode::##_mode##); \ + } \ } WSL2_DRVFS_TEST_CLASS(Plan9); @@ -1313,4 +1371,4 @@ WSL2_DRVFS_TEST_CLASS(VirtioFs); // TODO: Enable again once the issue is resolved // WSL2_DRVFS_TEST_CLASS(Virtio9p); -} // namespace DrvFsTests \ No newline at end of file +} // namespace DrvFsTests