1
0
Fork 0
forked from wry/wry

wire: move message buffers into workspace crates

This commit is contained in:
kossLAN 2026-05-29 11:07:43 -04:00
parent d8380b3dce
commit a1e4641e82
No known key found for this signature in database
18 changed files with 187 additions and 166 deletions

91
wire-buf/src/buf_in.rs Normal file
View file

@ -0,0 +1,91 @@
use {
crate::{BUF_SIZE, BufFdError, MAX_IN_FD},
jay_io_uring::IoUring,
jay_utils::buf::Buf,
smallvec::SmallVec,
std::{collections::VecDeque, mem::MaybeUninit, rc::Rc},
uapi::{OwnedFd, Pod},
};
pub struct BufFdIn {
fd: Rc<OwnedFd>,
ring: Rc<IoUring>,
in_fd: VecDeque<Rc<OwnedFd>>,
in_buf: Buf,
in_left: usize,
in_right: usize,
}
impl BufFdIn {
pub fn new(fd: &Rc<OwnedFd>, ring: &Rc<IoUring>) -> Self {
Self {
fd: fd.clone(),
ring: ring.clone(),
in_fd: Default::default(),
in_buf: Buf::new(BUF_SIZE),
in_left: 0,
in_right: 0,
}
}
pub async fn read_full<T: Pod + ?Sized>(&mut self, buf: &mut T) -> Result<(), BufFdError> {
let bytes = unsafe { uapi::as_maybe_uninit_bytes_mut2(buf) };
let mut offset = 0;
while offset < bytes.len() {
self.read_full_(bytes, &mut offset).await?;
}
Ok(())
}
async fn read_full_(
&mut self,
bytes: &mut [MaybeUninit<u8>],
offset: &mut usize,
) -> Result<(), BufFdError> {
let in_buf = uapi::as_maybe_uninit_bytes(&self.in_buf[..]);
let num_bytes = (bytes.len() - *offset).min(self.in_right - self.in_left);
if num_bytes > 0 {
let left = self.in_left % BUF_SIZE;
let right = (self.in_left + num_bytes) % BUF_SIZE;
if left < right {
bytes[*offset..*offset + num_bytes].copy_from_slice(&in_buf[left..right]);
} else {
bytes[*offset..*offset + (BUF_SIZE - left)].copy_from_slice(&in_buf[left..]);
bytes[*offset + (BUF_SIZE - left)..*offset + num_bytes]
.copy_from_slice(&in_buf[..right]);
}
self.in_left += num_bytes;
*offset += num_bytes;
}
if *offset == bytes.len() {
return Ok(());
}
let left = self.in_left % BUF_SIZE;
let right = self.in_right % BUF_SIZE;
let mut iov = SmallVec::<[_; 2]>::new();
if right < left {
iov.push(self.in_buf.slice(right..left));
} else {
iov.push(self.in_buf.slice(right..));
iov.push(self.in_buf.slice(..left));
}
match self.ring.recvmsg(&self.fd, &mut iov, &mut self.in_fd).await {
Ok(0) => return Err(BufFdError::Closed),
Ok(n) => self.in_right += n,
Err(e) => return Err(BufFdError::Ring(e)),
}
if self.in_fd.len() > MAX_IN_FD {
return Err(BufFdError::TooManyFds);
}
Ok(())
}
pub fn get_fd(&mut self) -> Result<Rc<OwnedFd>, BufFdError> {
match self.in_fd.pop_front() {
Some(f) => Ok(f),
None => Err(BufFdError::NoFd),
}
}
}

156
wire-buf/src/buf_out.rs Normal file
View file

