diff --git a/build/wire.rs b/build/wire.rs index 2b952f64..e5ca40db 100644 --- a/build/wire.rs +++ b/build/wire.rs @@ -6,7 +6,7 @@ use { wire::parser::{Field, Lined, Message, Type, parse_messages, to_camel}, }, anyhow::{Context, Result}, - std::{fs::DirEntry, io::Write, os::unix::ffi::OsStrExt}, + std::{fmt, fs::DirEntry, io::Write, os::unix::ffi::OsStrExt}, }; fn write_type(f: &mut W, ty: &Type) -> Result<()> { @@ -110,26 +110,81 @@ fn write_message(f: &mut W, obj: &str, message: &Message) -> Result<() " fn parse({}: &mut MsgParser<'_, 'a>) -> Result {{", parser )?; - writeln!(f, " Ok(Self {{")?; - writeln!(f, " self_id: {}Id::NONE,", obj)?; - for field in &message.fields { - let p = match &field.val.ty.val { - Type::Id(..) => "object", - Type::U32 => "uint", - Type::I32 => "int", - Type::U64 => "u64", - Type::U64Rev => "u64_rev", - Type::OptStr => "optstr", - Type::Str => "str", - Type::Fixed => "fixed", - Type::Fd => "fd", - Type::BStr => "bstr", - Type::Array(_) => "binary_array", - Type::Pod(_) => "binary", - }; - writeln!(f, " {}: parser.{}()?,", field.val.name, p)?; + if message.is_fixed_size { + writeln!(f, " let [")?; + for (i, field) in message.fields.iter().enumerate() { + match &field.val.ty.val { + Type::U64 => { + writeln!(f, " arg{i}_hi,")?; + writeln!(f, " arg{i}_lo,")?; + } + Type::U64Rev => { + writeln!(f, " arg{i}_lo,")?; + writeln!(f, " arg{i}_hi,")?; + } + Type::Fd => {} + _ => { + writeln!(f, " arg{i},")?; + } + } + } + writeln!(f, " ] = *{parser}.data() else {{")?; + writeln!( + f, + " return Err(MsgParserError::UnexpectedMessageSize);" + )?; + writeln!(f, " }};")?; + writeln!(f, " Ok(Self {{")?; + writeln!(f, " self_id: {}Id::NONE,", obj)?; + for (i, field) in message.fields.iter().enumerate() { + writeln!( + f, + " {}: {},", + field.val.name, + fmt::from_fn(|f| { + match &field.val.ty.val { + Type::Id(_, name) => write!(f, "{name}Id(arg{i})"), + Type::U32 => write!(f, "arg{i}"), + Type::I32 => write!(f, "arg{i} as i32"), + Type::U64 | Type::U64Rev => { + write!(f, "((arg{i}_hi as u64) << 32) | (arg{i}_lo as u64)") + } + Type::OptStr => unreachable!(), + Type::Str => unreachable!(), + Type::Fixed => write!(f, "Fixed(arg{i} as i32)"), + Type::Fd => write!(f, "parser.fd()?"), + Type::BStr => unreachable!(), + Type::Array(_) => unreachable!(), + Type::Pod(_) => unreachable!(), + } + }) + )?; + } + writeln!(f, " }})")?; + } else { + writeln!(f, " let res = Ok(Self {{")?; + writeln!(f, " self_id: {}Id::NONE,", obj)?; + for field in &message.fields { + let p = match &field.val.ty.val { + Type::Id(..) => "object", + Type::U32 => "uint", + Type::I32 => "int", + Type::U64 => "u64", + Type::U64Rev => "u64_rev", + Type::OptStr => "optstr", + Type::Str => "str", + Type::Fixed => "fixed", + Type::Fd => "fd", + Type::BStr => "bstr", + Type::Array(_) => "binary_array", + Type::Pod(_) => "binary", + }; + writeln!(f, " {}: parser.{}()?,", field.val.name, p)?; + } + writeln!(f, " }});")?; + writeln!(f, " parser.eof()?;")?; + writeln!(f, " res")?; } - writeln!(f, " }})")?; writeln!(f, " }}")?; writeln!(f, " }}")?; writeln!( @@ -138,35 +193,75 @@ fn write_message(f: &mut W, obj: &str, message: &Message) -> Result<() lifetime, message.camel_name, lifetime )?; writeln!(f, " fn format(self, fmt: &mut MsgFormatter<'_>) {{")?; - writeln!(f, " fmt.header(self.self_id, {});", uppercase)?; - fn write_fmt_expr(f: &mut W, prefix: &str, ty: &Type, access: &str) -> Result<()> { - let p = match ty { - Type::Id(..) => "object", - Type::U32 => "uint", - Type::I32 => "int", - Type::U64 => "u64", - Type::U64Rev => "u64_rev", - Type::OptStr => "optstr", - Type::Str | Type::BStr => "string", - Type::Fixed => "fixed", - Type::Fd => "fd", - Type::Array(..) => "binary", - Type::Pod(..) => "binary", - }; - let rf = match ty { - Type::Pod(..) => "&", - _ => "", - }; - writeln!(f, " {}fmt.{}({}{});", prefix, p, rf, access)?; - Ok(()) - } - for field in &message.fields { - write_fmt_expr( - f, - "", - &field.val.ty.val, - &format!("self.{}", field.val.name), - )?; + if message.is_fixed_size { + writeln!(f, " fmt.data(&[")?; + writeln!(f, " self.self_id.0,")?; + writeln!(f, " {uppercase},")?; + for field in &message.fields { + let prefix = format!(" self.{}", field.val.name); + match &field.val.ty.val { + Type::Id(_, _) => writeln!(f, "{prefix}.0,")?, + Type::U32 => writeln!(f, "{prefix},")?, + Type::I32 => writeln!(f, "{prefix} as u32,")?, + Type::U64 => { + writeln!(f, " (self.{} >> 32) as u32,", field.val.name)?; + writeln!(f, "{prefix} as u32,")?; + } + Type::U64Rev => { + writeln!(f, "{prefix} as u32,")?; + writeln!(f, " (self.{} >> 32) as u32,", field.val.name)?; + } + Type::Str => unreachable!(), + Type::OptStr => unreachable!(), + Type::BStr => unreachable!(), + Type::Fixed => writeln!(f, "{prefix}.0 as u32,")?, + Type::Fd => {} + Type::Array(_) => unreachable!(), + Type::Pod(_) => unreachable!(), + } + } + writeln!(f, " ]);")?; + for field in &message.fields { + if let Type::Fd = &field.val.ty.val { + writeln!(f, " fmt.fd(self.{});", field.val.name)?; + } + } + } else { + writeln!(f, " fmt.header(self.self_id, {});", uppercase)?; + fn write_fmt_expr( + f: &mut W, + prefix: &str, + ty: &Type, + access: &str, + ) -> Result<()> { + let p = match ty { + Type::Id(..) => "object", + Type::U32 => "uint", + Type::I32 => "int", + Type::U64 => "u64", + Type::U64Rev => "u64_rev", + Type::OptStr => "optstr", + Type::Str | Type::BStr => "string", + Type::Fixed => "fixed", + Type::Fd => "fd", + Type::Array(..) => "binary", + Type::Pod(..) => "binary", + }; + let rf = match ty { + Type::Pod(..) => "&", + _ => "", + }; + writeln!(f, " {}fmt.{}({}{});", prefix, p, rf, access)?; + Ok(()) + } + for field in &message.fields { + write_fmt_expr( + f, + "", + &field.val.ty.val, + &format!("self.{}", field.val.name), + )?; + } } writeln!(f, " }}")?; writeln!(f, " fn id(&self) -> ObjectId {{")?; diff --git a/build/wire/parser.rs b/build/wire/parser.rs index e8e8b822..ee09a5dd 100644 --- a/build/wire/parser.rs +++ b/build/wire/parser.rs @@ -235,6 +235,7 @@ pub struct Message { pub fields: Vec>, pub attribs: MessageAttribs, pub has_reference_type: bool, + pub is_fixed_size: bool, } #[derive(Debug, Default)] @@ -344,6 +345,11 @@ impl<'a> Parser<'a> { Type::OptStr | Type::Str | Type::BStr | Type::Array(..) => true, _ => false, }); + let is_variable_size = fields.iter().any(|f| match &f.val.ty.val { + Type::OptStr | Type::Str | Type::BStr | Type::Array(..) | Type::Pod(..) => true, + _ => false, + }); + let is_fixed_size = !is_variable_size; let safe_name = match name { "move" => "move_", "type" => "type_", @@ -361,6 +367,7 @@ impl<'a> Parser<'a> { fields, attribs, has_reference_type, + is_fixed_size, }, }) })(); diff --git a/src/client.rs b/src/client.rs index 11c673a7..8ce8ee5d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -427,7 +427,6 @@ impl Client { mut parser: MsgParser<'_, 'a>, ) -> Result { let res = R::parse(&mut parser)?; - parser.eof()?; log::trace!( "Client {} -> {}@{}.{:?}", self.id, diff --git a/src/client/error.rs b/src/client/error.rs index 5f7402c8..17ccb22b 100644 --- a/src/client/error.rs +++ b/src/client/error.rs @@ -19,10 +19,6 @@ pub enum ClientError { InvalidMethod, #[error("Client tried to access non-existent object {0}")] InvalidObject(ObjectId), - #[error("The message size is < 8")] - MessageSizeTooSmall, - #[error("The size of the message is not a multiple of 4")] - UnalignedMessage, #[error("The requested client {0} does not exist")] ClientDoesNotExist(ClientId), #[error("Server tried to allocate more than 0x1_00_00_00 ids")] diff --git a/src/client/tasks.rs b/src/client/tasks.rs index cd550bf8..fe51a9ea 100644 --- a/src/client/tasks.rs +++ b/src/client/tasks.rs @@ -2,11 +2,9 @@ use { crate::{ async_engine::Phase, client::{Client, ClientError}, - object::ObjectId, utils::{ - buffd::{BufFdIn, BufFdOut, MsgParser}, + buffd::{BufFdOut, MsgParser, WlBufFdIn, WlMessage}, errorfmt::ErrorFmt, - vec_ext::VecExt, }, }, futures_util::{FutureExt, select}, @@ -49,14 +47,14 @@ async fn receive(data: Rc) { }); let display = data.display().unwrap(); let recv = async { - let mut buf = BufFdIn::new(&data.socket, &data.state.ring); - let mut data_buf = Vec::::new(); + let mut buf = WlBufFdIn::new(&data.socket, &data.state.ring); loop { - let mut hdr = [0u32, 0]; - buf.read_full(&mut hdr[..]).await?; - let obj_id = ObjectId::from_raw(hdr[0]); - let len = (hdr[1] >> 16) as usize; - let request = hdr[1] & 0xffff; + let WlMessage { + obj_id, + message, + body, + fds, + } = buf.read_message().await?; let obj = match data.objects.get_obj(obj_id) { Ok(obj) => obj, _ => { @@ -65,28 +63,12 @@ async fn receive(data: Rc) { return Err(ClientError::InvalidObject(obj_id)); } }; - // log::trace!("obj: {}, request: {}, len: {}", obj_id, request, len); - if len < 8 { - return Err(ClientError::MessageSizeTooSmall); - } - if len % 4 != 0 { - return Err(ClientError::UnalignedMessage); - } - let len = len / 4 - 2; - data_buf.clear(); - data_buf.reserve(len); - let unused = data_buf.split_at_spare_mut_ext().1; - buf.read_full(&mut unused[..len]).await?; - unsafe { - data_buf.set_len(len); - } - // log::trace!("{:x?}", data_buf); - let parser = MsgParser::new(&mut buf, &data_buf[..]); - if let Err(e) = obj.handle_request(&data, request, parser) { + let parser = MsgParser::new(fds, body); + if let Err(e) = obj.handle_request(&data, message, parser) { if let ClientError::InvalidMethod = e && let Ok(obj) = data.objects.get_obj(obj_id) { - data.invalid_request(&*obj, request); + data.invalid_request(&*obj, message); return Err(e); } return Err(ClientError::RequestError(Box::new(e))); diff --git a/src/it/test_transport.rs b/src/it/test_transport.rs index 9637d9f8..549613bf 100644 --- a/src/it/test_transport.rs +++ b/src/it/test_transport.rs @@ -13,7 +13,10 @@ use { utils::{ asyncevent::AsyncEvent, bitfield::Bitfield, - buffd::{BufFdIn, BufFdOut, MsgFormatter, MsgParser, OutBuffer, OutBufferSwapchain}, + buffd::{ + BufFdError, BufFdIn, BufFdOut, MsgFormatter, MsgParser, OutBuffer, + OutBufferSwapchain, WlBufFdIn, WlMessage, + }, copyhashmap::CopyHashMap, hash_map_ext::HashMapExt, stack::Stack, @@ -36,7 +39,6 @@ pub struct TestTransport { pub run: Rc, pub socket: Rc, pub client_id: Cell, - pub bufs: Stack>, pub swapchain: Rc>, pub flush_request: AsyncEvent, pub incoming: Cell>>, @@ -153,7 +155,7 @@ impl TestTransport { "", Incoming { tc: self.clone(), - buf: BufFdIn::new(&self.socket, &self.run.state.ring), + buf: WlBufFdIn::new(&self.socket, &self.run.state.ring), } .run(), ), @@ -246,7 +248,7 @@ impl Outgoing { struct Incoming { tc: Rc, - buf: BufFdIn, + buf: WlBufFdIn, } impl Incoming { @@ -267,30 +269,15 @@ impl Incoming { } async fn handle_msg(&mut self) -> Result<(), TestError> { - let mut hdr = [0u32, 0]; - if let Err(e) = self.buf.read_full(&mut hdr[..]).await { - return Err(e.with_context("Could not read from wayland socket")); - } - let obj_id = ObjectId::from_raw(hdr[0]); - let len = (hdr[1] >> 16) as usize; - let request = hdr[1] & 0xffff; - if len < 8 { - bail!("Message size is < 8"); - } - if len % 4 != 0 { - bail!("Message size is not a multiple of 4"); - } - let len = len / 4 - 2; - let mut data_buf = self.tc.bufs.pop().unwrap_or_default(); - data_buf.clear(); - data_buf.reserve(len); - let unused = data_buf.split_at_spare_mut_ext().1; - if let Err(e) = self.buf.read_full(&mut unused[..len]).await { - return Err(e.with_context("Could not read from wayland socket")); - } - unsafe { - data_buf.set_len(len); - } + let WlMessage { + obj_id, + message, + body, + fds, + } = match self.buf.read_message().await { + Ok(m) => m, + Err(e) => return Err(e.with_context("Could not read from wayland socket")), + }; let object = match self.tc.objects.get(&obj_id) { Some(obj) => obj, _ => bail!( @@ -298,11 +285,8 @@ impl Incoming { obj_id ), }; - let parser = MsgParser::new(&mut self.buf, &data_buf); - object.handle_request(request, parser)?; - if data_buf.capacity() > 0 { - self.tc.bufs.push(data_buf); - } + let parser = MsgParser::new(fds, body); + object.handle_request(message, parser)?; Ok(()) } } diff --git a/src/it/testrun.rs b/src/it/testrun.rs index ac14a734..0045a7bf 100644 --- a/src/it/testrun.rs +++ b/src/it/testrun.rs @@ -58,7 +58,6 @@ impl TestRun { run: self.clone(), socket, client_id: Cell::new(ClientId::from_raw(0)), - bufs: Default::default(), swapchain: Default::default(), flush_request: Default::default(), incoming: Default::default(), @@ -146,9 +145,7 @@ pub trait ParseFull<'a>: Sized { impl<'a, T: RequestParser<'a>> ParseFull<'a> for T { fn parse_full(mut parser: MsgParser<'_, 'a>) -> Result { - let res = T::parse(&mut parser)?; - parser.eof()?; - Ok(res) + T::parse(&mut parser).map_err(Into::into) } } diff --git a/src/tools/tool_client.rs b/src/tools/tool_client.rs index 38fe224d..e4215feb 100644 --- a/src/tools/tool_client.rs +++ b/src/tools/tool_client.rs @@ -10,15 +10,13 @@ use { asyncevent::AsyncEvent, bitfield::Bitfield, buffd::{ - BufFdError, BufFdIn, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer, - OutBufferSwapchain, + BufFdError, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer, + OutBufferSwapchain, WlBufFdIn, WlMessage, }, clonecell::CloneCell, errorfmt::ErrorFmt, numcell::NumCell, oserror::OsError, - stack::Stack, - vec_ext::VecExt, xrd::xrd, }, wheel::{Wheel, WheelError}, @@ -59,10 +57,6 @@ pub enum ToolClientError { SocketPathTooLong, #[error("Could not connect to the compositor")] Connect(#[source] IoUringError), - #[error("The message length is smaller than 8 bytes")] - MsgLenTooSmall, - #[error("The size of the message is not a multiple of 4")] - UnalignedMessage, #[error(transparent)] BufFdError(#[from] BufFdError), #[error("Could not parse a message of type {}", .0)] @@ -85,7 +79,6 @@ pub struct ToolClient { AHashMap Result<(), ToolClientError>>>, >, >, - bufs: Stack>, swapchain: Rc>, flush_request: AsyncEvent, pending_futures: RefCell>>, @@ -186,7 +179,6 @@ impl ToolClient { eng, obj_ids: RefCell::new(obj_ids), handlers: Default::default(), - bufs: Default::default(), swapchain: Default::default(), flush_request: Default::default(), pending_futures: Default::default(), @@ -209,7 +201,7 @@ impl ToolClient { "tool client incoming", Incoming { tc: slf.clone(), - buf: BufFdIn::new(&socket, &slf.ring), + buf: WlBufFdIn::new(&socket, &slf.ring), } .run(), ), @@ -528,7 +520,7 @@ impl Outgoing { struct Incoming { tc: Rc, - buf: BufFdIn, + buf: WlBufFdIn, } impl Incoming { @@ -541,44 +533,27 @@ impl Incoming { } async fn handle_msg(&mut self) -> Result<(), ToolClientError> { - let mut hdr = [0u32, 0]; - if let Err(e) = self.buf.read_full(&mut hdr[..]).await { - return Err(ToolClientError::Read(e)); - } - let obj_id = ObjectId::from_raw(hdr[0]); - let len = (hdr[1] >> 16) as usize; - let request = hdr[1] & 0xffff; - if len < 8 { - return Err(ToolClientError::MsgLenTooSmall); - } - if len % 4 != 0 { - return Err(ToolClientError::UnalignedMessage); - } - let len = len / 4 - 2; - let mut data_buf = self.tc.bufs.pop().unwrap_or_default(); - data_buf.clear(); - data_buf.reserve(len); - let unused = data_buf.split_at_spare_mut_ext().1; - if let Err(e) = self.buf.read_full(&mut unused[..len]).await { - return Err(ToolClientError::Read(e)); - } - unsafe { - data_buf.set_len(len); - } + let WlMessage { + obj_id, + message, + body, + fds, + } = self + .buf + .read_message() + .await + .map_err(ToolClientError::Read)?; let mut handler = None; { let handlers = self.tc.handlers.borrow_mut(); if let Some(handlers) = handlers.get(&obj_id) { - handler = handlers.get(&request).cloned(); + handler = handlers.get(&message).cloned(); } } if let Some(handler) = handler { - let mut parser = MsgParser::new(&mut self.buf, &data_buf); + let mut parser = MsgParser::new(fds, body); handler(&mut parser)?; } - if data_buf.capacity() > 0 { - self.tc.bufs.push(data_buf); - } Ok(()) } } diff --git a/src/utils/buffd.rs b/src/utils/buffd.rs index cf28640b..bbb8a67e 100644 --- a/src/utils/buffd.rs +++ b/src/utils/buffd.rs @@ -6,6 +6,7 @@ pub use { ei_parser::{EiMsgParser, EiMsgParserError}, formatter::MsgFormatter, parser::{MsgParser, MsgParserError}, + wl_buf_in::{WlBufFdIn, WlMessage}, }; mod buf_in; @@ -14,6 +15,7 @@ mod ei_formatter; mod ei_parser; mod formatter; mod parser; +mod wl_buf_in; #[derive(Debug, Error)] pub enum BufFdError { @@ -29,6 +31,12 @@ pub enum BufFdError { Closed, #[error("The connection timed out")] Timeout, + #[error("Message size is not a multiple of 4")] + UnalignedMessageSize, + #[error("Message size is larger than 4096")] + MessageTooLarge, + #[error("Message size is smaller than 8")] + MessageTooSmall, } const BUF_SIZE: usize = 4096; diff --git a/src/utils/buffd/formatter.rs b/src/utils/buffd/formatter.rs index 9e56dc06..2ad83618 100644 --- a/src/utils/buffd/formatter.rs +++ b/src/utils/buffd/formatter.rs @@ -33,6 +33,11 @@ impl<'a> MsgFormatter<'a> { self.meta.write_pos += bytes.len(); } + #[inline(always)] + pub fn data(&mut self, data: &[u32]) { + self.write(uapi::as_bytes(data)); + } + pub fn int(&mut self, int: i32) -> &mut Self { self.write(uapi::as_bytes(&int)); self @@ -43,11 +48,13 @@ impl<'a> MsgFormatter<'a> { self } + #[expect(dead_code)] pub fn u64(&mut self, int: u64) -> &mut Self { self.uint((int >> 32) as u32); self.uint(int as u32) } + #[expect(dead_code)] pub fn u64_rev(&mut self, int: u64) -> &mut Self { self.uint(int as u32); self.uint((int >> 32) as u32) diff --git a/src/utils/buffd/parser.rs b/src/utils/buffd/parser.rs index a8e92c67..7ae74118 100644 --- a/src/utils/buffd/parser.rs +++ b/src/utils/buffd/parser.rs @@ -1,7 +1,7 @@ use { - crate::{fixed::Fixed, globals::GlobalName, object::ObjectId, utils::buffd::BufFdIn}, + crate::{fixed::Fixed, globals::GlobalName, object::ObjectId}, bstr::{BStr, ByteSlice}, - std::{ptr, rc::Rc}, + std::{collections::VecDeque, ptr, rc::Rc}, thiserror::Error, uapi::{OwnedFd, Pod}, }; @@ -22,29 +22,32 @@ pub enum MsgParserError { TrailingData, #[error("String is not UTF-8")] NonUtf8, + #[error("The message has an unexpected size")] + UnexpectedMessageSize, } pub struct MsgParser<'a, 'b> { - buf: &'a mut BufFdIn, + fds: &'a mut VecDeque>, pos: usize, - data: &'b [u8], + data: &'b [u32], } impl<'a, 'b> MsgParser<'a, 'b> { - pub fn new(buf: &'a mut BufFdIn, data: &'b [u32]) -> Self { - Self { - buf, - pos: 0, - data: uapi::as_bytes(data), - } + pub fn new(fds: &'a mut VecDeque>, data: &'b [u32]) -> Self { + Self { fds, pos: 0, data } + } + + #[inline(always)] + pub fn data(&self) -> &[u32] { + self.data } pub fn int(&mut self) -> Result { - if self.data.len() - self.pos < 4 { + if self.pos >= self.data.len() { return Err(MsgParserError::UnexpectedEof); } let res = unsafe { *(self.data.as_ptr().add(self.pos) as *const i32) }; - self.pos += 4; + self.pos += 1; Ok(res) } @@ -52,12 +55,14 @@ impl<'a, 'b> MsgParser<'a, 'b> { self.int().map(|i| i as u32) } + #[expect(dead_code)] pub fn u64(&mut self) -> Result { let hi = self.uint()?; let lo = self.uint()?; Ok(((hi as u64) << 32) | lo as u64) } + #[expect(dead_code)] pub fn u64_rev(&mut self) -> Result { let lo = self.uint()?; let hi = self.uint()?; @@ -107,8 +112,8 @@ impl<'a, 'b> MsgParser<'a, 'b> { } pub fn fd(&mut self) -> Result, MsgParserError> { - match self.buf.get_fd() { - Ok(fd) => Ok(fd), + match self.fds.pop_front() { + Some(fd) => Ok(fd), _ => Err(MsgParserError::MissingFd), } } @@ -123,13 +128,13 @@ impl<'a, 'b> MsgParser<'a, 'b> { pub fn array(&mut self) -> Result<&'b [u8], MsgParserError> { let len = self.uint()? as usize; - let cap = (len + 3) & !3; + let cap = (len + 3) >> 2; if cap > self.data.len() - self.pos { return Err(MsgParserError::UnexpectedEof); } let pos = self.pos; self.pos += cap; - Ok(&self.data[pos..pos + len]) + Ok(&uapi::as_bytes(&self.data[pos..])[..len]) } pub fn binary(&mut self) -> Result { diff --git a/src/utils/buffd/wl_buf_in.rs b/src/utils/buffd/wl_buf_in.rs new file mode 100644 index 00000000..8bca5956 --- /dev/null +++ b/src/utils/buffd/wl_buf_in.rs @@ -0,0 +1,125 @@ +use { + crate::{ + io_uring::IoUring, + object::ObjectId, + utils::{ + buf::Buf, + buffd::{BufFdError, MAX_IN_FD}, + }, + }, + std::{collections::VecDeque, ptr, rc::Rc, slice}, + uapi::OwnedFd, +}; + +const WORD_SIZE: usize = 4; +const WORD_ALIGN: usize = 4; +const HEADER_WORDS: usize = 2; +const HEADER_SIZE: usize = HEADER_WORDS * WORD_SIZE; +const MAX_MESSAGE_SIZE: usize = 4096; +const BUF_SIZE: usize = 2 * MAX_MESSAGE_SIZE; + +pub struct WlBufFdIn { + fd: Rc, + ring: Rc, + fds: VecDeque>, + buf: Buf, + lo: usize, + len: usize, +} + +pub struct WlMessage<'a> { + pub obj_id: ObjectId, + pub message: u32, + pub body: &'a [u32], + pub fds: &'a mut VecDeque>, +} + +impl WlBufFdIn { + pub fn new(fd: &Rc, ring: &Rc) -> Self { + let buf = Buf::new(BUF_SIZE); + assert_eq!(buf.as_ptr() as usize % WORD_ALIGN, 0); + Self { + fd: fd.clone(), + ring: ring.clone(), + fds: Default::default(), + buf, + lo: Default::default(), + len: Default::default(), + } + } + + pub async fn read_message(&mut self) -> Result, BufFdError> { + if self.len == 0 { + self.lo = 0; + } + if self.len < HEADER_SIZE { + if self.lo > 0 { + self.compact(); + } + while self.len < HEADER_SIZE { + self.recvmsg().await?; + } + } + let hdr: &[u32] = + unsafe { slice::from_raw_parts(self.buf[self.lo..].as_ptr().cast(), HEADER_WORDS) }; + let obj_id = ObjectId::from_raw(hdr[0]); + let len = (hdr[1] >> 16) as usize; + let message = hdr[1] & 0xffff; + if len & 3 != 0 { + return Err(BufFdError::UnalignedMessageSize); + } + if len > MAX_MESSAGE_SIZE { + return Err(BufFdError::MessageTooLarge); + } + if len < HEADER_SIZE { + return Err(BufFdError::MessageTooSmall); + } + if len > self.len { + if self.lo + self.len >= MAX_MESSAGE_SIZE { + self.compact(); + } + while len > self.len { + self.recvmsg().await?; + } + } + let body: &[u32] = unsafe { + let words = (len - HEADER_SIZE) >> 2; + slice::from_raw_parts(self.buf[self.lo + HEADER_SIZE..].as_ptr().cast(), words) + }; + self.lo += len; + self.len -= len; + Ok(WlMessage { + obj_id, + message, + body, + fds: &mut self.fds, + }) + } + + #[inline(always)] + fn compact(&mut self) { + unsafe { + let dst = self.buf.as_mut_ptr(); + let src = dst.add(self.lo); + ptr::copy(src, dst, self.len); + self.lo = 0; + } + } + + async fn recvmsg(&mut self) -> Result<(), BufFdError> { + let mut buf = self.buf.slice(self.lo + self.len..); + match self + .ring + .recvmsg(&self.fd, slice::from_mut(&mut buf), &mut self.fds) + .await + { + Ok(0) => return Err(BufFdError::Closed), + Ok(n) => self.len += n, + Err(e) => return Err(BufFdError::Ring(e)), + } + if self.fds.len() > MAX_IN_FD { + return Err(BufFdError::TooManyFds); + } + Ok(()) + } +} diff --git a/src/wl_usr.rs b/src/wl_usr.rs index dcaa79bf..cbb6dee9 100644 --- a/src/wl_usr.rs +++ b/src/wl_usr.rs @@ -11,15 +11,14 @@ use { asyncevent::AsyncEvent, bitfield::Bitfield, buffd::{ - BufFdError, BufFdIn, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer, - OutBufferSwapchain, + BufFdError, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer, + OutBufferSwapchain, WlBufFdIn, WlMessage, }, clonecell::CloneCell, copyhashmap::CopyHashMap, errorfmt::ErrorFmt, hash_map_ext::HashMapExt, oserror::OsError, - vec_ext::VecExt, }, video::dmabuf::DmaBufIds, wheel::Wheel, @@ -51,10 +50,6 @@ pub enum UsrConError { SocketPathTooLong, #[error("Could not connect to the compositor")] Connect(#[source] IoUringError), - #[error("The message length is smaller than 8 bytes")] - MsgLenTooSmall, - #[error("The size of the message is not a multiple of 4")] - UnalignedMessage, #[error(transparent)] BufFdError(#[from] BufFdError), #[error("Could not read from the compositor")] @@ -168,8 +163,7 @@ impl UsrCon { "wl_usr incoming", Incoming { con: slf.clone(), - buf: BufFdIn::new(socket, &slf.ring), - data: vec![], + buf: WlBufFdIn::new(socket, &slf.ring), } .run(), ), @@ -257,7 +251,6 @@ impl UsrCon { mut parser: MsgParser<'_, 'a>, ) -> Result { let res = R::parse(&mut parser)?; - parser.eof()?; log::trace!( "Server {} -> {}@{}.{:?}", self.server_id, @@ -338,8 +331,7 @@ impl Outgoing { struct Incoming { con: Rc, - buf: BufFdIn, - data: Vec, + buf: WlBufFdIn, } impl Incoming { @@ -358,33 +350,16 @@ impl Incoming { } async fn handle_msg(&mut self) -> Result<(), UsrConError> { - let mut hdr = [0u32, 0]; - if let Err(e) = self.buf.read_full(&mut hdr[..]).await { - return Err(UsrConError::Read(e)); - } - let obj_id = ObjectId::from_raw(hdr[0]); - let len = (hdr[1] >> 16) as usize; - let event = hdr[1] & 0xffff; - if len < 8 { - return Err(UsrConError::MsgLenTooSmall); - } - if len % 4 != 0 { - return Err(UsrConError::UnalignedMessage); - } - let len = len / 4 - 2; - self.data.clear(); - self.data.reserve(len); - let unused = self.data.split_at_spare_mut_ext().1; - if let Err(e) = self.buf.read_full(&mut unused[..len]).await { - return Err(UsrConError::Read(e)); - } - unsafe { - self.data.set_len(len); - } + let WlMessage { + obj_id, + message, + body, + fds, + } = self.buf.read_message().await.map_err(UsrConError::Read)?; if let Some(obj) = self.con.objects.get(&obj_id) { if let Some(obj) = obj { - let parser = MsgParser::new(&mut self.buf, &self.data); - obj.handle_event(&self.con, event, parser)?; + let parser = MsgParser::new(fds, body); + obj.handle_event(&self.con, message, parser)?; } } else if obj_id.raw() < MIN_SERVER_ID { return Err(UsrConError::MissingObject(obj_id));