blob: 8c08217e16f82ebf93e8e56dfec8e682fd265a42 [file] [log] [blame] [edit]
//! Wayland socket manipulation
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use nix::{
sys::{socket, uio},
Result as NixResult,
};
use crate::wire::{ArgumentType, Message, MessageParseError, MessageWriteError};
/// Maximum number of FD that can be sent in a single socket message
pub const MAX_FDS_OUT: usize = 28;
/// Maximum number of bytes that can be sent in a single socket message
pub const MAX_BYTES_OUT: usize = 4096;
/*
* Socket
*/
/// A wayland socket
#[derive(Debug)]
pub struct Socket {
fd: RawFd,
}
impl Socket {
/// Send a single message to the socket
///
/// A single socket message can contain several wayland messages
///
/// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
/// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
/// end may lose some data.
pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> NixResult<()> {
let iov = [uio::IoVec::from_slice(bytes)];
if !fds.is_empty() {
let cmsgs = [socket::ControlMessage::ScmRights(fds)];
socket::sendmsg(self.fd, &iov, &cmsgs, socket::MsgFlags::MSG_DONTWAIT, None)?;
} else {
socket::sendmsg(self.fd, &iov, &[], socket::MsgFlags::MSG_DONTWAIT, None)?;
};
Ok(())
}
/// Receive a single message from the socket
///
/// Return the number of bytes received and the number of Fds received.
///
/// Errors with `WouldBlock` is no message is available.
///
/// A single socket message can contain several wayland messages.
///
/// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
/// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
/// be lost.
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> NixResult<(usize, usize)> {
let mut cmsg = cmsg_space!([RawFd; MAX_FDS_OUT]);
let iov = [uio::IoVec::from_mut_slice(buffer)];
let msg = socket::recvmsg(
self.fd,
&iov[..],
Some(&mut cmsg),
socket::MsgFlags::MSG_DONTWAIT | socket::MsgFlags::MSG_CMSG_CLOEXEC,
)?;
let mut fd_count = 0;
let received_fds = msg.cmsgs().flat_map(|cmsg| match cmsg {
socket::ControlMessageOwned::ScmRights(s) => s,
_ => Vec::new(),
});
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd;
}
Ok((msg.bytes, fd_count))
}
/// Retrieve the current value of the requested [socket::GetSockOpt]
pub fn opt<O: socket::GetSockOpt>(&self, opt: O) -> NixResult<O::Val> {
socket::getsockopt(self.fd, opt)
}
}
impl FromRawFd for Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Socket {
Socket { fd }
}
}
impl AsRawFd for Socket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl IntoRawFd for Socket {
fn into_raw_fd(self) -> RawFd {
self.fd
}
}
impl Drop for Socket {
fn drop(&mut self) {
let _ = ::nix::unistd::close(self.fd);
}
}
/*
* BufferedSocket
*/
/// An adapter around a raw Socket that directly handles buffering and
/// conversion from/to wayland messages
#[derive(Debug)]
pub struct BufferedSocket {
socket: Socket,
in_data: Buffer<u32>,
in_fds: Buffer<RawFd>,
out_data: Buffer<u32>,
out_fds: Buffer<RawFd>,
}
impl BufferedSocket {
/// Wrap a Socket into a Buffered Socket
pub fn new(socket: Socket) -> BufferedSocket {
BufferedSocket {
socket,
in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), // Incoming buffers are twice as big in order to be
in_fds: Buffer::new(2 * MAX_FDS_OUT), // able to store leftover data if needed
out_data: Buffer::new(MAX_BYTES_OUT / 4),
out_fds: Buffer::new(MAX_FDS_OUT),
}
}
/// Get direct access to the underlying socket
pub fn get_socket(&mut self) -> &mut Socket {
&mut self.socket
}
/// Retrieve ownership of the underlying Socket
///
/// Any leftover content in the internal buffers will be lost
pub fn into_socket(self) -> Socket {
self.socket
}
/// Flush the contents of the outgoing buffer into the socket
pub fn flush(&mut self) -> NixResult<()> {
{
let words = self.out_data.get_contents();
if words.is_empty() {
return Ok(());
}
let bytes = unsafe {
::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4)
};
let fds = self.out_fds.get_contents();
self.socket.send_msg(bytes, fds)?;
for &fd in fds {
// once the fds are sent, we can close them
let _ = ::nix::unistd::close(fd);
}
}
self.out_data.clear();
self.out_fds.clear();
Ok(())
}
// internal method
//
// attempts to write a message in the internal out buffers,
// returns true if successful
//
// if false is returned, it means there is not enough space
// in the buffer
fn attempt_write_message(&mut self, msg: &Message) -> NixResult<bool> {
match msg.write_to_buffers(
self.out_data.get_writable_storage(),
self.out_fds.get_writable_storage(),
) {
Ok((bytes_out, fds_out)) => {
self.out_data.advance(bytes_out);
self.out_fds.advance(fds_out);
Ok(true)
}
Err(MessageWriteError::BufferTooSmall) => Ok(false),
Err(MessageWriteError::DupFdFailed(e)) => Err(e),
}
}
/// Write a message to the outgoing buffer
///
/// This method may flush the internal buffer if necessary (if it is full).
///
/// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
/// will be returned.
pub fn write_message(&mut self, msg: &Message) -> NixResult<()> {
if !self.attempt_write_message(msg)? {
// the attempt failed, there is not enough space in the buffer
// we need to flush it
self.flush()?;
if !self.attempt_write_message(msg)? {
// If this fails again, this means the message is too big
// to be transmitted at all
return Err(::nix::Error::E2BIG);
}
}
Ok(())
}
/// Try to fill the incoming buffers of this socket, to prepare
/// a new round of parsing.
pub fn fill_incoming_buffers(&mut self) -> NixResult<()> {
// clear the buffers if they have no content
if !self.in_data.has_content() {
self.in_data.clear();
}
if !self.in_fds.has_content() {
self.in_fds.clear();
}
// receive a message
let (in_bytes, in_fds) = {
let words = self.in_data.get_writable_storage();
let bytes = unsafe {
::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4)
};
let fds = self.in_fds.get_writable_storage();
self.socket.rcv_msg(bytes, fds)?
};
if in_bytes == 0 {
// the other end of the socket was closed
return Err(::nix::Error::EPIPE);
}
// advance the storage
self.in_data.advance(in_bytes / 4 + if in_bytes % 4 > 0 { 1 } else { 0 });
self.in_fds.advance(in_fds);
Ok(())
}
/// Read and deserialize a single message from the incoming buffers socket
///
/// This method requires one closure that given an object id and an opcode,
/// must provide the signature of the associated request/event, in the form of
/// a `&'static [ArgumentType]`. If it returns `None`, meaning that
/// the couple object/opcode does not exist, an error will be returned.
///
/// There are 3 possibilities of return value:
///
/// - `Ok(Ok(msg))`: no error occurred, this is the message
/// - `Ok(Err(e))`: either a malformed message was encountered or we need more data,
/// in the latter case you need to try calling `fill_incoming_buffers()`.
/// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
/// (this can be a "wouldblock" error, which just means that no message is available
/// to read)
pub fn read_one_message<F>(&mut self, mut signature: F) -> Result<Message, MessageParseError>
where
F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
{
let (msg, read_data, read_fd) = {
let data = self.in_data.get_contents();
let fds = self.in_fds.get_contents();
if data.len() < 2 {
return Err(MessageParseError::MissingData);
}
let object_id = data[0];
let opcode = (data[1] & 0x0000_FFFF) as u16;
if let Some(sig) = signature(object_id, opcode) {
match Message::from_raw(data, sig, fds) {
Ok((msg, rest_data, rest_fds)) => {
(msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
}
// TODO: gracefully handle wayland messages split across unix messages ?
Err(e) => return Err(e),
}
} else {
// no signature found ?
return Err(MessageParseError::Malformed);
}
};
self.in_data.offset(read_data);
self.in_fds.offset(read_fd);
Ok(msg)
}
/// Read and deserialize messages from the socket
///
/// This method requires two closures:
///
/// - The first one, given an object id and an opcode, must provide
/// the signature of the associated request/event, in the form of
/// a `&'static [ArgumentType]`. If it returns `None`, meaning that
/// the couple object/opcode does not exist, the parsing will be
/// prematurely interrupted and this method will return a
/// `MessageParseError::Malformed` error.
/// - The second closure is charged to process the parsed message. If it
/// returns `false`, the iteration will be prematurely stopped.
///
/// In both cases of early stopping, the remaining unused data will be left
/// in the buffers, and will start to be processed at the next call of this
/// method.
///
/// There are 3 possibilities of return value:
///
/// - `Ok(Ok(n))`: no error occurred, `n` messages where processed
/// - `Ok(Err(MessageParseError::Malformed))`: a malformed message was encountered
/// (this is a protocol error and is supposed to be fatal to the connection).
/// - `Err(e)`: an I/O error occurred reading from the socked, details are in `e`
/// (this can be a "wouldblock" error, which just means that no message is available
/// to read)
pub fn read_messages<F1, F2>(
&mut self,
mut signature: F1,
mut callback: F2,
) -> NixResult<Result<usize, MessageParseError>>
where
F1: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
F2: FnMut(Message) -> bool,
{
// message parsing
let mut dispatched = 0;
loop {
let mut err = None;
// first parse any leftover messages
loop {
match self.read_one_message(&mut signature) {
Ok(msg) => {
let keep_going = callback(msg);
dispatched += 1;
if !keep_going {
break;
}
}
Err(e) => {
err = Some(e);
break;
}
}
}
// copy back any leftover content to the front of the buffer
self.in_data.move_to_front();
self.in_fds.move_to_front();
if let Some(MessageParseError::Malformed) = err {
// early stop here
return Ok(Err(MessageParseError::Malformed));
}
if err.is_none() && self.in_data.has_content() {
// we stopped reading without error while there is content? That means
// the user requested an early stopping
return Ok(Ok(dispatched));
}
// now, try to get more data
match self.fill_incoming_buffers() {
Ok(()) => (),
Err(e @ ::nix::Error::EAGAIN) => {
// stop looping, returning Ok() or EAGAIN depending on whether messages
// were dispatched
if dispatched == 0 {
return Err(e);
} else {
break;
}
}
Err(e) => return Err(e),
}
}
Ok(Ok(dispatched))
}
}
/*
* Buffer
*/
#[derive(Debug)]
struct Buffer<T: Copy> {
storage: Vec<T>,
occupied: usize,
offset: usize,
}
impl<T: Copy + Default> Buffer<T> {
fn new(size: usize) -> Buffer<T> {
Buffer { storage: vec![T::default(); size], occupied: 0, offset: 0 }
}
/// Check if this buffer has content to read
fn has_content(&self) -> bool {
self.occupied > self.offset
}
/// Advance the internal counter of occupied space
fn advance(&mut self, bytes: usize) {
self.occupied += bytes;
}
/// Advance the read offset of current occupied space
fn offset(&mut self, bytes: usize) {
self.offset += bytes;
}
/// Clears the contents of the buffer
///
/// This only sets the counter of occupied space back to zero,
/// allowing previous content to be overwritten.
fn clear(&mut self) {
self.occupied = 0;
self.offset = 0;
}
/// Get the current contents of the occupied space of the buffer
fn get_contents(&self) -> &[T] {
&self.storage[(self.offset)..(self.occupied)]
}
/// Get mutable access to the unoccupied space of the buffer
fn get_writable_storage(&mut self) -> &mut [T] {
&mut self.storage[(self.occupied)..]
}
/// Move the unread contents of the buffer to the front, to ensure
/// maximal write space availability
fn move_to_front(&mut self) {
unsafe {
::std::ptr::copy(
&self.storage[self.offset] as *const T,
&mut self.storage[0] as *mut T,
self.occupied - self.offset,
);
}
self.occupied -= self.offset;
self.offset = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wire::{Argument, ArgumentType, Message};
use std::ffi::CString;
use smallvec::smallvec;
fn same_file(a: RawFd, b: RawFd) -> bool {
let stat1 = ::nix::sys::stat::fstat(a).unwrap();
let stat2 = ::nix::sys::stat::fstat(b).unwrap();
stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
}
// check if two messages are equal
//
// if arguments contain FDs, check that the fd point to
// the same file, rather than are the same number.
fn assert_eq_msgs(msg1: &Message, msg2: &Message) {
assert_eq!(msg1.sender_id, msg2.sender_id);
assert_eq!(msg1.opcode, msg2.opcode);
assert_eq!(msg1.args.len(), msg2.args.len());
for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
if let (&Argument::Fd(fd1), &Argument::Fd(fd2)) = (arg1, arg2) {
assert!(same_file(fd1, fd2));
} else {
assert_eq!(arg1, arg2);
}
}
}
#[test]
fn write_read_cycle() {
let msg = Message {
sender_id: 42,
opcode: 7,
args: smallvec![
Argument::Uint(3),
Argument::Fixed(-89),
Argument::Str(Box::new(CString::new(&b"I like trains!"[..]).unwrap())),
Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
Argument::Object(88),
Argument::NewId(56),
Argument::Int(-25),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &'static [ArgumentType] = &[
ArgumentType::Uint,
ArgumentType::Fixed,
ArgumentType::Str,
ArgumentType::Array,
ArgumentType::Object,
ArgumentType::NewId,
ArgumentType::Int,
];
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 42 && opcode == 7 {
Some(SIGNATURE)
} else {
None
}
},
|message| {
assert_eq_msgs(&message, &msg);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 1);
}
#[test]
fn write_read_cycle_fd() {
let msg = Message {
sender_id: 42,
opcode: 7,
args: smallvec![
Argument::Fd(1), // stdin
Argument::Fd(0), // stdout
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &'static [ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 42 && opcode == 7 {
Some(SIGNATURE)
} else {
None
}
},
|message| {
assert_eq_msgs(&message, &msg);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 1);
}
#[test]
fn write_read_cycle_multiple() {
let messages = [
Message {
sender_id: 42,
opcode: 0,
args: smallvec![
Argument::Int(42),
Argument::Str(Box::new(CString::new(&b"I like trains"[..]).unwrap())),
],
},
Message {
sender_id: 42,
opcode: 1,
args: smallvec![
Argument::Fd(1), // stdin
Argument::Fd(0), // stdout
],
},
Message {
sender_id: 42,
opcode: 2,
args: smallvec![
Argument::Uint(3),
Argument::Fd(2), // stderr
],
},
];
static SIGNATURES: &'static [&'static [ArgumentType]] = &[
&[ArgumentType::Int, ArgumentType::Str],
&[ArgumentType::Fd, ArgumentType::Fd],
&[ArgumentType::Uint, ArgumentType::Fd],
];
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
for msg in &messages {
client.write_message(msg).unwrap();
}
client.flush().unwrap();
let mut recv_msgs = Vec::new();
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 42 {
Some(SIGNATURES[opcode as usize])
} else {
None
}
},
|message| {
recv_msgs.push(message);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 3);
assert_eq!(recv_msgs.len(), 3);
for (msg1, msg2) in messages.iter().zip(recv_msgs.iter()) {
assert_eq_msgs(msg1, msg2);
}
}
#[test]
fn parse_with_string_len_multiple_of_4() {
let msg = Message {
sender_id: 2,
opcode: 0,
args: smallvec![
Argument::Uint(18),
Argument::Str(Box::new(CString::new(&b"wl_shell"[..]).unwrap())),
Argument::Uint(1),
],
};
let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
let mut client = BufferedSocket::new(unsafe { Socket::from_raw_fd(client.into_raw_fd()) });
let mut server = BufferedSocket::new(unsafe { Socket::from_raw_fd(server.into_raw_fd()) });
client.write_message(&msg).unwrap();
client.flush().unwrap();
static SIGNATURE: &'static [ArgumentType] =
&[ArgumentType::Uint, ArgumentType::Str, ArgumentType::Uint];
let ret = server
.read_messages(
|sender_id, opcode| {
if sender_id == 2 && opcode == 0 {
Some(SIGNATURE)
} else {
None
}
},
|message| {
assert_eq_msgs(&message, &msg);
true
},
)
.unwrap()
.unwrap();
assert_eq!(ret, 1);
}
}