@ -0,0 +1,156 @@
use {
crate::{BUF_SIZE, BufFdError},
jay_io_uring::{IoUring, IoUringError},
jay_time::Time,
jay_utils::{buf::Buf, oserror::OsError},
std::{
collections::VecDeque,
mem::{self},
rc::Rc,
},
uapi::{OwnedFd, c},
};
pub(super) const OUT_BUF_SIZE: usize = 2 * BUF_SIZE;
pub(super) struct MsgFds {
pub(super) pos: usize,
pub(super) fds: Vec<Rc<OwnedFd>>,
}
pub(super) struct OutBufferMeta {
pub(super) read_pos: usize,
pub(super) write_pos: usize,
pub(super) fds: VecDeque<MsgFds>,
}
pub struct OutBuffer {
pub(super) meta: OutBufferMeta,
pub(super) buf: Buf,
}
impl Default for OutBuffer {
fn default() -> Self {
Self {
meta: OutBufferMeta {
read_pos: 0,
write_pos: 0,
fds: Default::default(),
},
buf: Buf::new(OUT_BUF_SIZE),
}
}
}
impl OutBuffer {
pub fn is_full(&self) -> bool {
self.meta.write_pos > BUF_SIZE
}
}
const LIMIT_PENDING: usize = 10;
#[derive(Default)]
pub struct OutBufferSwapchain {
pub cur: OutBuffer,
pub pending: VecDeque<OutBuffer>,
pub free: Vec<OutBuffer>,
}
impl OutBufferSwapchain {
pub fn exceeds_limit(&self) -> bool {
self.pending.len() > LIMIT_PENDING
}
pub fn commit(&mut self) {
if self.cur.meta.write_pos > 0 {
let new = self.free.pop().unwrap_or_default();
let old = mem::replace(&mut self.cur, new);
self.pending.push_back(old);
}
}
}
pub struct BufFdOut {
fd: Rc<OwnedFd>,
ring: Rc<IoUring>,
}
impl BufFdOut {
pub fn new(fd: &Rc<OwnedFd>, ring: &Rc<IoUring>) -> Self {
Self {
fd: fd.clone(),
ring: ring.clone(),
}
}
pub async fn flush(&mut self, buf: &mut OutBuffer, timeout: Time) -> Result<(), BufFdError> {
while buf.meta.read_pos < buf.meta.write_pos {
self.flush_buffer(buf, Some(timeout)).await?;
}
buf.meta.read_pos = 0;
buf.meta.write_pos = 0;
Ok(())
}
pub async fn flush_no_timeout(&mut self, buf: &mut OutBuffer) -> Result<(), BufFdError> {
while buf.meta.read_pos < buf.meta.write_pos {
self.flush_buffer(buf, None).await?;
}
buf.meta.read_pos = 0;
buf.meta.write_pos = 0;
Ok(())
}
async fn flush_buffer(
&mut self,
buffer: &mut OutBuffer,
timeout: Option<Time>,
) -> Result<(), BufFdError> {
let mut buf = buffer
.buf
.slice(buffer.meta.read_pos..buffer.meta.write_pos);
let mut fds = vec![];
{
let mut f = buffer.meta.fds.front().map(|f| f.pos);
if f == Some(buffer.meta.read_pos) {
fds = buffer.meta.fds.pop_front().unwrap().fds;
f = buffer.meta.fds.front().map(|f| f.pos)
}
if let Some(next_pos) = f {
buf = buffer.buf.slice(buffer.meta.read_pos..next_pos);
}
}
match self.ring.sendmsg_one(&self.fd, buf, fds, timeout).await {
Ok(n) => {
buffer.meta.read_pos += n;
Ok(())
}
Err(IoUringError::OsError(OsError(c::ECONNRESET))) => return Err(BufFdError::Closed),
Err(IoUringError::OsError(OsError(c::ECANCELED))) => return Err(BufFdError::Timeout),
Err(e) => return Err(BufFdError::Ring(e)),
}
}
pub async fn flush2(
&mut self,
mut buf: Buf,
mut fds: Vec<Rc<OwnedFd>>,
) -> Result<(), BufFdError> {
let mut read_pos = 0;
while read_pos < buf.len() {
let res = self
.ring
.sendmsg_one(&self.fd, buf.slice(read_pos..), mem::take(&mut fds), None)
.await;
match res {
Ok(n) => read_pos += n,
Err(IoUringError::OsError(OsError(c::ECONNRESET))) => {
return Err(BufFdError::Closed);
}
Err(e) => return Err(BufFdError::Io(e)),
}
}
Ok(())
}
}

View file

