sys_util: add sock_ctrl_msg module for transmitting messages with FD
This CL also includes the gcc build time dependency for building the
sock_ctrl_msg.c helper code.
TEST=cargo test
BUG=chromium:738638
Change-Id: I4adc2360b7fab4ed7d557603aa7bad2e738b69b4
Reviewed-on: https://chromium-review.googlesource.com/562574
Commit-Ready: Zach Reizner <zachr@chromium.org>
Tested-by: Zach Reizner <zachr@chromium.org>
Reviewed-by: Chirantan Ekbote <chirantan@chromium.org>
Reviewed-by: Dylan Reid <dgreid@chromium.org>
diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml
index df54b13..bc9e1f9 100644
--- a/sys_util/Cargo.toml
+++ b/sys_util/Cargo.toml
@@ -2,8 +2,12 @@
name = "sys_util"
version = "0.1.0"
authors = ["The Chromium OS Authors"]
+build = "build.rs"
[dependencies]
data_model = { path = "../data_model" }
libc = "*"
-syscall_defines = { path = "../syscall_defines" }
\ No newline at end of file
+syscall_defines = { path = "../syscall_defines" }
+
+[build-dependencies]
+gcc = "0.3"
diff --git a/sys_util/build.rs b/sys_util/build.rs
new file mode 100644
index 0000000..ad96135
--- /dev/null
+++ b/sys_util/build.rs
@@ -0,0 +1,9 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+extern crate gcc;
+
+fn main() {
+ gcc::compile_library("libsock_ctrl_msg.a", &["sock_ctrl_msg.c"]);
+}
diff --git a/sys_util/sock_ctrl_msg.c b/sys_util/sock_ctrl_msg.c
new file mode 100644
index 0000000..15b7be3
--- /dev/null
+++ b/sys_util/sock_ctrl_msg.c
@@ -0,0 +1,123 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <stdint.h>
+#include <string.h> // memcpy
+#include <sys/errno.h>
+#include <sys/socket.h> // CMSG_*
+
+/*
+ * Returns the number of bytes the `cmsg_buffer` must be for the functions that take a cmsg_buffer
+ * in this module.
+ * Arguments:
+ * * `fd_count` - Maximum number of file descriptors to be sent or received via the cmsg.
+ */
+size_t scm_cmsg_buffer_len(size_t fd_count)
+{
+ return CMSG_SPACE(sizeof(int) * fd_count);
+}
+
+/*
+ * Convenience wrapper around `sendmsg` that builds up the `msghdr` structure for you given the
+ * array of fds.
+ * Arguments:
+ * * `fd` - Unix domain socket to `sendmsg` on.
+ * * `outv` - Array of `outv_count` length `iovec`s that contain the data to send.
+ * * `outv_count` - Number of elements in `outv` array.
+ * * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(fd_count)` bytes long.
+ * * `fds` - Array of `fd_count` file descriptors to send along with data.
+ * * `fd_count` - Number of elements in `fds` array.
+ * Returns:
+ * A non-negative number indicating how many bytes were sent on success or a negative errno on
+ * failure.
+ */
+ssize_t scm_sendmsg(int fd, const struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer,
+ const int *fds, size_t fd_count)
+{
+ if (fd < 0 || ((!cmsg_buffer || !fds) && fd_count > 0))
+ return -EINVAL;
+
+ struct msghdr msg = {0};
+ msg.msg_iov = (struct iovec *)outv; // discard const, sendmsg won't mutate it
+ msg.msg_iovlen = outv_count;
+
+ if (fd_count) {
+ msg.msg_control = cmsg_buffer;
+ msg.msg_controllen = scm_cmsg_buffer_len(fd_count);
+
+ struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int));
+ memcpy(CMSG_DATA(cmsg), fds, fd_count * sizeof(int));
+
+ msg.msg_controllen = cmsg->cmsg_len;
+ }
+
+ ssize_t bytes_sent = sendmsg(fd, &msg, MSG_NOSIGNAL);
+ if (bytes_sent == -1)
+ return -errno;
+
+ return bytes_sent;
+}
+
+/*
+ * Convenience wrapper around `recvmsg` that builds up the `msghdr` structure and returns up to
+ * `*fd_count` file descriptors in the given `fds` array.
+ * Arguments:
+ * * `fd` - Unix domain socket to `recvmsg` on.
+ * * `outv` - Array of `outv_count` length `iovec`s that will contain the received data.
+ * * `outv_count` - Number of elements in `outv` array.
+ * * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(*fd_count)` bytes long.
+ * * `fds` - Array of `fd_count` file descriptors to receive along with data.
+ * * `fd_count` - Number of elements in `fds` array.
+ * Returns:
+ * A non-negative number indicating how many bytes were received on success or a negative errno on
+ * failure.
+ */
+ssize_t scm_recvmsg(int fd, struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, int *fds,
+ size_t *fd_count)
+{
+ if (fd < 0 || !cmsg_buffer || !fds || !fd_count)
+ return -EINVAL;
+
+ struct msghdr msg = {0};
+ msg.msg_iov = outv;
+ msg.msg_iovlen = outv_count;
+ msg.msg_control = cmsg_buffer;
+ msg.msg_controllen = scm_cmsg_buffer_len(*fd_count);
+
+ ssize_t total_read = recvmsg(fd, &msg, 0);
+ if (total_read == -1)
+ return -errno;
+
+ if (total_read == 0 && CMSG_FIRSTHDR(&msg) == NULL) {
+ *fd_count = 0;
+ return 0;
+ }
+
+ size_t fd_idx = 0;
+ struct cmsghdr *cmsg;
+ for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+ if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS)
+ continue;
+
+ size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+
+ int *cmsg_fds = (int *)CMSG_DATA(cmsg);
+ size_t cmsg_fd_idx;
+ for (cmsg_fd_idx = 0; cmsg_fd_idx < cmsg_fd_count; cmsg_fd_idx++) {
+ if (fd_idx < *fd_count) {
+ fds[fd_idx] = cmsg_fds[cmsg_fd_idx];
+ fd_idx++;
+ } else {
+ close(cmsg_fds[cmsg_fd_idx]);
+ }
+ }
+ }
+
+ *fd_count = fd_idx;
+
+ return total_read;
+}
diff --git a/sys_util/src/lib.rs b/sys_util/src/lib.rs
index 04e35b7..af9db2f 100644
--- a/sys_util/src/lib.rs
+++ b/sys_util/src/lib.rs
@@ -27,6 +27,7 @@
mod signal;
mod fork;
mod signalfd;
+mod sock_ctrl_msg;
pub use mmap::*;
pub use shm::*;
@@ -43,6 +44,7 @@
pub use fork::*;
pub use signalfd::*;
pub use ioctl::*;
+pub use sock_ctrl_msg::*;
pub use guest_memory::Error as GuestMemoryError;
pub use signalfd::Error as SignalFdError;
diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs
new file mode 100644
index 0000000..ddbfef8
--- /dev/null
+++ b/sys_util/src/sock_ctrl_msg.rs
@@ -0,0 +1,297 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::fs::File;
+use std::os::unix::io::{AsRawFd, RawFd, FromRawFd};
+use std::os::unix::net::{UnixDatagram, UnixStream};
+
+use libc::{iovec, c_void};
+
+use {Result, Error};
+
+// These functions are implemented in C because each of them requires complicated setup with CMSG
+// macros. These macros are part of the system headers and are required to be used for portability
+// reasons. In practice, the control message ABI can't change but using them is much easier and less
+// error prone than trying to port these macros to rust.
+extern "C" {
+ fn scm_cmsg_buffer_len(fd_count: usize) -> usize;
+ fn scm_sendmsg(fd: RawFd,
+ outv: *const iovec,
+ outv_count: usize,
+ cmsg_buffer: *mut u8,
+ fds: *const RawFd,
+ fd_count: usize)
+ -> isize;
+ fn scm_recvmsg(fd: RawFd,
+ outv: *mut iovec,
+ outv_count: usize,
+ cmsg_buffer: *mut u8,
+ fds: *mut RawFd,
+ fd_count: *mut usize)
+ -> isize;
+}
+
+fn cmsg_buffer_len(fd_count: usize) -> usize {
+ // Safe because this function has no side effects, touches no pointers, and never fails.
+ unsafe { scm_cmsg_buffer_len(fd_count) }
+}
+
+/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
+/// `recvmsg`.
+pub trait ScmSocket {
+ /// Gets the file descriptor of this socket.
+ fn socket_fd(&self) -> RawFd;
+}
+
+impl ScmSocket for UnixDatagram {
+ fn socket_fd(&self) -> RawFd {
+ self.as_raw_fd()
+ }
+}
+
+impl ScmSocket for UnixStream {
+ fn socket_fd(&self) -> RawFd {
+ self.as_raw_fd()
+ }
+}
+
+/// Used to send and receive messages with file descriptors on sockets that accept control messages
+/// (e.g. Unix domain sockets).
+pub struct Scm {
+ cmsg_buffer: Vec<u8>,
+ vecs: Vec<iovec>,
+ fds: Vec<RawFd>,
+}
+
+impl Scm {
+ /// Constructs a new Scm object with pre-allocated structures.
+ ///
+ /// # Arguments
+ ///
+ /// * `fd_count` - The maximum number of files that can be received per `recv` call.
+ pub fn new(fd_count: usize) -> Scm {
+ Scm {
+ cmsg_buffer: Vec::with_capacity(cmsg_buffer_len(fd_count)),
+ vecs: Vec::new(),
+ fds: vec![-1; fd_count],
+ }
+ }
+
+ /// Sends the given data and file descriptors over the given `socket`.
+ ///
+ /// On success, returns the number of bytes sent.
+ ///
+ /// # Arguments
+ ///
+ /// * `socket` - A socket that supports socket control messages.
+ /// * `bufs` - A list of buffers to send on the `socket`. These will not be copied before
+ /// `sendmsg` is called.
+ /// * `fds` - A list of file descriptors to be sent.
+ pub fn send<T: ScmSocket>(&mut self,
+ socket: &T,
+ bufs: &[&[u8]],
+ fds: &[RawFd])
+ -> Result<usize> {
+ let cmsg_buf_len = cmsg_buffer_len(fds.len());
+ self.cmsg_buffer.reserve(cmsg_buf_len);
+ self.vecs.clear();
+ for buf in bufs {
+ self.vecs
+ .push(iovec {
+ iov_base: buf.as_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ });
+ }
+ let write_count = unsafe {
+ // Safe because we are giving scm_sendmsg only valid pointers and lengths and we check
+ // the return value.
+ self.cmsg_buffer.set_len(cmsg_buf_len);
+ scm_sendmsg(socket.socket_fd(),
+ self.vecs.as_ptr(),
+ self.vecs.len(),
+ self.cmsg_buffer.as_mut_ptr(),
+ fds.as_ptr(),
+ fds.len())
+ };
+
+ if write_count < 0 {
+ Err(Error::new(write_count as i32))
+ } else {
+ Ok(write_count as usize)
+ }
+ }
+
+ /// Receives data and file descriptors from the given `socket` into the list of buffers.
+ ///
+ /// On success, returns the number of bytes received.
+ ///
+ /// # Arguments
+ ///
+ /// * `socket` - A socket that supports socket control messages.
+ /// * `bufs` - A list of buffers to receive data from the `socket`. The `recvmsg` call fills
+ /// these directly.
+ /// * `files` - A vector of `File`s to put the received file descriptors into. This vector is
+ /// not cleared and will have at most `fd_count` (specified in `Scm::new`) `File`s
+ /// added to it.
+ pub fn recv<T: ScmSocket>(&mut self,
+ socket: &T,
+ bufs: &mut [&mut [u8]],
+ files: &mut Vec<File>)
+ -> Result<usize> {
+ let cmsg_buf_len = cmsg_buffer_len(files.len());
+ self.cmsg_buffer.reserve(cmsg_buf_len);
+ self.vecs.clear();
+ for buf in bufs {
+ self.vecs
+ .push(iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ });
+ }
+ let mut fd_count = self.fds.len();
+ let read_count = unsafe {
+ // Safe because we are giving scm_recvmsg only valid pointers and lengths and we check
+ // the return value.
+ self.cmsg_buffer.set_len(cmsg_buf_len);
+ scm_recvmsg(socket.socket_fd(),
+ self.vecs.as_mut_ptr(),
+ self.vecs.len(),
+ self.cmsg_buffer.as_mut_ptr(),
+ self.fds.as_mut_ptr(),
+ &mut fd_count as *mut usize)
+ };
+
+ if read_count < 0 {
+ Err(Error::new(read_count as i32))
+ } else {
+ // Safe because we have unqiue ownership of each fd we wrap with File.
+ for &fd in &self.fds[0..fd_count] {
+ files.push(unsafe { File::from_raw_fd(fd) });
+ }
+ Ok(read_count as usize)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::io::Write;
+ use std::mem::size_of;
+ use std::os::raw::c_long;
+ use std::os::unix::net::UnixDatagram;
+ use std::slice::from_raw_parts;
+
+ use libc::cmsghdr;
+
+ use EventFd;
+
+ #[test]
+ fn buffer_len() {
+ assert_eq!(cmsg_buffer_len(0), size_of::<cmsghdr>());
+ assert_eq!(cmsg_buffer_len(1),
+ size_of::<cmsghdr>() + size_of::<c_long>());
+ if size_of::<RawFd>() == 4 {
+ assert_eq!(cmsg_buffer_len(2),
+ size_of::<cmsghdr>() + size_of::<c_long>());
+ assert_eq!(cmsg_buffer_len(3),
+ size_of::<cmsghdr>() + size_of::<c_long>() * 2);
+ assert_eq!(cmsg_buffer_len(4),
+ size_of::<cmsghdr>() + size_of::<c_long>() * 2);
+ } else if size_of::<RawFd>() == 8 {
+ assert_eq!(cmsg_buffer_len(2),
+ size_of::<cmsghdr>() + size_of::<c_long>() * 2);
+ assert_eq!(cmsg_buffer_len(3),
+ size_of::<cmsghdr>() + size_of::<c_long>() * 3);
+ assert_eq!(cmsg_buffer_len(4),
+ size_of::<cmsghdr>() + size_of::<c_long>() * 4);
+ }
+ }
+
+ #[test]
+ fn send_recv_no_fd() {
+ let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
+
+ let mut scm = Scm::new(1);
+ let write_count = scm.send(&s1,
+ [[1u8, 1, 2].as_ref(), [21, 34, 55].as_ref()].as_ref(),
+ &[])
+ .expect("failed to send data");
+
+ assert_eq!(write_count, 6);
+
+ let mut buf1 = [0; 3];
+ let mut buf2 = [0; 3];
+ let mut bufs = [buf1.as_mut(), buf2.as_mut()];
+ let mut files = Vec::new();
+ let read_count = scm.recv(&s2, &mut bufs[..], &mut files)
+ .expect("failed to recv data");
+
+ assert_eq!(read_count, 6);
+ assert!(files.is_empty());
+ assert_eq!(bufs[0], [1, 1, 2]);
+ assert_eq!(bufs[1], [21, 34, 55]);
+ }
+
+ #[test]
+ fn send_recv_only_fd() {
+ let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
+
+ let mut scm = Scm::new(1);
+ let evt = EventFd::new().expect("failed to create eventfd");
+ let write_count = scm.send(&s1, &[&[]], &[evt.as_raw_fd()])
+ .expect("failed to send fd");
+
+ assert_eq!(write_count, 0);
+
+ let mut files = Vec::new();
+ let read_count = scm.recv(&s2, &mut [&mut []], &mut files)
+ .expect("failed to recv fd");
+
+ assert_eq!(read_count, 0);
+ assert_eq!(files.len(), 1);
+ assert!(files[0].as_raw_fd() >= 0);
+ assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd());
+ assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd());
+ assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd());
+
+ files[0]
+ .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
+ .expect("failed to write to sent fd");
+
+ assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
+ }
+
+ #[test]
+ fn send_recv_with_fd() {
+ let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
+
+ let mut scm = Scm::new(1);
+ let evt = EventFd::new().expect("failed to create eventfd");
+ let write_count = scm.send(&s1, &[&[237]], &[evt.as_raw_fd()])
+ .expect("failed to send fd");
+
+ assert_eq!(write_count, 1);
+
+ let mut files = Vec::new();
+ let mut buf = [0u8];
+ let read_count = scm.recv(&s2, &mut [&mut buf], &mut files)
+ .expect("failed to recv fd");
+
+ assert_eq!(read_count, 1);
+ assert_eq!(buf[0], 237);
+ assert_eq!(files.len(), 1);
+ assert!(files[0].as_raw_fd() >= 0);
+ assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd());
+ assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd());
+ assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd());
+
+ files[0]
+ .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
+ .expect("failed to write to sent fd");
+
+ assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
+ }
+}