1
0
Fork 0
forked from wry/wry

autocommit 2022-01-02 15:13:33 CET

This commit is contained in:
Julian Orth 2022-01-02 15:13:33 +01:00
commit d6172b273f
50 changed files with 5807 additions and 0 deletions

105
src/utils/buffd/buf_in.rs Normal file
View file

@ -0,0 +1,105 @@
use crate::async_engine::AsyncFd;
use crate::utils::buffd::{BufFdError, BUF_SIZE, CMSG_BUF_SIZE, MAX_IN_FD};
use std::collections::VecDeque;
use std::mem::MaybeUninit;
use uapi::{c, Errno, OwnedFd, Pod};
pub struct BufFdIn {
fd: AsyncFd,
in_fd: VecDeque<OwnedFd>,
in_buf: Box<[MaybeUninit<u8>; BUF_SIZE]>,
in_cmsg_buf: Box<[MaybeUninit<u8>; CMSG_BUF_SIZE]>,
in_left: usize,
in_right: usize,
}
impl BufFdIn {
pub fn new(fd: AsyncFd) -> Self {
Self {
fd,
in_fd: Default::default(),
in_buf: Box::new([MaybeUninit::uninit(); BUF_SIZE]),
in_cmsg_buf: Box::new([MaybeUninit::uninit(); CMSG_BUF_SIZE]),
in_left: 0,
in_right: 0,
}
}
pub async fn read_full<T: Pod + ?Sized>(&mut self, buf: &mut T) -> Result<(), BufFdError> {
let bytes = unsafe { uapi::as_maybe_uninit_bytes_mut2(buf) };
let mut offset = 0;
while offset < bytes.len() {
if self.read_full_(bytes, &mut offset)? {
self.fd.readable().await?;
}
}
Ok(())
}
fn read_full_(
&mut self,
bytes: &mut [MaybeUninit<u8>],
offset: &mut usize,
) -> Result<bool, BufFdError> {
let num_bytes = (bytes.len() - *offset).min(self.in_right - self.in_left);
if num_bytes > 0 {
let left = self.in_left % BUF_SIZE;
let right = (self.in_left + num_bytes) % BUF_SIZE;
if left < right {
bytes[*offset..*offset + num_bytes].copy_from_slice(&self.in_buf[left..right]);
} else {
bytes[*offset..*offset + (BUF_SIZE - left)].copy_from_slice(&self.in_buf[left..]);
bytes[*offset + (BUF_SIZE - left)..*offset + num_bytes]
.copy_from_slice(&self.in_buf[..right]);
}
self.in_left += num_bytes;
*offset += num_bytes;
}
if *offset == bytes.len() {
return Ok(false);
}
let left = self.in_left % BUF_SIZE;
let right = self.in_right % BUF_SIZE;
let mut iov = if right < left {
[&mut self.in_buf[right..left], &mut []]
} else {
let (l, r) = self.in_buf.split_at_mut(right);
[r, &mut l[..left]]
};
let mut hdr = uapi::MsghdrMut {
iov: &mut iov[..],
control: Some(&mut self.in_cmsg_buf[..]),
name: uapi::sockaddr_none_mut(),
flags: 0,
};
let (iov, _, mut cmsg) = match uapi::recvmsg(self.fd.raw(), &mut hdr, c::MSG_DONTWAIT) {
Ok((iov, _, _)) if iov.is_empty() => return Err(BufFdError::Closed),
Ok(v) => v,
Err(Errno(c::EAGAIN)) => return Ok(true),
Err(e) => return Err(BufFdError::Io(e.into())),
};
self.in_right += iov.len();
while cmsg.len() > 0 {
let (_, hdr, data) = match uapi::cmsg_read(&mut cmsg) {
Ok(m) => m,
Err(e) => return Err(BufFdError::Io(e.into())),
};
if (hdr.cmsg_level, hdr.cmsg_type) == (c::SOL_SOCKET, c::SCM_RIGHTS) {
self.in_fd.extend(uapi::pod_iter(data).unwrap());
}
}
if self.in_fd.len() > MAX_IN_FD {
return Err(BufFdError::TooManyFds);
}
Ok(false)
}
pub fn get_fd(&mut self) -> Result<OwnedFd, BufFdError> {
match self.in_fd.pop_front() {
Some(f) => Ok(f),
None => Err(BufFdError::NoFd),
}
}
}

121
src/utils/buffd/buf_out.rs Normal file
View file

