From b71c6a45fe59b0db6af69be3fc79151304ec53bd Mon Sep 17 00:00:00 2001 From: David Allison Date: Mon, 22 Jun 2026 19:58:27 -0700 Subject: [PATCH 1/3] Fix portability of SCM_RIGHTS --- toolbelt/sockets.cc | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/toolbelt/sockets.cc b/toolbelt/sockets.cc index 582ba9d..b549445 100644 --- a/toolbelt/sockets.cc +++ b/toolbelt/sockets.cc @@ -573,16 +573,43 @@ absl::Status UnixSocket::ReceiveFds(std::vector &fds, absl::StrFormat("EOF from socket while reading fds\n")); } - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - if (cmsg == nullptr) { - // This can happen, apparently. - return absl::OkStatus(); + if ((msg.msg_flags & MSG_CTRUNC) != 0) { + return absl::InternalError( + "Control data was truncated while reading fds from unix socket"); } - int *fdptr = reinterpret_cast(CMSG_DATA(cmsg)); - int num_fds = (cmsg->cmsg_len - sizeof(struct cmsghdr)) / sizeof(int); - for (int i = 0; i < num_fds; i++) { - fds.emplace_back(fdptr[i]); + + bool saw_rights = false; + int num_fds = 0; + for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) { + continue; + } + saw_rights = true; + if (cmsg->cmsg_len < CMSG_LEN(0)) { + return absl::InternalError(absl::StrFormat( + "Invalid SCM_RIGHTS control length %zu while reading fds", + static_cast(cmsg->cmsg_len))); + } + size_t data_len = cmsg->cmsg_len - CMSG_LEN(0); + if (data_len % sizeof(int) != 0) { + return absl::InternalError(absl::StrFormat( + "Misaligned SCM_RIGHTS control length %zu while reading fds", + static_cast(cmsg->cmsg_len))); + } + int *fdptr = reinterpret_cast(CMSG_DATA(cmsg)); + int fds_in_message = static_cast(data_len / sizeof(int)); + for (int i = 0; i < fds_in_message; i++) { + fds.emplace_back(fdptr[i]); + } + num_fds += fds_in_message; } + if (!saw_rights && total_fds > 0) { + return absl::InternalError(absl::StrFormat( + "Expected %d fds from unix socket but received no SCM_RIGHTS message", + total_fds)); + } + // Add the number we received in this message to the total. num_fds_received += num_fds; From cd72868b04bc8483b7131cb3b7f11bb43c820a0a Mon Sep 17 00:00:00 2001 From: David Allison Date: Mon, 22 Jun 2026 20:10:25 -0700 Subject: [PATCH 2/3] Version 2.1.3 --- MODULE.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MODULE.bazel b/MODULE.bazel index 0c3fd8d..2927f4d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -1,6 +1,6 @@ module( name = "cpp_toolbelt", - version = "2.1.2", + version = "2.1.3", ) bazel_dep(name = "platforms", version = "1.0.0") From 04faab2f34ff5d3b460cad0ae6855f2da521d1eb Mon Sep 17 00:00:00 2001 From: Dave Allison Date: Tue, 23 Jun 2026 15:04:06 -0700 Subject: [PATCH 3/3] Don't rely on sendmsg boundaries for SCM_RIGHTS --- toolbelt/sockets.cc | 136 +++++++++++++++++++++------------------ toolbelt/sockets_test.cc | 99 ++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 63 deletions(-) diff --git a/toolbelt/sockets.cc b/toolbelt/sockets.cc index b549445..965a0af 100644 --- a/toolbelt/sockets.cc +++ b/toolbelt/sockets.cc @@ -485,23 +485,23 @@ absl::Status UnixSocket::SendFds(const std::vector &fds, #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmissing-field-initializers" #endif - struct msghdr msg = {.msg_iov = &iov, - .msg_iovlen = 1, - .msg_control = control_buf.data(), - .msg_controllen = - static_cast(CMSG_SPACE(fds_size))}; + struct msghdr msg = {.msg_iov = &iov, .msg_iovlen = 1}; #if defined(__clang__) #pragma clang diagnostic pop #elif defined(__GNUC__) #pragma GCC diagnostic pop #endif - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(fds_size); - int *fdptr = reinterpret_cast(CMSG_DATA(cmsg)); - for (size_t i = first_fd; i < first_fd + fds_to_send; i++) { - *fdptr++ = fds[i].Fd(); + if (fds_to_send > 0) { + msg.msg_control = control_buf.data(); + msg.msg_controllen = static_cast(CMSG_SPACE(fds_size)); + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(fds_size); + int *fdptr = reinterpret_cast(CMSG_DATA(cmsg)); + for (size_t i = first_fd; i < first_fd + fds_to_send; i++) { + *fdptr++ = fds[i].Fd(); + } } if (c != nullptr) { @@ -531,14 +531,19 @@ absl::Status UnixSocket::ReceiveFds(std::vector &fds, int32_t num_fds_received = 0; for (;;) { - std::fill(control_buf.begin(), control_buf.end(), 0); - // The total number of fds we need to see. This is // sent in each message, but each message contains only portion // of the total (there's a limit per message). - int32_t total_fds; - struct iovec iov = {.iov_base = reinterpret_cast(&total_fds), - .iov_len = sizeof(int32_t)}; + int32_t total_fds = 0; + size_t total_fds_bytes = 0; + bool saw_rights = false; + int num_fds = 0; + + while (total_fds_bytes < sizeof(total_fds)) { + std::fill(control_buf.begin(), control_buf.end(), 0); + struct iovec iov = { + .iov_base = reinterpret_cast(&total_fds) + total_fds_bytes, + .iov_len = sizeof(total_fds) - total_fds_bytes}; #if defined(__clang__) #pragma clang diagnostic push @@ -547,62 +552,67 @@ absl::Status UnixSocket::ReceiveFds(std::vector &fds, #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmissing-field-initializers" #endif - struct msghdr msg = {.msg_iov = &iov, - .msg_iovlen = 1, - .msg_control = control_buf.data(), - .msg_controllen = - static_cast(control_buf.size())}; + struct msghdr msg = {.msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = control_buf.data(), + .msg_controllen = + static_cast(control_buf.size())}; #if defined(__clang__) #pragma clang diagnostic pop #elif defined(__GNUC__) #pragma GCC diagnostic pop #endif - if (c != nullptr) { - int fd = c->Wait(fd_.Fd(), POLLIN); - if (fd != fd_.Fd()) { - return absl::InternalError("Interrupted"); - } - } - ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0); - if (n == -1) { - return absl::InternalError(absl::StrFormat( - "Failed to read fds to unix socket: %s", strerror(errno))); - } - if (n == 0) { - return absl::InternalError( - absl::StrFormat("EOF from socket while reading fds\n")); - } - - if ((msg.msg_flags & MSG_CTRUNC) != 0) { - return absl::InternalError( - "Control data was truncated while reading fds from unix socket"); - } - - bool saw_rights = false; - int num_fds = 0; - for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; - cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) { - continue; + if (c != nullptr) { + int fd = c->Wait(fd_.Fd(), POLLIN); + if (fd != fd_.Fd()) { + return absl::InternalError("Interrupted"); + } } - saw_rights = true; - if (cmsg->cmsg_len < CMSG_LEN(0)) { + ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0); + if (n == -1) { return absl::InternalError(absl::StrFormat( - "Invalid SCM_RIGHTS control length %zu while reading fds", - static_cast(cmsg->cmsg_len))); + "Failed to read fds to unix socket: %s", strerror(errno))); } - size_t data_len = cmsg->cmsg_len - CMSG_LEN(0); - if (data_len % sizeof(int) != 0) { - return absl::InternalError(absl::StrFormat( - "Misaligned SCM_RIGHTS control length %zu while reading fds", - static_cast(cmsg->cmsg_len))); + if (n == 0) { + return absl::InternalError( + absl::StrFormat("EOF from socket while reading fds\n")); } - int *fdptr = reinterpret_cast(CMSG_DATA(cmsg)); - int fds_in_message = static_cast(data_len / sizeof(int)); - for (int i = 0; i < fds_in_message; i++) { - fds.emplace_back(fdptr[i]); + + total_fds_bytes += static_cast(n); + + if ((msg.msg_flags & MSG_CTRUNC) != 0) { + return absl::InternalError( + "Control data was truncated while reading fds from unix socket"); + } + + for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) { + continue; + } + saw_rights = true; + if (cmsg->cmsg_len < CMSG_LEN(0)) { + return absl::InternalError(absl::StrFormat( + "Invalid SCM_RIGHTS control length %zu while reading fds", + static_cast(cmsg->cmsg_len))); + } + size_t data_len = cmsg->cmsg_len - CMSG_LEN(0); + if (data_len % sizeof(int) != 0) { + return absl::InternalError(absl::StrFormat( + "Misaligned SCM_RIGHTS control length %zu while reading fds", + static_cast(cmsg->cmsg_len))); + } + int *fdptr = reinterpret_cast(CMSG_DATA(cmsg)); + int fds_in_message = static_cast(data_len / sizeof(int)); + for (int i = 0; i < fds_in_message; i++) { + fds.emplace_back(fdptr[i]); + } + num_fds += fds_in_message; } - num_fds += fds_in_message; + } + + if (total_fds == 0) { + break; } if (!saw_rights && total_fds > 0) { return absl::InternalError(absl::StrFormat( diff --git a/toolbelt/sockets_test.cc b/toolbelt/sockets_test.cc index 24532ae..f766c8c 100644 --- a/toolbelt/sockets_test.cc +++ b/toolbelt/sockets_test.cc @@ -137,6 +137,105 @@ TEST(SocketsTest, UnixSocket) { remove(socket_name.c_str()); } +TEST(SocketsTest, UnixSocketZeroFds) { + char tmp[] = "/tmp/socketsXXXXXX"; + int fd = mkstemp(tmp); + ASSERT_NE(-1, fd); + std::string socket_name = tmp; + close(fd); + + unlink(socket_name.c_str()); + co::CoroutineScheduler scheduler; + + toolbelt::UnixSocket listener; + absl::Status status = listener.Bind(socket_name, true); + ASSERT_TRUE(status.ok()); + + co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + std::vector fds; + absl::Status s2 = socket.ReceiveFds(fds, c); + ASSERT_TRUE(s2.ok()); + ASSERT_TRUE(fds.empty()); + }); + + co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) { + toolbelt::UnixSocket socket; + absl::Status s = socket.Connect(socket_name); + ASSERT_TRUE(s.ok()); + + std::vector fds; + absl::Status s2 = socket.SendFds(fds, c); + ASSERT_TRUE(s2.ok()); + }); + + scheduler.Run(); + remove(socket_name.c_str()); +} + +TEST(SocketsTest, UnixSocketShortFdCountRead) { + char tmp[] = "/tmp/socketsXXXXXX"; + int fd = mkstemp(tmp); + ASSERT_NE(-1, fd); + std::string socket_name = tmp; + close(fd); + + unlink(socket_name.c_str()); + co::CoroutineScheduler scheduler; + + toolbelt::UnixSocket listener; + absl::Status status = listener.Bind(socket_name, true); + ASSERT_TRUE(status.ok()); + + co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) { + absl::StatusOr s = listener.Accept(c); + ASSERT_TRUE(s.ok()); + auto socket = s.value(); + + std::vector fds; + absl::Status s2 = socket.ReceiveFds(fds, c); + ASSERT_TRUE(s2.ok()); + ASSERT_EQ(1, fds.size()); + }); + + co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) { + toolbelt::UnixSocket socket; + absl::Status s = socket.Connect(socket_name); + ASSERT_TRUE(s.ok()); + + int32_t num_fds = 1; + char* num_fds_bytes = reinterpret_cast(&num_fds); + int fd_to_send = dup(0); + ASSERT_NE(-1, fd_to_send); + + char control_buf[CMSG_SPACE(sizeof(int))] = {}; + struct iovec iov = {.iov_base = num_fds_bytes, .iov_len = 1}; + struct msghdr msg = {.msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = control_buf, + .msg_controllen = sizeof(control_buf)}; + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + *reinterpret_cast(CMSG_DATA(cmsg)) = fd_to_send; + + ssize_t n = sendmsg(socket.GetFileDescriptor().Fd(), &msg, 0); + ASSERT_EQ(1, n); + close(fd_to_send); + + n = send(socket.GetFileDescriptor().Fd(), num_fds_bytes + 1, + sizeof(num_fds) - 1, 0); + ASSERT_EQ(static_cast(sizeof(num_fds) - 1), n); + }); + + scheduler.Run(); + remove(socket_name.c_str()); +} + TEST(SocketsTest, UnixSocketErrors) { toolbelt::UnixSocket socket; // Socket is inValid, all will fail.