@ -0,0 +1,103 @@
use {
crate::buf_out::{MsgFds, OUT_BUF_SIZE, OutBuffer, OutBufferMeta},
jay_wire_types::EiObjectId,
std::{mem, rc::Rc},
uapi::OwnedFd,
};
pub struct EiMsgFormatter<'a> {
buf: &'a mut [u8],
meta: &'a mut OutBufferMeta,
pos: usize,
fds: &'a mut Vec<Rc<OwnedFd>>,
}
impl<'a> EiMsgFormatter<'a> {
pub fn new(buf: &'a mut OutBuffer, fds: &'a mut Vec<Rc<OwnedFd>>) -> Self {
Self {
pos: buf.meta.write_pos,
buf: &mut buf.buf[..],
fds,
meta: &mut buf.meta,
}
}
fn write(&mut self, bytes: &[u8]) {
if bytes.len() > OUT_BUF_SIZE - self.meta.write_pos {
panic!("Out buffer overflow");
}
self.buf[self.meta.write_pos..self.meta.write_pos + bytes.len()].copy_from_slice(bytes);
self.meta.write_pos += bytes.len();
}
pub fn int(&mut self, int: i32) -> &mut Self {
self.write(uapi::as_bytes(&int));
self
}
pub fn uint(&mut self, int: u32) -> &mut Self {
self.write(uapi::as_bytes(&int));
self
}
pub fn long(&mut self, int: i64) -> &mut Self {
self.write(uapi::as_bytes(&int));
self
}
pub fn ulong(&mut self, int: u64) -> &mut Self {
self.write(uapi::as_bytes(&int));
self
}
pub fn float(&mut self, f: f32) -> &mut Self {
self.write(uapi::as_bytes(&f));
self
}
pub fn optstr<S: AsRef<[u8]> + ?Sized>(&mut self, s: Option<&S>) -> &mut Self {
match s {
Some(s) => self.string(s),
_ => self.uint(0),
}
}
pub fn string<S: AsRef<[u8]> + ?Sized>(&mut self, s: &S) -> &mut Self {
let s = s.as_ref();
let len = s.len() + 1;
let cap = (len + 3) & !3;
self.uint(len as u32);
self.write(uapi::as_bytes(s));
let none = [0; 4];
self.write(&none[..cap - len + 1]);
self
}
pub fn fd(&mut self, fd: Rc<OwnedFd>) -> &mut Self {
self.fds.push(fd);
self
}
pub fn object<T: Into<EiObjectId>>(&mut self, obj: T) -> &mut Self {
self.ulong(obj.into().raw())
}
pub fn header<T: Into<EiObjectId>>(&mut self, obj: T, event: u32) -> &mut Self {
self.object(obj).uint(0).uint(event)
}
pub fn write_len(self) {
assert!(self.meta.write_pos - self.pos >= 16);
assert_eq!(self.pos % 4, 0);
unsafe {
let second_ptr = self.buf.as_ptr().add(self.pos + 8) as *mut u32;
*second_ptr = (self.meta.write_pos - self.pos) as u32;
}
if self.fds.len() > 0 {
self.meta.fds.push_back(MsgFds {
pos: self.pos,
fds: mem::take(self.fds),
})
}
}
}

113
wire-buf/src/ei_parser.rs Normal file
View file