@ -0,0 +1,121 @@
use crate::async_engine::AsyncFd;
use crate::utils::buffd::{BufFdError, BUF_SIZE, CMSG_BUF_SIZE};
use futures::{select, FutureExt};
use std::collections::VecDeque;
use std::mem::MaybeUninit;
use std::slice;
use uapi::{c, Errno, OwnedFd};
pub(super) const OUT_BUF_SIZE: usize = 2 * BUF_SIZE;
pub(super) struct MsgFds {
pub(super) pos: usize,
pub(super) fds: Vec<OwnedFd>,
}
pub struct BufFdOut {
fd: AsyncFd,
pub(super) out_pos: usize,
pub(super) out_buf: *mut [MaybeUninit<u8>; OUT_BUF_SIZE],
pub(super) fds: VecDeque<MsgFds>,
cmsg_buf: Box<[MaybeUninit<u8>; CMSG_BUF_SIZE]>,
}
impl BufFdOut {
pub fn new(fd: AsyncFd) -> Self {
Self {
fd,
out_pos: 0,
out_buf: Box::into_raw(Box::new([MaybeUninit::<u32>::uninit(); OUT_BUF_SIZE / 4])) as _,
fds: Default::default(),
cmsg_buf: Box::new([MaybeUninit::uninit(); CMSG_BUF_SIZE]),
}
}
pub fn write(&mut self, bytes: &[MaybeUninit<u8>]) {
if bytes.len() > OUT_BUF_SIZE - self.out_pos {
panic!("Out buffer overflow");
}
unsafe {
(*self.out_buf)[self.out_pos..self.out_pos + bytes.len()].copy_from_slice(bytes);
}
self.out_pos += bytes.len();
}
pub fn needs_flush(&self) -> bool {
self.out_pos > BUF_SIZE
}
pub async fn flush(&mut self) -> Result<(), BufFdError> {
let mut timeout = None;
let mut pos = 0;
while pos < self.out_pos {
if self.flush_sync(&mut pos)? {
if timeout.is_none() {
timeout = Some(self.fd.eng().timeout(5000)?.fuse());
}
select! {
_ = timeout.as_mut().unwrap() => return Err(BufFdError::Timeout),
res = self.fd.writable().fuse() => res?,
}
}
}
self.out_pos = 0;
Ok(())
}
fn flush_sync(&mut self, pos: &mut usize) -> Result<bool, BufFdError> {
while *pos < self.out_pos {
let mut buf = unsafe { &(*self.out_buf)[*pos..self.out_pos] };
let mut cmsg_len = 0;
let mut fds_opt = None;
{
let mut f = self.fds.front().map(|f| f.pos);
if f == Some(*pos) {
let fds = self.fds.pop_front().unwrap();
let hdr = c::cmsghdr {
cmsg_len: 0,
cmsg_level: c::SOL_SOCKET,
cmsg_type: c::SCM_RIGHTS,
};
let mut cmsg_buf = &mut self.cmsg_buf[..];
cmsg_len = uapi::cmsg_write(&mut cmsg_buf, hdr, &fds.fds[..]).unwrap();
fds_opt = Some(fds);
f = self.fds.front().map(|f| f.pos)
}
if let Some(next_pos) = f {
buf = &buf[..next_pos - *pos];
}
}
let hdr = uapi::Msghdr {
iov: slice::from_ref(&buf),
control: Some(&self.cmsg_buf[..cmsg_len]),
name: uapi::sockaddr_none_ref(),
};
let bytes_sent =
match uapi::sendmsg(self.fd.raw(), &hdr, c::MSG_DONTWAIT | c::MSG_NOSIGNAL) {
Ok(b) => b,
Err(Errno(c::EAGAIN)) => {
if let Some(fds) = fds_opt {
self.fds.push_front(fds);
}
return Ok(true);
}
Err(Errno(c::ECONNRESET)) => return Err(BufFdError::Closed),
Err(e) => return Err(BufFdError::Io(e.into())),
};
*pos += bytes_sent;
}
Ok(false)
}
}
impl Drop for BufFdOut {
fn drop(&mut self) {
unsafe {
Box::from_raw(self.out_buf as *mut [MaybeUninit<u32>; OUT_BUF_SIZE / 4]);
}
}
}

31
src/utils/buffd/mod.rs Normal file
View file

