sys_util: add sock_ctrl_msg module for transmitting messages with FD

This CL also includes the gcc build time dependency for building the
sock_ctrl_msg.c helper code.

TEST=cargo test
BUG=chromium:738638

Change-Id: I4adc2360b7fab4ed7d557603aa7bad2e738b69b4
Reviewed-on: https://chromium-review.googlesource.com/562574
Commit-Ready: Zach Reizner <zachr@chromium.org>
Tested-by: Zach Reizner <zachr@chromium.org>
Reviewed-by: Chirantan Ekbote <chirantan@chromium.org>
Reviewed-by: Dylan Reid <dgreid@chromium.org>
diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml
index df54b13..bc9e1f9 100644
--- a/sys_util/Cargo.toml
+++ b/sys_util/Cargo.toml
@@ -2,8 +2,12 @@
 name = "sys_util"
 version = "0.1.0"
 authors = ["The Chromium OS Authors"]
+build = "build.rs"
 
 [dependencies]
 data_model = { path = "../data_model" }
 libc = "*"
-syscall_defines = { path = "../syscall_defines" }
\ No newline at end of file
+syscall_defines = { path = "../syscall_defines" }
+
+[build-dependencies]
+gcc = "0.3"
diff --git a/sys_util/build.rs b/sys_util/build.rs
new file mode 100644
index 0000000..ad96135
--- /dev/null
+++ b/sys_util/build.rs
@@ -0,0 +1,9 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+extern crate gcc;
+
+fn main() {
+    gcc::compile_library("libsock_ctrl_msg.a", &["sock_ctrl_msg.c"]);
+}
diff --git a/sys_util/sock_ctrl_msg.c b/sys_util/sock_ctrl_msg.c
new file mode 100644
index 0000000..15b7be3
--- /dev/null
+++ b/sys_util/sock_ctrl_msg.c
@@ -0,0 +1,123 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <stdint.h>
+#include <string.h> // memcpy
+#include <sys/errno.h>
+#include <sys/socket.h> // CMSG_*
+
+/*
+ * Returns the number of bytes the `cmsg_buffer` must be for the functions that take a cmsg_buffer
+ * in this module.
+ * Arguments:
+ *    * `fd_count` - Maximum number of file descriptors to be sent or received via the cmsg.
+ */
+size_t scm_cmsg_buffer_len(size_t fd_count)
+{
+    return CMSG_SPACE(sizeof(int) * fd_count);
+}
+
+/*
+ * Convenience wrapper around `sendmsg` that builds up the `msghdr` structure for you given the
+ * array of fds.
+ * Arguments:
+ *   * `fd` - Unix domain socket to `sendmsg` on.
+ *   * `outv` - Array of `outv_count` length `iovec`s that contain the data to send.
+ *   * `outv_count` - Number of elements in `outv` array.
+ *   * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(fd_count)` bytes long.
+ *   * `fds` - Array of `fd_count` file descriptors to send along with data.
+ *   * `fd_count` - Number of elements in `fds` array.
+ * Returns:
+ * A non-negative number indicating how many bytes were sent on success or a negative errno on
+ * failure.
+ */
+ssize_t scm_sendmsg(int fd, const struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer,
+                    const int *fds, size_t fd_count)
+{
+    if (fd < 0 || ((!cmsg_buffer || !fds) && fd_count > 0))
+        return -EINVAL;
+
+    struct msghdr msg = {0};
+    msg.msg_iov = (struct iovec *)outv; // discard const, sendmsg won't mutate it
+    msg.msg_iovlen = outv_count;
+
+    if (fd_count) {
+        msg.msg_control = cmsg_buffer;
+        msg.msg_controllen = scm_cmsg_buffer_len(fd_count);
+
+        struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
+        cmsg->cmsg_level = SOL_SOCKET;
+        cmsg->cmsg_type = SCM_RIGHTS;
+        cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int));
+        memcpy(CMSG_DATA(cmsg), fds, fd_count * sizeof(int));
+
+        msg.msg_controllen = cmsg->cmsg_len;
+    }
+
+    ssize_t bytes_sent = sendmsg(fd, &msg, MSG_NOSIGNAL);
+    if (bytes_sent == -1)
+        return -errno;
+
+    return bytes_sent;
+}
+
+/*
+ * Convenience wrapper around `recvmsg` that builds up the `msghdr` structure and returns up to
+ * `*fd_count` file descriptors in the given `fds` array.
+ * Arguments:
+ *   * `fd` - Unix domain socket to `recvmsg` on.
+ *   * `outv` - Array of `outv_count` length `iovec`s that will contain the received data.
+ *   * `outv_count` - Number of elements in `outv` array.
+ *   * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(*fd_count)` bytes long.
+ *   * `fds` - Array of `fd_count` file descriptors to receive along with data.
+ *   * `fd_count` - Number of elements in `fds` array.
+ * Returns:
+ * A non-negative number indicating how many bytes were received on success or a negative errno on
+ * failure.
+ */
+ssize_t scm_recvmsg(int fd, struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, int *fds,
+                    size_t *fd_count)
+{
+    if (fd < 0 || !cmsg_buffer || !fds || !fd_count)
+        return -EINVAL;
+
+    struct msghdr msg = {0};
+    msg.msg_iov = outv;
+    msg.msg_iovlen = outv_count;
+    msg.msg_control = cmsg_buffer;
+    msg.msg_controllen = scm_cmsg_buffer_len(*fd_count);
+
+    ssize_t total_read = recvmsg(fd, &msg, 0);
+    if (total_read == -1)
+        return -errno;
+
+    if (total_read == 0 && CMSG_FIRSTHDR(&msg) == NULL) {
+        *fd_count = 0;
+        return  0;
+    }
+
+    size_t fd_idx = 0;
+    struct cmsghdr *cmsg;
+    for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+        if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS)
+            continue;
+
+        size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+
+        int *cmsg_fds = (int *)CMSG_DATA(cmsg);
+        size_t cmsg_fd_idx;
+        for (cmsg_fd_idx = 0; cmsg_fd_idx < cmsg_fd_count; cmsg_fd_idx++) {
+            if (fd_idx < *fd_count) {
+                fds[fd_idx] = cmsg_fds[cmsg_fd_idx];
+                fd_idx++;
+            } else {
+                close(cmsg_fds[cmsg_fd_idx]);
+            }
+        }
+    }
+
+    *fd_count = fd_idx;
+
+    return total_read;
+}
diff --git a/sys_util/src/lib.rs b/sys_util/src/lib.rs
index 04e35b7..af9db2f 100644
--- a/sys_util/src/lib.rs
+++ b/sys_util/src/lib.rs
@@ -27,6 +27,7 @@
 mod signal;
 mod fork;
 mod signalfd;