@ -0,0 +1,113 @@
use {
crate::BufFdIn,
jay_wire_types::EiObjectId,
std::{ptr, rc::Rc},
thiserror::Error,
uapi::OwnedFd,
};
#[derive(Debug, Error)]
pub enum EiMsgParserError {
#[error("The message ended unexpectedly")]
UnexpectedEof,
#[error("The message contained a string of size 0")]
EmptyString,
#[error("Message is missing a required file descriptor")]
MissingFd,
#[error("There is trailing data after the message")]
TrailingData,
#[error("String is not UTF-8")]
NonUtf8,
}
pub struct EiMsgParser<'a, 'b> {
buf: &'a mut BufFdIn,
pos: usize,
data: &'b [u8],
}
impl<'a, 'b> EiMsgParser<'a, 'b> {
pub fn new(buf: &'a mut BufFdIn, data: &'b [u32]) -> Self {
Self {
buf,
pos: 0,
data: uapi::as_bytes(data),
}
}
pub fn int(&mut self) -> Result<i32, EiMsgParserError> {
if self.data.len() - self.pos < 4 {
return Err(EiMsgParserError::UnexpectedEof);
}
let res = unsafe { *(self.data.as_ptr().add(self.pos) as *const i32) };
self.pos += 4;
Ok(res)
}
pub fn uint(&mut self) -> Result<u32, EiMsgParserError> {
self.int().map(|i| i as u32)
}
pub fn long(&mut self) -> Result<i64, EiMsgParserError> {
if self.data.len() - self.pos < 8 {
return Err(EiMsgParserError::UnexpectedEof);
}
let res = unsafe { ptr::read_unaligned(self.data.as_ptr().add(self.pos) as *const i64) };
self.pos += 8;
Ok(res)
}
pub fn ulong(&mut self) -> Result<u64, EiMsgParserError> {
self.long().map(|i| i as u64)
}
pub fn object<T>(&mut self) -> Result<T, EiMsgParserError>
where
EiObjectId: Into<T>,
{
self.ulong().map(|i| EiObjectId::from_raw(i).into())
}
pub fn float(&mut self) -> Result<f32, EiMsgParserError> {
Ok(f32::from_bits(self.uint()?))
}
pub fn optstr(&mut self) -> Result<Option<&'b str>, EiMsgParserError> {
let len = self.uint()? as usize;
if len == 0 {
return Ok(None);
}
let cap = (len + 3) & !3;
if cap > self.data.len() - self.pos {
return Err(EiMsgParserError::UnexpectedEof);
}
let pos = self.pos;
self.pos += cap;
match std::str::from_utf8(&self.data[pos..pos + len - 1]) {
Ok(s) => Ok(Some(s)),
Err(_) => Err(EiMsgParserError::NonUtf8),
}
}
pub fn str(&mut self) -> Result<&'b str, EiMsgParserError> {
match self.optstr()? {
Some(s) => Ok(s),
_ => Err(EiMsgParserError::EmptyString),
}
}
pub fn fd(&mut self) -> Result<Rc<OwnedFd>, EiMsgParserError> {
match self.buf.get_fd() {
Ok(fd) => Ok(fd),
_ => Err(EiMsgParserError::MissingFd),
}
}
pub fn eof(&self) -> Result<(), EiMsgParserError> {
if self.pos == self.data.len() {
Ok(())
} else {
Err(EiMsgParserError::TrailingData)
}
}
}

138
wire-buf/src/formatter.rs Normal file
View file