@ -0,0 +1,31 @@
use crate::async_engine::AsyncError;
pub use buf_in::BufFdIn;
pub use buf_out::BufFdOut;
use thiserror::Error;
pub use wl_formatter::WlFormatter;
pub use wl_parser::{WlParser, WlParserError};
mod buf_in;
mod buf_out;
mod wl_formatter;
mod wl_parser;
#[derive(Debug, Error)]
pub enum BufFdError {
#[error("An IO error occurred")]
Io(#[source] std::io::Error),
#[error("An async error occurred")]
Async(#[from] AsyncError),
#[error("The peer did not send a file descriptor")]
NoFd,
#[error("The peer sent too many file descriptors")]
TooManyFds,
#[error("The peer closed the connection")]
Closed,
#[error("The connection timed out")]
Timeout,
}
const BUF_SIZE: usize = 4096;
const CMSG_BUF_SIZE: usize = 4096;
const MAX_IN_FD: usize = 4;

View file

@ -0,0 +1,78 @@
use crate::objects::ObjectId;
use crate::utils::buffd::buf_out::{BufFdOut, MsgFds};
use std::mem;
use std::mem::MaybeUninit;
use uapi::OwnedFd;
pub struct WlFormatter<'a> {
buf: &'a mut BufFdOut,
pos: usize,
fds: Vec<OwnedFd>,
}
impl<'a> WlFormatter<'a> {
pub fn new(buf: &'a mut BufFdOut) -> Self {
Self {
pos: buf.out_pos,
buf,
fds: vec![],
}
}
pub fn int(&mut self, int: i32) -> &mut Self {
self.buf.write(uapi::as_maybe_uninit_bytes(&int));
self
}
pub fn uint(&mut self, int: u32) -> &mut Self {
self.buf.write(uapi::as_maybe_uninit_bytes(&int));
self
}
pub fn fixed(&mut self, fixed: f64) -> &mut Self {
let int = (fixed * 256.0) as i32;
self.buf.write(uapi::as_maybe_uninit_bytes(&int));
self
}
pub fn string(&mut self, s: &str) -> &mut Self {
let len = s.len() + 1;
let cap = (len + 3) & !3;
self.uint(len as u32);
self.buf.write(uapi::as_maybe_uninit_bytes(s.as_bytes()));
let none = [MaybeUninit::new(0); 4];
self.buf.write(&none[..cap - len + 1]);
self
}
pub fn fd(&mut self, fd: OwnedFd) -> &mut Self {
self.fds.push(fd);
self
}
pub fn object(&mut self, obj: ObjectId) -> &mut Self {
self.uint(obj.raw())
}
pub fn header(&mut self, obj: ObjectId, event: u32) -> &mut Self {
self.object(obj).uint(event)
}
}
impl<'a> Drop for WlFormatter<'a> {
fn drop(&mut self) {
assert!(self.buf.out_pos - self.pos >= 8);
assert_eq!(self.pos % 4, 0);
unsafe {
let second_ptr = (self.buf.out_buf as *mut u8).add(self.pos + 4) as *mut u32;
let len = ((self.buf.out_pos - self.pos) as u32) << 16;
*second_ptr |= len;
}
if self.fds.len() > 0 {
self.buf.fds.push_back(MsgFds {
pos: self.pos,
fds: mem::take(&mut self.fds),
})
}
}
}

View file

@ -0,0 +1,93 @@
use crate::globals::GlobalName;
use crate::objects::ObjectId;
use crate::utils::buffd::BufFdIn;
use thiserror::Error;
use uapi::OwnedFd;
#[derive(Debug, Error)]
pub enum WlParserError {
#[error("The message ended unexpectedly")]
UnexpectedEof,
#[error("The message contained a non-utf8 string")]
NonUtf8,
#[error("The message contained a string of size 0")]
EmptyString,
#[error("Message is missing a required file descriptor")]
MissingFd,
#[error("There is trailing data after the message")]
TrailingData,
}
pub struct WlParser<'a, 'b> {
buf: &'a mut BufFdIn,
pos: usize,
data: &'b [u8],
}
impl<'a, 'b> WlParser<'a, 'b> {
pub fn new(buf: &'a mut BufFdIn, data: &'b [u32]) -> Self {
Self {
buf,
pos: 0,
data: unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 4) },
}
}
pub fn int(&mut self) -> Result<i32, WlParserError> {
if self.data.len() - self.pos < 4 {
return Err(WlParserError::UnexpectedEof);
}
let res = unsafe { *(self.data.as_ptr().add(self.pos) as *const i32) };
self.pos += 4;
Ok(res)
}
pub fn uint(&mut self) -> Result<u32, WlParserError> {
self.int().map(|i| i as u32)
}
pub fn object(&mut self) -> Result<ObjectId, WlParserError> {
self.int().map(|i| ObjectId::from_raw(i as u32))
}
pub fn global(&mut self) -> Result<GlobalName, WlParserError> {
self.int().map(|i| GlobalName::from_raw(i as u32))
}
pub fn fixed(&mut self) -> Result<f64, WlParserError> {
self.int().map(|i| i as f64 / 256.0)
}
pub fn string(&mut self) -> Result<&'b str, WlParserError> {
let len = self.uint()? as usize;
if len == 0 {
return Err(WlParserError::EmptyString);
}
let cap = (len + 3) & !3;
if cap > self.data.len() - self.pos {
return Err(WlParserError::UnexpectedEof);
}
let s = &self.data[self.pos..self.pos + len - 1];
let s = match std::str::from_utf8(s) {
Ok(s) => s,
_ => return Err(WlParserError::NonUtf8),
};
self.pos += cap;
Ok(s)
}
pub fn fd(&mut self) -> Result<OwnedFd, WlParserError> {
match self.buf.get_fd() {
Ok(fd) => Ok(fd),
_ => Err(WlParserError::MissingFd),
}
}
pub fn eof(&self) -> Result<(), WlParserError> {
if self.pos == self.data.len() {
Ok(())
} else {
Err(WlParserError::TrailingData)
}
}
}