| // Copyright 2017 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use std::fs::File; |
| use std::os::unix::io::{AsRawFd, RawFd, FromRawFd}; |
| use std::os::unix::net::{UnixDatagram, UnixStream}; |
| |
| use libc::{iovec, c_void}; |
| |
| use {Result, Error}; |
| |
| // These functions are implemented in C because each of them requires complicated setup with CMSG |
| // macros. These macros are part of the system headers and are required to be used for portability |
| // reasons. In practice, the control message ABI can't change but using them is much easier and less |
| // error prone than trying to port these macros to rust. |
| extern "C" { |
| fn scm_cmsg_buffer_len(fd_count: usize) -> usize; |
| fn scm_sendmsg(fd: RawFd, |
| outv: *const iovec, |
| outv_count: usize, |
| cmsg_buffer: *mut u8, |
| fds: *const RawFd, |
| fd_count: usize) |
| -> isize; |
| fn scm_recvmsg(fd: RawFd, |
| outv: *mut iovec, |
| outv_count: usize, |
| cmsg_buffer: *mut u8, |
| fds: *mut RawFd, |
| fd_count: *mut usize) |
| -> isize; |
| } |
| |
| fn cmsg_buffer_len(fd_count: usize) -> usize { |
| // Safe because this function has no side effects, touches no pointers, and never fails. |
| unsafe { scm_cmsg_buffer_len(fd_count) } |
| } |
| |
| /// Trait for file descriptors can send and receive socket control messages via `sendmsg` and |
| /// `recvmsg`. |
| pub trait ScmSocket { |
| /// Gets the file descriptor of this socket. |
| fn socket_fd(&self) -> RawFd; |
| } |
| |
| impl ScmSocket for UnixDatagram { |
| fn socket_fd(&self) -> RawFd { |
| self.as_raw_fd() |
| } |
| } |
| |
| impl ScmSocket for UnixStream { |
| fn socket_fd(&self) -> RawFd { |
| self.as_raw_fd() |
| } |
| } |
| |
| /// Used to send and receive messages with file descriptors on sockets that accept control messages |
| /// (e.g. Unix domain sockets). |
| pub struct Scm { |
| cmsg_buffer: Vec<u8>, |
| vecs: Vec<iovec>, |
| fds: Vec<RawFd>, |
| } |
| |
| impl Scm { |
| /// Constructs a new Scm object with pre-allocated structures. |
| /// |
| /// # Arguments |
| /// |
| /// * `fd_count` - The maximum number of files that can be received per `recv` call. |
| pub fn new(fd_count: usize) -> Scm { |
| Scm { |
| cmsg_buffer: Vec::with_capacity(cmsg_buffer_len(fd_count)), |
| vecs: Vec::new(), |
| fds: vec![-1; fd_count], |
| } |
| } |
| |
| /// Sends the given data and file descriptors over the given `socket`. |
| /// |
| /// On success, returns the number of bytes sent. |
| /// |
| /// # Arguments |
| /// |
| /// * `socket` - A socket that supports socket control messages. |
| /// * `bufs` - A list of buffers to send on the `socket`. These will not be copied before |
| /// `sendmsg` is called. |
| /// * `fds` - A list of file descriptors to be sent. |
| pub fn send<T: ScmSocket>(&mut self, |
| socket: &T, |
| bufs: &[&[u8]], |
| fds: &[RawFd]) |
| -> Result<usize> { |
| let cmsg_buf_len = cmsg_buffer_len(fds.len()); |
| self.cmsg_buffer.reserve(cmsg_buf_len); |
| self.vecs.clear(); |
| for buf in bufs { |
| self.vecs |
| .push(iovec { |
| iov_base: buf.as_ptr() as *mut c_void, |
| iov_len: buf.len(), |
| }); |
| } |
| let write_count = unsafe { |
| // Safe because we are giving scm_sendmsg only valid pointers and lengths and we check |
| // the return value. |
| self.cmsg_buffer.set_len(cmsg_buf_len); |
| scm_sendmsg(socket.socket_fd(), |
| self.vecs.as_ptr(), |
| self.vecs.len(), |
| self.cmsg_buffer.as_mut_ptr(), |
| fds.as_ptr(), |
| fds.len()) |
| }; |
| |
| if write_count < 0 { |
| Err(Error::new(write_count as i32)) |
| } else { |
| Ok(write_count as usize) |
| } |
| } |
| |
| /// Receives data and file descriptors from the given `socket` into the list of buffers. |
| /// |
| /// On success, returns the number of bytes received. |
| /// |
| /// # Arguments |
| /// |
| /// * `socket` - A socket that supports socket control messages. |
| /// * `bufs` - A list of buffers to receive data from the `socket`. The `recvmsg` call fills |
| /// these directly. |
| /// * `files` - A vector of `File`s to put the received file descriptors into. This vector is |
| /// not cleared and will have at most `fd_count` (specified in `Scm::new`) `File`s |
| /// added to it. |
| pub fn recv<T: ScmSocket>(&mut self, |
| socket: &T, |
| bufs: &mut [&mut [u8]], |
| files: &mut Vec<File>) |
| -> Result<usize> { |
| let cmsg_buf_len = cmsg_buffer_len(files.len()); |
| self.cmsg_buffer.reserve(cmsg_buf_len); |
| self.vecs.clear(); |
| for buf in bufs { |
| self.vecs |
| .push(iovec { |
| iov_base: buf.as_mut_ptr() as *mut c_void, |
| iov_len: buf.len(), |
| }); |
| } |
| let mut fd_count = self.fds.len(); |
| let read_count = unsafe { |
| // Safe because we are giving scm_recvmsg only valid pointers and lengths and we check |
| // the return value. |
| self.cmsg_buffer.set_len(cmsg_buf_len); |
| scm_recvmsg(socket.socket_fd(), |
| self.vecs.as_mut_ptr(), |
| self.vecs.len(), |
| self.cmsg_buffer.as_mut_ptr(), |
| self.fds.as_mut_ptr(), |
| &mut fd_count as *mut usize) |
| }; |
| |
| if read_count < 0 { |
| Err(Error::new(read_count as i32)) |
| } else { |
| // Safe because we have unqiue ownership of each fd we wrap with File. |
| for &fd in &self.fds[0..fd_count] { |
| files.push(unsafe { File::from_raw_fd(fd) }); |
| } |
| Ok(read_count as usize) |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| |
| use std::io::Write; |
| use std::mem::size_of; |
| use std::os::raw::c_long; |
| use std::os::unix::net::UnixDatagram; |
| use std::slice::from_raw_parts; |
| |
| use libc::cmsghdr; |
| |
| use EventFd; |
| |
| #[test] |
| fn buffer_len() { |
| assert_eq!(cmsg_buffer_len(0), size_of::<cmsghdr>()); |
| assert_eq!(cmsg_buffer_len(1), |
| size_of::<cmsghdr>() + size_of::<c_long>()); |
| if size_of::<RawFd>() == 4 { |
| assert_eq!(cmsg_buffer_len(2), |
| size_of::<cmsghdr>() + size_of::<c_long>()); |
| assert_eq!(cmsg_buffer_len(3), |
| size_of::<cmsghdr>() + size_of::<c_long>() * 2); |
| assert_eq!(cmsg_buffer_len(4), |
| size_of::<cmsghdr>() + size_of::<c_long>() * 2); |
| } else if size_of::<RawFd>() == 8 { |
| assert_eq!(cmsg_buffer_len(2), |
| size_of::<cmsghdr>() + size_of::<c_long>() * 2); |
| assert_eq!(cmsg_buffer_len(3), |
| size_of::<cmsghdr>() + size_of::<c_long>() * 3); |
| assert_eq!(cmsg_buffer_len(4), |
| size_of::<cmsghdr>() + size_of::<c_long>() * 4); |
| } |
| } |
| |
| #[test] |
| fn send_recv_no_fd() { |
| let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); |
| |
| let mut scm = Scm::new(1); |
| let write_count = scm.send(&s1, |
| [[1u8, 1, 2].as_ref(), [21, 34, 55].as_ref()].as_ref(), |
| &[]) |
| .expect("failed to send data"); |
| |
| assert_eq!(write_count, 6); |
| |
| let mut buf1 = [0; 3]; |
| let mut buf2 = [0; 3]; |
| let mut bufs = [buf1.as_mut(), buf2.as_mut()]; |
| let mut files = Vec::new(); |
| let read_count = scm.recv(&s2, &mut bufs[..], &mut files) |
| .expect("failed to recv data"); |
| |
| assert_eq!(read_count, 6); |
| assert!(files.is_empty()); |
| assert_eq!(bufs[0], [1, 1, 2]); |
| assert_eq!(bufs[1], [21, 34, 55]); |
| } |
| |
| #[test] |
| fn send_recv_only_fd() { |
| let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); |
| |
| let mut scm = Scm::new(1); |
| let evt = EventFd::new().expect("failed to create eventfd"); |
| let write_count = scm.send(&s1, &[&[]], &[evt.as_raw_fd()]) |
| .expect("failed to send fd"); |
| |
| assert_eq!(write_count, 0); |
| |
| let mut files = Vec::new(); |
| let read_count = scm.recv(&s2, &mut [&mut []], &mut files) |
| .expect("failed to recv fd"); |
| |
| assert_eq!(read_count, 0); |
| assert_eq!(files.len(), 1); |
| assert!(files[0].as_raw_fd() >= 0); |
| assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd()); |
| assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd()); |
| assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd()); |
| |
| files[0] |
| .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) |
| .expect("failed to write to sent fd"); |
| |
| assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); |
| } |
| |
| #[test] |
| fn send_recv_with_fd() { |
| let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); |
| |
| let mut scm = Scm::new(1); |
| let evt = EventFd::new().expect("failed to create eventfd"); |
| let write_count = scm.send(&s1, &[&[237]], &[evt.as_raw_fd()]) |
| .expect("failed to send fd"); |
| |
| assert_eq!(write_count, 1); |
| |
| let mut files = Vec::new(); |
| let mut buf = [0u8]; |
| let read_count = scm.recv(&s2, &mut [&mut buf], &mut files) |
| .expect("failed to recv fd"); |
| |
| assert_eq!(read_count, 1); |
| assert_eq!(buf[0], 237); |
| assert_eq!(files.len(), 1); |
| assert!(files[0].as_raw_fd() >= 0); |
| assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd()); |
| assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd()); |
| assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd()); |
| |
| files[0] |
| .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) |
| .expect("failed to write to sent fd"); |
| |
| assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); |
| } |
| } |