@ -0,0 +1,138 @@
use {
crate::buf_out::{MsgFds, OUT_BUF_SIZE, OutBuffer, OutBufferMeta},
jay_units::Fixed,
jay_wire_types::ObjectId,
std::{mem, rc::Rc},
uapi::{OwnedFd, Packed},
};
pub struct MsgFormatter<'a> {
buf: &'a mut [u8],
meta: &'a mut OutBufferMeta,
pos: usize,
fds: &'a mut Vec<Rc<OwnedFd>>,
}
impl<'a> MsgFormatter<'a> {
pub fn new(buf: &'a mut OutBuffer, fds: &'a mut Vec<Rc<OwnedFd>>) -> Self {
Self {
pos: buf.meta.write_pos,
buf: &mut buf.buf[..],
fds,
meta: &mut buf.meta,
}
}
fn write(&mut self, bytes: &[u8]) {
if bytes.len() > OUT_BUF_SIZE - self.meta.write_pos {
panic!("Out buffer overflow");
}
self.buf[self.meta.write_pos..self.meta.write_pos + bytes.len()].copy_from_slice(bytes);
self.meta.write_pos += bytes.len();
}
#[inline(always)]
pub fn data(&mut self, data: &[u32]) {
self.write(uapi::as_bytes(data));
}
pub fn int(&mut self, int: i32) -> &mut Self {
self.write(uapi::as_bytes(&int));
self
}
pub fn uint(&mut self, int: u32) -> &mut Self {
self.write(uapi::as_bytes(&int));
self
}
pub fn u64(&mut self, int: u64) -> &mut Self {
self.uint((int >> 32) as u32);
self.uint(int as u32)
}
pub fn u64_rev(&mut self, int: u64) -> &mut Self {
self.uint(int as u32);
self.uint((int >> 32) as u32)
}
pub fn fixed(&mut self, fixed: Fixed) -> &mut Self {
self.write(uapi::as_bytes(&fixed.0));
self
}
pub fn optstr<S: AsRef<[u8]> + ?Sized>(&mut self, s: Option<&S>) -> &mut Self {
match s {
Some(s) => self.string(s),
_ => self.uint(0),
}
}
pub fn string<S: AsRef<[u8]> + ?Sized>(&mut self, s: &S) -> &mut Self {
let s = s.as_ref();
let len = s.len() + 1;
let cap = (len + 3) & !3;
self.uint(len as u32);
self.write(uapi::as_bytes(s));
let none = [0; 4];
self.write(&none[..cap - len + 1]);
self
}
pub fn fd(&mut self, fd: Rc<OwnedFd>) -> &mut Self {
self.fds.push(fd);
self
}
pub fn object<T: Into<ObjectId>>(&mut self, obj: T) -> &mut Self {
self.uint(obj.into().raw())
}
pub fn header<T: Into<ObjectId>>(&mut self, obj: T, event: u32) -> &mut Self {
self.object(obj).uint(event)
}
pub fn array<F: FnOnce(&mut MsgFormatter<'_>)>(&mut self, f: F) -> &mut Self {
let pos = self.meta.write_pos;
self.uint(0);
let len = {
let mut fmt = MsgFormatter {
buf: self.buf,
meta: self.meta,
pos,
fds: self.fds,
};
f(&mut fmt);
let len = self.meta.write_pos - pos - 4;
let none = [0; 4];
self.write(&none[..self.meta.write_pos.wrapping_neg() & 3]);
len as u32
};
self.buf[pos..pos + 4].copy_from_slice(uapi::as_bytes(&len));
self
}
pub fn binary<T: ?Sized + Packed>(&mut self, t: &T) -> &mut Self {
self.uint(size_of_val(t) as u32);
self.write(uapi::as_bytes(t));
let none = [0; 4];
self.write(&none[..self.meta.write_pos.wrapping_neg() & 3]);
self
}
pub fn write_len(self) {
assert!(self.meta.write_pos - self.pos >= 8);
assert_eq!(self.pos % 4, 0);
unsafe {
let second_ptr = self.buf.as_ptr().add(self.pos + 4) as *mut u32;
let len = ((self.meta.write_pos - self.pos) as u32) << 16;
*second_ptr |= len;
}
if self.fds.len() > 0 {
self.meta.fds.push_back(MsgFds {
pos: self.pos,
fds: mem::take(self.fds),
})
}
}
}

43
wire-buf/src/lib.rs Normal file
View file