+mod sock_ctrl_msg;
 
 pub use mmap::*;
 pub use shm::*;
@@ -43,6 +44,7 @@
 pub use fork::*;
 pub use signalfd::*;
 pub use ioctl::*;
+pub use sock_ctrl_msg::*;
 
 pub use guest_memory::Error as GuestMemoryError;
 pub use signalfd::Error as SignalFdError;
diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs
new file mode 100644
index 0000000..ddbfef8
--- /dev/null
+++ b/sys_util/src/sock_ctrl_msg.rs
@@ -0,0 +1,297 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::fs::File;
+use std::os::unix::io::{AsRawFd, RawFd, FromRawFd};
+use std::os::unix::net::{UnixDatagram, UnixStream};
+
+use libc::{iovec, c_void};
+
+use {Result, Error};
+
+// These functions are implemented in C because each of them requires complicated setup with CMSG
+// macros. These macros are part of the system headers and are required to be used for portability
+// reasons. In practice, the control message ABI can't change but using them is much easier and less
+// error prone than trying to port these macros to rust.
+extern "C" {
+    fn scm_cmsg_buffer_len(fd_count: usize) -> usize;
+    fn scm_sendmsg(fd: RawFd,
+                   outv: *const iovec,
+                   outv_count: usize,
+                   cmsg_buffer: *mut u8,
+                   fds: *const RawFd,
+                   fd_count: usize)
+                   -> isize;
+    fn scm_recvmsg(fd: RawFd,
+                   outv: *mut iovec,
+                   outv_count: usize,
+                   cmsg_buffer: *mut u8,
+                   fds: *mut RawFd,
+                   fd_count: *mut usize)
+                   -> isize;
+}
+
+fn cmsg_buffer_len(fd_count: usize) -> usize {
+    // Safe because this function has no side effects, touches no pointers, and never fails.
+    unsafe { scm_cmsg_buffer_len(fd_count) }
+}
+
+/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
+/// `recvmsg`.
+pub trait ScmSocket {
+    /// Gets the file descriptor of this socket.
+    fn socket_fd(&self) -> RawFd;
+}
+
+impl ScmSocket for UnixDatagram {
+    fn socket_fd(&self) -> RawFd {
+        self.as_raw_fd()
+    }
+}
+
+impl ScmSocket for UnixStream {
+    fn socket_fd(&self) -> RawFd {
+        self.as_raw_fd()
+    }
+}
+
+/// Used to send and receive messages with file descriptors on sockets that accept control messages
+/// (e.g. Unix domain sockets).
+pub struct Scm {
+    cmsg_buffer: Vec<u8>,
+    vecs: Vec<iovec>,
+    fds: Vec<RawFd>,
+}
+
+impl Scm {
+    /// Constructs a new Scm object with pre-allocated structures.
+    ///
+    /// # Arguments
+    ///
+    /// * `fd_count` - The maximum number of files that can be received per `recv` call.
+    pub fn new(fd_count: usize) -> Scm {
+        Scm {
+            cmsg_buffer: Vec::with_capacity(cmsg_buffer_len(fd_count)),
+            vecs: Vec::new(),
+            fds: vec![-1; fd_count],
+        }
+    }
+
+    /// Sends the given data and file descriptors over the given `socket`.
+    ///
+    /// On success, returns the number of bytes sent.
+    ///
+    /// # Arguments
+    ///
+    /// * `socket` - A socket that supports socket control messages.
+    /// * `bufs` - A list of buffers to send on the `socket`. These will not be copied before
+    ///            `sendmsg` is called.
+    /// * `fds` - A list of file descriptors to be sent.
+    pub fn send<T: ScmSocket>(&mut self,
+                              socket: &T,
+                              bufs: &[&[u8]],
+                              fds: &[RawFd])
+                              -> Result<usize> {
+        let cmsg_buf_len = cmsg_buffer_len(fds.len());
+        self.cmsg_buffer.reserve(cmsg_buf_len);
+        self.vecs.clear();
+        for buf in bufs {
+            self.vecs
+                .push(iovec {
+                          iov_base: buf.as_ptr() as *mut c_void,
+                          iov_len: buf.len(),
+                      });
+        }
+        let write_count = unsafe {
+            // Safe because we are giving scm_sendmsg only valid pointers and lengths and we check
+            // the return value.
+            self.cmsg_buffer.set_len(cmsg_buf_len);
+            scm_sendmsg(socket.socket_fd(),
+                        self.vecs.as_ptr(),
+                        self.vecs.len(),
+                        self.cmsg_buffer.as_mut_ptr(),
+                        fds.as_ptr(),
+                        fds.len())
+        };
+
+        if write_count < 0 {
+            Err(Error::new(write_count as i32))
+        } else {
+            Ok(write_count as usize)
+        }
+    }
+
+    /// Receives data and file descriptors from the given `socket` into the list of buffers.
+    ///
+    /// On success, returns the number of bytes received.
+    ///
+    /// # Arguments
+    ///
+    /// * `socket` - A socket that supports socket control messages.
+    /// * `bufs` - A list of buffers to receive data from the `socket`. The `recvmsg` call fills
+    ///            these directly.
+    /// * `files` - A vector of `File`s to put the received file descriptors into. This vector is
+    ///             not cleared and will have at most `fd_count` (specified in `Scm::new`) `File`s
+    ///             added to it.
+    pub fn recv<T: ScmSocket>(&mut self,
+                              socket: &T,
+                              bufs: &mut [&mut [u8]],
+                              files: &mut Vec<File>)
+                              -> Result<usize> {
+        let cmsg_buf_len = cmsg_buffer_len(files.len());
+        self.cmsg_buffer.reserve(cmsg_buf_len);
+        self.vecs.clear();
+        for buf in bufs {
+            self.vecs
+                .push(iovec {
+                          iov_base: buf.as_mut_ptr() as *mut c_void,
+                          iov_len: buf.len(),
+                      });
+        }
+        let mut fd_count = self.fds.len();
+        let read_count = unsafe {
+            // Safe because we are giving scm_recvmsg only valid pointers and lengths and we check
+            // the return value.
+            self.cmsg_buffer.set_len(cmsg_buf_len);
+            scm_recvmsg(socket.socket_fd(),
+                        self.vecs.as_mut_ptr(),
+                        self.vecs.len(),
+                        self.cmsg_buffer.as_mut_ptr(),
+                        self.fds.as_mut_ptr(),
+                        &mut fd_count as *mut usize)
+        };
+
+        if read_count < 0 {
+            Err(Error::new(read_count as i32))
+        } else {
+            // Safe because we have unqiue ownership of each fd we wrap with File.
+            for &fd in &self.fds[0..fd_count] {
+                files.push(unsafe { File::from_raw_fd(fd) });
+            }
+            Ok(read_count as usize)
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    use std::io::Write;
+    use std::mem::size_of;
+    use std::os::raw::c_long;
+    use std::os::unix::net::UnixDatagram;
+    use std::slice::from_raw_parts;
+
+    use libc::cmsghdr;
+
+    use EventFd;
+
+    #[test]
+    fn buffer_len() {
+        assert_eq!(cmsg_buffer_len(0), size_of::<cmsghdr>());
+        assert_eq!(cmsg_buffer_len(1),
+                   size_of::<cmsghdr>() + size_of::<c_long>());
+        if size_of::<RawFd>() == 4 {
+            assert_eq!(cmsg_buffer_len(2),
+                       size_of::<cmsghdr>() + size_of::<c_long>());
+            assert_eq!(cmsg_buffer_len(3),
+                       size_of::<cmsghdr>() + size_of::<c_long>() * 2);
+            assert_eq!(cmsg_buffer_len(4),
+                       size_of::<cmsghdr>() + size_of::<c_long>() * 2);
+        } else if size_of::<RawFd>() == 8 {
+            assert_eq!(cmsg_buffer_len(2),
+                       size_of::<cmsghdr>() + size_of::<c_long>() * 2);
+            assert_eq!(cmsg_buffer_len(3),
+                       size_of::<cmsghdr>() + size_of::<c_long>() * 3);
+            assert_eq!(cmsg_buffer_len(4),
+                       size_of::<cmsghdr>() + size_of::<c_long>() * 4);
+        }
+    }
+
+    #[test]
+    fn send_recv_no_fd() {
+        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
+
+        let mut scm = Scm::new(1);
+        let write_count = scm.send(&s1,
+                                   [[1u8, 1, 2].as_ref(), [21, 34, 55].as_ref()].as_ref(),
+                                   &[])
+            .expect("failed to send data");
+
+        assert_eq!(write_count, 6);
+
+        let mut buf1 = [0; 3];
+        let mut buf2 = [0; 3];
+        let mut bufs = [buf1.as_mut(), buf2.as_mut()];
+        let mut files = Vec::new();
+        let read_count = scm.recv(&s2, &mut bufs[..], &mut files)
+            .expect("failed to recv data");
+
+        assert_eq!(read_count, 6);
+        assert!(files.is_empty());
+        assert_eq!(bufs[0], [1, 1, 2]);
+        assert_eq!(bufs[1], [21, 34, 55]);
+    }
+
+    #[test]
+    fn send_recv_only_fd() {
+        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
+
+        let mut scm = Scm::new(1);
+        let evt = EventFd::new().expect("failed to create eventfd");
+        let write_count = scm.send(&s1, &[&[]], &[evt.as_raw_fd()])
+            .expect("failed to send fd");
+
+        assert_eq!(write_count, 0);
+
+        let mut files = Vec::new();
+        let read_count = scm.recv(&s2, &mut [&mut []], &mut files)
+            .expect("failed to recv fd");
+
+        assert_eq!(read_count, 0);
+        assert_eq!(files.len(), 1);
+        assert!(files[0].as_raw_fd() >= 0);
+        assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd());
+        assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd());
+        assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd());
+
+        files[0]
+            .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
+            .expect("failed to write to sent fd");
+
+        assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
+    }
+
+    #[test]
+    fn send_recv_with_fd() {
+        let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
+
+        let mut scm = Scm::new(1);
+        let evt = EventFd::new().expect("failed to create eventfd");
+        let write_count = scm.send(&s1, &[&[237]], &[evt.as_raw_fd()])
+            .expect("failed to send fd");
+
+        assert_eq!(write_count, 1);
+
+        let mut files = Vec::new();
+        let mut buf = [0u8];
+        let read_count = scm.recv(&s2, &mut [&mut buf], &mut files)
+            .expect("failed to recv fd");
+
+        assert_eq!(read_count, 1);
+        assert_eq!(buf[0], 237);
+        assert_eq!(files.len(), 1);
+        assert!(files[0].as_raw_fd() >= 0);
+        assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd());
+        assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd());
+        assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd());
+
+        files[0]
+            .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
+            .expect("failed to write to sent fd");
+
+        assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
+    }
+}