Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module(
name = "cpp_toolbelt",
version = "2.1.2",
version = "2.1.3",
)

bazel_dep(name = "platforms", version = "1.0.0")
Expand Down
123 changes: 80 additions & 43 deletions toolbelt/sockets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,23 +485,23 @@ absl::Status UnixSocket::SendFds(const std::vector<FileDescriptor> &fds,
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#endif
struct msghdr msg = {.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control_buf.data(),
.msg_controllen =
static_cast<socklen_t>(CMSG_SPACE(fds_size))};
struct msghdr msg = {.msg_iov = &iov, .msg_iovlen = 1};
#if defined(__clang__)
#pragma clang diagnostic pop
#elif defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(fds_size);
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
for (size_t i = first_fd; i < first_fd + fds_to_send; i++) {
*fdptr++ = fds[i].Fd();
if (fds_to_send > 0) {
msg.msg_control = control_buf.data();
msg.msg_controllen = static_cast<socklen_t>(CMSG_SPACE(fds_size));
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(fds_size);
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
for (size_t i = first_fd; i < first_fd + fds_to_send; i++) {
*fdptr++ = fds[i].Fd();
}
}

if (c != nullptr) {
Expand Down Expand Up @@ -531,14 +531,19 @@ absl::Status UnixSocket::ReceiveFds(std::vector<FileDescriptor> &fds,

int32_t num_fds_received = 0;
for (;;) {
std::fill(control_buf.begin(), control_buf.end(), 0);

// The total number of fds we need to see. This is
// sent in each message, but each message contains only portion
// of the total (there's a limit per message).
int32_t total_fds;
struct iovec iov = {.iov_base = reinterpret_cast<void *>(&total_fds),
.iov_len = sizeof(int32_t)};
int32_t total_fds = 0;
size_t total_fds_bytes = 0;
bool saw_rights = false;
int num_fds = 0;

while (total_fds_bytes < sizeof(total_fds)) {
std::fill(control_buf.begin(), control_buf.end(), 0);
struct iovec iov = {
.iov_base = reinterpret_cast<char *>(&total_fds) + total_fds_bytes,
.iov_len = sizeof(total_fds) - total_fds_bytes};

#if defined(__clang__)
#pragma clang diagnostic push
Expand All @@ -547,42 +552,74 @@ absl::Status UnixSocket::ReceiveFds(std::vector<FileDescriptor> &fds,
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#endif
struct msghdr msg = {.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control_buf.data(),
.msg_controllen =
static_cast<socklen_t>(control_buf.size())};
struct msghdr msg = {.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control_buf.data(),
.msg_controllen =
static_cast<socklen_t>(control_buf.size())};
#if defined(__clang__)
#pragma clang diagnostic pop
#elif defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
if (c != nullptr) {
int fd = c->Wait(fd_.Fd(), POLLIN);
if (fd != fd_.Fd()) {
return absl::InternalError("Interrupted");
if (c != nullptr) {
int fd = c->Wait(fd_.Fd(), POLLIN);
if (fd != fd_.Fd()) {
return absl::InternalError("Interrupted");
}
}
ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0);
if (n == -1) {
return absl::InternalError(absl::StrFormat(
"Failed to read fds to unix socket: %s", strerror(errno)));
}
if (n == 0) {
return absl::InternalError(
absl::StrFormat("EOF from socket while reading fds\n"));
}

total_fds_bytes += static_cast<size_t>(n);

if ((msg.msg_flags & MSG_CTRUNC) != 0) {
return absl::InternalError(
"Control data was truncated while reading fds from unix socket");
}

for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr;
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) {
continue;
}
saw_rights = true;
if (cmsg->cmsg_len < CMSG_LEN(0)) {
return absl::InternalError(absl::StrFormat(
"Invalid SCM_RIGHTS control length %zu while reading fds",
static_cast<size_t>(cmsg->cmsg_len)));
}
size_t data_len = cmsg->cmsg_len - CMSG_LEN(0);
if (data_len % sizeof(int) != 0) {
return absl::InternalError(absl::StrFormat(
"Misaligned SCM_RIGHTS control length %zu while reading fds",
static_cast<size_t>(cmsg->cmsg_len)));
}
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
int fds_in_message = static_cast<int>(data_len / sizeof(int));
for (int i = 0; i < fds_in_message; i++) {
fds.emplace_back(fdptr[i]);
}
num_fds += fds_in_message;
}
}
ssize_t n = ::recvmsg(fd_.Fd(), &msg, 0);
if (n == -1) {
return absl::InternalError(absl::StrFormat(
"Failed to read fds to unix socket: %s", strerror(errno)));
}
if (n == 0) {
return absl::InternalError(
absl::StrFormat("EOF from socket while reading fds\n"));
}

struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg == nullptr) {
// This can happen, apparently.
return absl::OkStatus();
if (total_fds == 0) {
break;
}
int *fdptr = reinterpret_cast<int *>(CMSG_DATA(cmsg));
int num_fds = (cmsg->cmsg_len - sizeof(struct cmsghdr)) / sizeof(int);
for (int i = 0; i < num_fds; i++) {
fds.emplace_back(fdptr[i]);
if (!saw_rights && total_fds > 0) {
return absl::InternalError(absl::StrFormat(
"Expected %d fds from unix socket but received no SCM_RIGHTS message",
total_fds));
}

// Add the number we received in this message to the total.
num_fds_received += num_fds;

Expand Down
99 changes: 99 additions & 0 deletions toolbelt/sockets_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,105 @@ TEST(SocketsTest, UnixSocket) {
remove(socket_name.c_str());
}

TEST(SocketsTest, UnixSocketZeroFds) {
char tmp[] = "/tmp/socketsXXXXXX";
int fd = mkstemp(tmp);
ASSERT_NE(-1, fd);
std::string socket_name = tmp;
close(fd);

unlink(socket_name.c_str());
co::CoroutineScheduler scheduler;

toolbelt::UnixSocket listener;
absl::Status status = listener.Bind(socket_name, true);
ASSERT_TRUE(status.ok());

co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) {
absl::StatusOr<toolbelt::UnixSocket> s = listener.Accept(c);
ASSERT_TRUE(s.ok());
auto socket = s.value();

std::vector<toolbelt::FileDescriptor> fds;
absl::Status s2 = socket.ReceiveFds(fds, c);
ASSERT_TRUE(s2.ok());
ASSERT_TRUE(fds.empty());
});

co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) {
toolbelt::UnixSocket socket;
absl::Status s = socket.Connect(socket_name);
ASSERT_TRUE(s.ok());

std::vector<toolbelt::FileDescriptor> fds;
absl::Status s2 = socket.SendFds(fds, c);
ASSERT_TRUE(s2.ok());
});

scheduler.Run();
remove(socket_name.c_str());
}

TEST(SocketsTest, UnixSocketShortFdCountRead) {
char tmp[] = "/tmp/socketsXXXXXX";
int fd = mkstemp(tmp);
ASSERT_NE(-1, fd);
std::string socket_name = tmp;
close(fd);

unlink(socket_name.c_str());
co::CoroutineScheduler scheduler;

toolbelt::UnixSocket listener;
absl::Status status = listener.Bind(socket_name, true);
ASSERT_TRUE(status.ok());

co::Coroutine incoming(scheduler, [&listener](co::Coroutine* c) {
absl::StatusOr<toolbelt::UnixSocket> s = listener.Accept(c);
ASSERT_TRUE(s.ok());
auto socket = s.value();

std::vector<toolbelt::FileDescriptor> fds;
absl::Status s2 = socket.ReceiveFds(fds, c);
ASSERT_TRUE(s2.ok());
ASSERT_EQ(1, fds.size());
});

co::Coroutine outgoing(scheduler, [&socket_name](co::Coroutine* c) {
toolbelt::UnixSocket socket;
absl::Status s = socket.Connect(socket_name);
ASSERT_TRUE(s.ok());

int32_t num_fds = 1;
char* num_fds_bytes = reinterpret_cast<char*>(&num_fds);
int fd_to_send = dup(0);
ASSERT_NE(-1, fd_to_send);

char control_buf[CMSG_SPACE(sizeof(int))] = {};
struct iovec iov = {.iov_base = num_fds_bytes, .iov_len = 1};
struct msghdr msg = {.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = control_buf,
.msg_controllen = sizeof(control_buf)};
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
*reinterpret_cast<int*>(CMSG_DATA(cmsg)) = fd_to_send;

ssize_t n = sendmsg(socket.GetFileDescriptor().Fd(), &msg, 0);
ASSERT_EQ(1, n);
close(fd_to_send);

n = send(socket.GetFileDescriptor().Fd(), num_fds_bytes + 1,
sizeof(num_fds) - 1, 0);
ASSERT_EQ(static_cast<ssize_t>(sizeof(num_fds) - 1), n);
});

scheduler.Run();
remove(socket_name.c_str());
}

TEST(SocketsTest, UnixSocketErrors) {
toolbelt::UnixSocket socket;
// Socket is inValid, all will fail.
Expand Down
Loading