@ -0,0 +1,43 @@
use {jay_io_uring::IoUringError, thiserror::Error};
pub use {
buf_in::BufFdIn,
buf_out::{BufFdOut, OutBuffer, OutBufferSwapchain},
ei_formatter::EiMsgFormatter,
ei_parser::{EiMsgParser, EiMsgParserError},
formatter::MsgFormatter,
parser::{MsgParser, MsgParserError},
wl_buf_in::{WlBufFdIn, WlMessage},
};
mod buf_in;
mod buf_out;
mod ei_formatter;
mod ei_parser;
mod formatter;
mod parser;
mod wl_buf_in;
#[derive(Debug, Error)]
pub enum BufFdError {
#[error("An IO error occurred")]
Io(#[source] IoUringError),
#[error("An io-uring error occurred")]
Ring(#[from] IoUringError),
#[error("The peer did not send a file descriptor")]
NoFd,
#[error("The peer sent too many file descriptors")]
TooManyFds,
#[error("The peer closed the connection")]
Closed,
#[error("The connection timed out")]
Timeout,
#[error("Message size is not a multiple of 4")]
UnalignedMessageSize,
#[error("Message size is larger than 4096")]
MessageTooLarge,
#[error("Message size is smaller than 8")]
MessageTooSmall,
}
const BUF_SIZE: usize = 4096;
const MAX_IN_FD: usize = 32;

167
wire-buf/src/parser.rs Normal file
View file

@ -0,0 +1,167 @@
use {
bstr::{BStr, ByteSlice},
jay_units::Fixed,
jay_wire_types::{GlobalName, ObjectId},
std::{collections::VecDeque, ptr, rc::Rc},
thiserror::Error,
uapi::{OwnedFd, Pod},
};
#[derive(Debug, Error)]
pub enum MsgParserError {
#[error("The message ended unexpectedly")]
UnexpectedEof,
#[error("The binary array contains more than the required number of bytes")]
BinaryArrayTooLarge,
#[error("The size of the binary array is not a multiple of the element size")]
BinaryArraySize,
#[error("The message contained a string of size 0")]
EmptyString,
#[error("Message is missing a required file descriptor")]
MissingFd,
#[error("There is trailing data after the message")]
TrailingData,
#[error("String is not UTF-8")]
NonUtf8,
#[error("The message has an unexpected size")]
UnexpectedMessageSize,
}
pub struct MsgParser<'a, 'b> {
fds: &'a mut VecDeque<Rc<OwnedFd>>,
pos: usize,
data: &'b [u32],
}
impl<'a, 'b> MsgParser<'a, 'b> {
pub fn new(fds: &'a mut VecDeque<Rc<OwnedFd>>, data: &'b [u32]) -> Self {
Self { fds, pos: 0, data }
}
#[inline(always)]
pub fn data(&self) -> &[u32] {
self.data
}
pub fn int(&mut self) -> Result<i32, MsgParserError> {
if self.pos >= self.data.len() {
return Err(MsgParserError::UnexpectedEof);
}
let res = unsafe { *(self.data.as_ptr().add(self.pos) as *const i32) };
self.pos += 1;
Ok(res)
}
pub fn uint(&mut self) -> Result<u32, MsgParserError> {
self.int().map(|i| i as u32)
}
pub fn u64(&mut self) -> Result<u64, MsgParserError> {
let hi = self.uint()?;
let lo = self.uint()?;
Ok(((hi as u64) << 32) | lo as u64)
}
pub fn u64_rev(&mut self) -> Result<u64, MsgParserError> {
let lo = self.uint()?;
let hi = self.uint()?;
Ok(((hi as u64) << 32) | lo as u64)
}
pub fn object<T>(&mut self) -> Result<T, MsgParserError>
where
ObjectId: Into<T>,
{
self.int().map(|i| ObjectId::from_raw(i as u32).into())
}
pub fn global(&mut self) -> Result<GlobalName, MsgParserError> {
self.int().map(|i| GlobalName::from_raw(i as u32))
}
pub fn fixed(&mut self) -> Result<Fixed, MsgParserError> {
self.int().map(Fixed)
}
pub fn bstr(&mut self) -> Result<&'b BStr, MsgParserError> {
let s = self.array()?;
if s.len() == 0 {
return Err(MsgParserError::EmptyString);
}
Ok(s[..s.len() - 1].as_bstr())
}
pub fn optstr(&mut self) -> Result<Option<&'b str>, MsgParserError> {
let s = self.array()?;
if s.len() == 0 {
return Ok(None);
}
match s[..s.len() - 1].as_bstr().to_str() {
Ok(s) => Ok(Some(s)),
_ => Err(MsgParserError::NonUtf8),
}
}
pub fn str(&mut self) -> Result<&'b str, MsgParserError> {
match self.bstr()?.to_str() {
Ok(s) => Ok(s),
_ => Err(MsgParserError::NonUtf8),
}
}
pub fn fd(&mut self) -> Result<Rc<OwnedFd>, MsgParserError> {
match self.fds.pop_front() {
Some(fd) => Ok(fd),
_ => Err(MsgParserError::MissingFd),
}
}
pub fn eof(&self) -> Result<(), MsgParserError> {
if self.pos == self.data.len() {
Ok(())
} else {
Err(MsgParserError::TrailingData)
}
}
pub fn array(&mut self) -> Result<&'b [u8], MsgParserError> {
let len = self.uint()? as usize;
let cap = (len + 3) >> 2;
if cap > self.data.len() - self.pos {
return Err(MsgParserError::UnexpectedEof);
}
let pos = self.pos;
self.pos += cap;
Ok(&uapi::as_bytes(&self.data[pos..])[..len])
}
pub fn binary<T: Pod>(&mut self) -> Result<T, MsgParserError> {
let array = self.array()?;
if array.len() < size_of::<T>() {
return Err(MsgParserError::UnexpectedEof);
}
if array.len() > size_of::<T>() {
return Err(MsgParserError::BinaryArrayTooLarge);
}
unsafe { Ok(ptr::read_unaligned(array.as_ptr() as _)) }
}
pub fn binary_array<T: Pod>(&mut self) -> Result<&'b [T], MsgParserError> {
if align_of::<T>() > 4 {
panic!("Alignment of binary array element is too large");
};
if size_of::<T>() == 0 {
panic!("Size of binary array element is 0");
};
let array = self.array()?;
if array.len() % size_of::<T>() != 0 {
return Err(MsgParserError::BinaryArraySize);
}
unsafe {
Ok(std::slice::from_raw_parts(
array.as_ptr() as _,
array.len() / size_of::<T>(),
))
}
}
}

121
wire-buf/src/wl_buf_in.rs Normal file
View file

@ -0,0 +1,121 @@
use {
crate::{BufFdError, MAX_IN_FD},
jay_io_uring::IoUring,
jay_utils::buf::Buf,
jay_wire_types::ObjectId,
std::{collections::VecDeque, ptr, rc::Rc, slice},
uapi::OwnedFd,
};
const WORD_SIZE: usize = 4;
const WORD_ALIGN: usize = 4;
const HEADER_WORDS: usize = 2;
const HEADER_SIZE: usize = HEADER_WORDS * WORD_SIZE;
const MAX_MESSAGE_SIZE: usize = 4096;
const BUF_SIZE: usize = 2 * MAX_MESSAGE_SIZE;
pub struct WlBufFdIn {
fd: Rc<OwnedFd>,
ring: Rc<IoUring>,
fds: VecDeque<Rc<OwnedFd>>,
buf: Buf,
lo: usize,
len: usize,
}
pub struct WlMessage<'a> {
pub obj_id: ObjectId,
pub message: u32,
pub body: &'a [u32],
pub fds: &'a mut VecDeque<Rc<OwnedFd>>,
}
impl WlBufFdIn {
pub fn new(fd: &Rc<OwnedFd>, ring: &Rc<IoUring>) -> Self {
let buf = Buf::new(BUF_SIZE);
assert_eq!(buf.as_ptr() as usize % WORD_ALIGN, 0);
Self {
fd: fd.clone(),
ring: ring.clone(),
fds: Default::default(),
buf,
lo: Default::default(),
len: Default::default(),
}
}
pub async fn read_message(&mut self) -> Result<WlMessage<'_>, BufFdError> {
if self.len == 0 {
self.lo = 0;
}
if self.len < HEADER_SIZE {
if self.lo > 0 {
self.compact();
}
while self.len < HEADER_SIZE {
self.recvmsg().await?;
}
}
let hdr: &[u32] =
unsafe { slice::from_raw_parts(self.buf[self.lo..].as_ptr().cast(), HEADER_WORDS) };
let obj_id = ObjectId::from_raw(hdr[0]);
let len = (hdr[1] >> 16) as usize;
let message = hdr[1] & 0xffff;
if len & 3 != 0 {
return Err(BufFdError::UnalignedMessageSize);
}
if len > MAX_MESSAGE_SIZE {
return Err(BufFdError::MessageTooLarge);
}
if len < HEADER_SIZE {
return Err(BufFdError::MessageTooSmall);
}
if len > self.len {
if self.lo + self.len >= MAX_MESSAGE_SIZE {
self.compact();
}
while len > self.len {
self.recvmsg().await?;
}
}
let body: &[u32] = unsafe {
let words = (len - HEADER_SIZE) >> 2;
slice::from_raw_parts(self.buf[self.lo + HEADER_SIZE..].as_ptr().cast(), words)
};
self.lo += len;
self.len -= len;
Ok(WlMessage {
obj_id,
message,
body,
fds: &mut self.fds,
})
}
#[inline(always)]
fn compact(&mut self) {
unsafe {
let dst = self.buf.as_mut_ptr();
let src = dst.add(self.lo);
ptr::copy(src, dst, self.len);
self.lo = 0;
}
}
async fn recvmsg(&mut self) -> Result<(), BufFdError> {
let mut buf = self.buf.slice(self.lo + self.len..);
match self
.ring
.recvmsg(&self.fd, slice::from_mut(&mut buf), &mut self.fds)
.await
{
Ok(0) => return Err(BufFdError::Closed),
Ok(n) => self.len += n,
Err(e) => return Err(BufFdError::Ring(e)),
}
if self.fds.len() > MAX_IN_FD {
return Err(BufFdError::TooManyFds);
}
Ok(())
}
}