1
0
Fork 0
forked from wry/wry

Merge pull request #837 from mahkoh/jorth/optimize-wire

wayland: optimize parsing and formatting messages
This commit is contained in:
mahkoh 2026-03-29 13:57:04 +02:00 committed by GitHub
commit 59aedd2c27
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 369 additions and 214 deletions

View file

@ -6,7 +6,7 @@ use {
wire::parser::{Field, Lined, Message, Type, parse_messages, to_camel}, wire::parser::{Field, Lined, Message, Type, parse_messages, to_camel},
}, },
anyhow::{Context, Result}, anyhow::{Context, Result},
std::{fs::DirEntry, io::Write, os::unix::ffi::OsStrExt}, std::{fmt, fs::DirEntry, io::Write, os::unix::ffi::OsStrExt},
}; };
fn write_type<W: Write>(f: &mut W, ty: &Type) -> Result<()> { fn write_type<W: Write>(f: &mut W, ty: &Type) -> Result<()> {
@ -110,26 +110,81 @@ fn write_message<W: Write>(f: &mut W, obj: &str, message: &Message) -> Result<()
" fn parse({}: &mut MsgParser<'_, 'a>) -> Result<Self, MsgParserError> {{", " fn parse({}: &mut MsgParser<'_, 'a>) -> Result<Self, MsgParserError> {{",
parser parser
)?; )?;
writeln!(f, " Ok(Self {{")?; if message.is_fixed_size {
writeln!(f, " self_id: {}Id::NONE,", obj)?; writeln!(f, " let [")?;
for field in &message.fields { for (i, field) in message.fields.iter().enumerate() {
let p = match &field.val.ty.val { match &field.val.ty.val {
Type::Id(..) => "object", Type::U64 => {
Type::U32 => "uint", writeln!(f, " arg{i}_hi,")?;
Type::I32 => "int", writeln!(f, " arg{i}_lo,")?;
Type::U64 => "u64", }
Type::U64Rev => "u64_rev", Type::U64Rev => {
Type::OptStr => "optstr", writeln!(f, " arg{i}_lo,")?;
Type::Str => "str", writeln!(f, " arg{i}_hi,")?;
Type::Fixed => "fixed", }
Type::Fd => "fd", Type::Fd => {}
Type::BStr => "bstr", _ => {
Type::Array(_) => "binary_array", writeln!(f, " arg{i},")?;
Type::Pod(_) => "binary", }
}; }
writeln!(f, " {}: parser.{}()?,", field.val.name, p)?; }
writeln!(f, " ] = *{parser}.data() else {{")?;
writeln!(
f,
" return Err(MsgParserError::UnexpectedMessageSize);"
)?;
writeln!(f, " }};")?;
writeln!(f, " Ok(Self {{")?;
writeln!(f, " self_id: {}Id::NONE,", obj)?;
for (i, field) in message.fields.iter().enumerate() {
writeln!(
f,
" {}: {},",
field.val.name,
fmt::from_fn(|f| {
match &field.val.ty.val {
Type::Id(_, name) => write!(f, "{name}Id(arg{i})"),
Type::U32 => write!(f, "arg{i}"),
Type::I32 => write!(f, "arg{i} as i32"),
Type::U64 | Type::U64Rev => {
write!(f, "((arg{i}_hi as u64) << 32) | (arg{i}_lo as u64)")
}
Type::OptStr => unreachable!(),
Type::Str => unreachable!(),
Type::Fixed => write!(f, "Fixed(arg{i} as i32)"),
Type::Fd => write!(f, "parser.fd()?"),
Type::BStr => unreachable!(),
Type::Array(_) => unreachable!(),
Type::Pod(_) => unreachable!(),
}
})
)?;
}
writeln!(f, " }})")?;
} else {
writeln!(f, " let res = Ok(Self {{")?;
writeln!(f, " self_id: {}Id::NONE,", obj)?;
for field in &message.fields {
let p = match &field.val.ty.val {
Type::Id(..) => "object",
Type::U32 => "uint",
Type::I32 => "int",
Type::U64 => "u64",
Type::U64Rev => "u64_rev",
Type::OptStr => "optstr",
Type::Str => "str",
Type::Fixed => "fixed",
Type::Fd => "fd",
Type::BStr => "bstr",
Type::Array(_) => "binary_array",
Type::Pod(_) => "binary",
};
writeln!(f, " {}: parser.{}()?,", field.val.name, p)?;
}
writeln!(f, " }});")?;
writeln!(f, " parser.eof()?;")?;
writeln!(f, " res")?;
} }
writeln!(f, " }})")?;
writeln!(f, " }}")?; writeln!(f, " }}")?;
writeln!(f, " }}")?; writeln!(f, " }}")?;
writeln!( writeln!(
@ -138,35 +193,75 @@ fn write_message<W: Write>(f: &mut W, obj: &str, message: &Message) -> Result<()
lifetime, message.camel_name, lifetime lifetime, message.camel_name, lifetime
)?; )?;
writeln!(f, " fn format(self, fmt: &mut MsgFormatter<'_>) {{")?; writeln!(f, " fn format(self, fmt: &mut MsgFormatter<'_>) {{")?;
writeln!(f, " fmt.header(self.self_id, {});", uppercase)?; if message.is_fixed_size {
fn write_fmt_expr<W: Write>(f: &mut W, prefix: &str, ty: &Type, access: &str) -> Result<()> { writeln!(f, " fmt.data(&[")?;
let p = match ty { writeln!(f, " self.self_id.0,")?;
Type::Id(..) => "object", writeln!(f, " {uppercase},")?;
Type::U32 => "uint", for field in &message.fields {
Type::I32 => "int", let prefix = format!(" self.{}", field.val.name);
Type::U64 => "u64", match &field.val.ty.val {
Type::U64Rev => "u64_rev", Type::Id(_, _) => writeln!(f, "{prefix}.0,")?,
Type::OptStr => "optstr", Type::U32 => writeln!(f, "{prefix},")?,
Type::Str | Type::BStr => "string", Type::I32 => writeln!(f, "{prefix} as u32,")?,
Type::Fixed => "fixed", Type::U64 => {
Type::Fd => "fd", writeln!(f, " (self.{} >> 32) as u32,", field.val.name)?;
Type::Array(..) => "binary", writeln!(f, "{prefix} as u32,")?;
Type::Pod(..) => "binary", }
}; Type::U64Rev => {
let rf = match ty { writeln!(f, "{prefix} as u32,")?;
Type::Pod(..) => "&", writeln!(f, " (self.{} >> 32) as u32,", field.val.name)?;
_ => "", }
}; Type::Str => unreachable!(),
writeln!(f, " {}fmt.{}({}{});", prefix, p, rf, access)?; Type::OptStr => unreachable!(),
Ok(()) Type::BStr => unreachable!(),
} Type::Fixed => writeln!(f, "{prefix}.0 as u32,")?,
for field in &message.fields { Type::Fd => {}
write_fmt_expr( Type::Array(_) => unreachable!(),
f, Type::Pod(_) => unreachable!(),
"", }
&field.val.ty.val, }
&format!("self.{}", field.val.name), writeln!(f, " ]);")?;
)?; for field in &message.fields {
if let Type::Fd = &field.val.ty.val {
writeln!(f, " fmt.fd(self.{});", field.val.name)?;
}
}
} else {
writeln!(f, " fmt.header(self.self_id, {});", uppercase)?;
fn write_fmt_expr<W: Write>(
f: &mut W,
prefix: &str,
ty: &Type,
access: &str,
) -> Result<()> {
let p = match ty {
Type::Id(..) => "object",
Type::U32 => "uint",
Type::I32 => "int",
Type::U64 => "u64",
Type::U64Rev => "u64_rev",
Type::OptStr => "optstr",
Type::Str | Type::BStr => "string",
Type::Fixed => "fixed",
Type::Fd => "fd",
Type::Array(..) => "binary",
Type::Pod(..) => "binary",
};
let rf = match ty {
Type::Pod(..) => "&",
_ => "",
};
writeln!(f, " {}fmt.{}({}{});", prefix, p, rf, access)?;
Ok(())
}
for field in &message.fields {
write_fmt_expr(
f,
"",
&field.val.ty.val,
&format!("self.{}", field.val.name),
)?;
}
} }
writeln!(f, " }}")?; writeln!(f, " }}")?;
writeln!(f, " fn id(&self) -> ObjectId {{")?; writeln!(f, " fn id(&self) -> ObjectId {{")?;

View file

@ -235,6 +235,7 @@ pub struct Message {
pub fields: Vec<Lined<Field>>, pub fields: Vec<Lined<Field>>,
pub attribs: MessageAttribs, pub attribs: MessageAttribs,
pub has_reference_type: bool, pub has_reference_type: bool,
pub is_fixed_size: bool,
} }
#[derive(Debug, Default)] #[derive(Debug, Default)]
@ -344,6 +345,11 @@ impl<'a> Parser<'a> {
Type::OptStr | Type::Str | Type::BStr | Type::Array(..) => true, Type::OptStr | Type::Str | Type::BStr | Type::Array(..) => true,
_ => false, _ => false,
}); });
let is_variable_size = fields.iter().any(|f| match &f.val.ty.val {
Type::OptStr | Type::Str | Type::BStr | Type::Array(..) | Type::Pod(..) => true,
_ => false,
});
let is_fixed_size = !is_variable_size;
let safe_name = match name { let safe_name = match name {
"move" => "move_", "move" => "move_",
"type" => "type_", "type" => "type_",
@ -361,6 +367,7 @@ impl<'a> Parser<'a> {
fields, fields,
attribs, attribs,
has_reference_type, has_reference_type,
is_fixed_size,
}, },
}) })
})(); })();

View file

@ -427,7 +427,6 @@ impl Client {
mut parser: MsgParser<'_, 'a>, mut parser: MsgParser<'_, 'a>,
) -> Result<R, MsgParserError> { ) -> Result<R, MsgParserError> {
let res = R::parse(&mut parser)?; let res = R::parse(&mut parser)?;
parser.eof()?;
log::trace!( log::trace!(
"Client {} -> {}@{}.{:?}", "Client {} -> {}@{}.{:?}",
self.id, self.id,

View file

@ -19,10 +19,6 @@ pub enum ClientError {
InvalidMethod, InvalidMethod,
#[error("Client tried to access non-existent object {0}")] #[error("Client tried to access non-existent object {0}")]
InvalidObject(ObjectId), InvalidObject(ObjectId),
#[error("The message size is < 8")]
MessageSizeTooSmall,
#[error("The size of the message is not a multiple of 4")]
UnalignedMessage,
#[error("The requested client {0} does not exist")] #[error("The requested client {0} does not exist")]
ClientDoesNotExist(ClientId), ClientDoesNotExist(ClientId),
#[error("Server tried to allocate more than 0x1_00_00_00 ids")] #[error("Server tried to allocate more than 0x1_00_00_00 ids")]

View file

@ -2,11 +2,9 @@ use {
crate::{ crate::{
async_engine::Phase, async_engine::Phase,
client::{Client, ClientError}, client::{Client, ClientError},
object::ObjectId,
utils::{ utils::{
buffd::{BufFdIn, BufFdOut, MsgParser}, buffd::{BufFdOut, MsgParser, WlBufFdIn, WlMessage},
errorfmt::ErrorFmt, errorfmt::ErrorFmt,
vec_ext::VecExt,
}, },
}, },
futures_util::{FutureExt, select}, futures_util::{FutureExt, select},
@ -49,14 +47,14 @@ async fn receive(data: Rc<Client>) {
}); });
let display = data.display().unwrap(); let display = data.display().unwrap();
let recv = async { let recv = async {
let mut buf = BufFdIn::new(&data.socket, &data.state.ring); let mut buf = WlBufFdIn::new(&data.socket, &data.state.ring);
let mut data_buf = Vec::<u32>::new();
loop { loop {
let mut hdr = [0u32, 0]; let WlMessage {
buf.read_full(&mut hdr[..]).await?; obj_id,
let obj_id = ObjectId::from_raw(hdr[0]); message,
let len = (hdr[1] >> 16) as usize; body,
let request = hdr[1] & 0xffff; fds,
} = buf.read_message().await?;
let obj = match data.objects.get_obj(obj_id) { let obj = match data.objects.get_obj(obj_id) {
Ok(obj) => obj, Ok(obj) => obj,
_ => { _ => {
@ -65,28 +63,12 @@ async fn receive(data: Rc<Client>) {
return Err(ClientError::InvalidObject(obj_id)); return Err(ClientError::InvalidObject(obj_id));
} }
}; };
// log::trace!("obj: {}, request: {}, len: {}", obj_id, request, len); let parser = MsgParser::new(fds, body);
if len < 8 { if let Err(e) = obj.handle_request(&data, message, parser) {
return Err(ClientError::MessageSizeTooSmall);
}
if len % 4 != 0 {
return Err(ClientError::UnalignedMessage);
}
let len = len / 4 - 2;
data_buf.clear();
data_buf.reserve(len);
let unused = data_buf.split_at_spare_mut_ext().1;
buf.read_full(&mut unused[..len]).await?;
unsafe {
data_buf.set_len(len);
}
// log::trace!("{:x?}", data_buf);
let parser = MsgParser::new(&mut buf, &data_buf[..]);
if let Err(e) = obj.handle_request(&data, request, parser) {
if let ClientError::InvalidMethod = e if let ClientError::InvalidMethod = e
&& let Ok(obj) = data.objects.get_obj(obj_id) && let Ok(obj) = data.objects.get_obj(obj_id)
{ {
data.invalid_request(&*obj, request); data.invalid_request(&*obj, message);
return Err(e); return Err(e);
} }
return Err(ClientError::RequestError(Box::new(e))); return Err(ClientError::RequestError(Box::new(e)));

View file

@ -13,7 +13,10 @@ use {
utils::{ utils::{
asyncevent::AsyncEvent, asyncevent::AsyncEvent,
bitfield::Bitfield, bitfield::Bitfield,
buffd::{BufFdIn, BufFdOut, MsgFormatter, MsgParser, OutBuffer, OutBufferSwapchain}, buffd::{
BufFdError, BufFdIn, BufFdOut, MsgFormatter, MsgParser, OutBuffer,
OutBufferSwapchain, WlBufFdIn, WlMessage,
},
copyhashmap::CopyHashMap, copyhashmap::CopyHashMap,
hash_map_ext::HashMapExt, hash_map_ext::HashMapExt,
stack::Stack, stack::Stack,
@ -36,7 +39,6 @@ pub struct TestTransport {
pub run: Rc<TestRun>, pub run: Rc<TestRun>,
pub socket: Rc<OwnedFd>, pub socket: Rc<OwnedFd>,
pub client_id: Cell<ClientId>, pub client_id: Cell<ClientId>,
pub bufs: Stack<Vec<u32>>,
pub swapchain: Rc<RefCell<OutBufferSwapchain>>, pub swapchain: Rc<RefCell<OutBufferSwapchain>>,
pub flush_request: AsyncEvent, pub flush_request: AsyncEvent,
pub incoming: Cell<Option<SpawnedFuture<()>>>, pub incoming: Cell<Option<SpawnedFuture<()>>>,
@ -153,7 +155,7 @@ impl TestTransport {
"", "",
Incoming { Incoming {
tc: self.clone(), tc: self.clone(),
buf: BufFdIn::new(&self.socket, &self.run.state.ring), buf: WlBufFdIn::new(&self.socket, &self.run.state.ring),
} }
.run(), .run(),
), ),
@ -246,7 +248,7 @@ impl Outgoing {
struct Incoming { struct Incoming {
tc: Rc<TestTransport>, tc: Rc<TestTransport>,
buf: BufFdIn, buf: WlBufFdIn,
} }
impl Incoming { impl Incoming {
@ -267,30 +269,15 @@ impl Incoming {
} }
async fn handle_msg(&mut self) -> Result<(), TestError> { async fn handle_msg(&mut self) -> Result<(), TestError> {
let mut hdr = [0u32, 0]; let WlMessage {
if let Err(e) = self.buf.read_full(&mut hdr[..]).await { obj_id,
return Err(e.with_context("Could not read from wayland socket")); message,
} body,
let obj_id = ObjectId::from_raw(hdr[0]); fds,
let len = (hdr[1] >> 16) as usize; } = match self.buf.read_message().await {
let request = hdr[1] & 0xffff; Ok(m) => m,
if len < 8 { Err(e) => return Err(e.with_context("Could not read from wayland socket")),
bail!("Message size is < 8"); };
}
if len % 4 != 0 {
bail!("Message size is not a multiple of 4");
}
let len = len / 4 - 2;
let mut data_buf = self.tc.bufs.pop().unwrap_or_default();
data_buf.clear();
data_buf.reserve(len);
let unused = data_buf.split_at_spare_mut_ext().1;
if let Err(e) = self.buf.read_full(&mut unused[..len]).await {
return Err(e.with_context("Could not read from wayland socket"));
}
unsafe {
data_buf.set_len(len);
}
let object = match self.tc.objects.get(&obj_id) { let object = match self.tc.objects.get(&obj_id) {
Some(obj) => obj, Some(obj) => obj,
_ => bail!( _ => bail!(
@ -298,11 +285,8 @@ impl Incoming {
obj_id obj_id
), ),
}; };
let parser = MsgParser::new(&mut self.buf, &data_buf); let parser = MsgParser::new(fds, body);
object.handle_request(request, parser)?; object.handle_request(message, parser)?;
if data_buf.capacity() > 0 {
self.tc.bufs.push(data_buf);
}
Ok(()) Ok(())
} }
} }

View file

@ -58,7 +58,6 @@ impl TestRun {
run: self.clone(), run: self.clone(),
socket, socket,
client_id: Cell::new(ClientId::from_raw(0)), client_id: Cell::new(ClientId::from_raw(0)),
bufs: Default::default(),
swapchain: Default::default(), swapchain: Default::default(),
flush_request: Default::default(), flush_request: Default::default(),
incoming: Default::default(), incoming: Default::default(),
@ -146,9 +145,7 @@ pub trait ParseFull<'a>: Sized {
impl<'a, T: RequestParser<'a>> ParseFull<'a> for T { impl<'a, T: RequestParser<'a>> ParseFull<'a> for T {
fn parse_full(mut parser: MsgParser<'_, 'a>) -> Result<Self, TestError> { fn parse_full(mut parser: MsgParser<'_, 'a>) -> Result<Self, TestError> {
let res = T::parse(&mut parser)?; T::parse(&mut parser).map_err(Into::into)
parser.eof()?;
Ok(res)
} }
} }

View file

@ -10,15 +10,13 @@ use {
asyncevent::AsyncEvent, asyncevent::AsyncEvent,
bitfield::Bitfield, bitfield::Bitfield,
buffd::{ buffd::{
BufFdError, BufFdIn, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer, BufFdError, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer,
OutBufferSwapchain, OutBufferSwapchain, WlBufFdIn, WlMessage,
}, },
clonecell::CloneCell, clonecell::CloneCell,
errorfmt::ErrorFmt, errorfmt::ErrorFmt,
numcell::NumCell, numcell::NumCell,
oserror::OsError, oserror::OsError,
stack::Stack,
vec_ext::VecExt,
xrd::xrd, xrd::xrd,
}, },
wheel::{Wheel, WheelError}, wheel::{Wheel, WheelError},
@ -59,10 +57,6 @@ pub enum ToolClientError {
SocketPathTooLong, SocketPathTooLong,
#[error("Could not connect to the compositor")] #[error("Could not connect to the compositor")]
Connect(#[source] IoUringError), Connect(#[source] IoUringError),
#[error("The message length is smaller than 8 bytes")]
MsgLenTooSmall,
#[error("The size of the message is not a multiple of 4")]
UnalignedMessage,
#[error(transparent)] #[error(transparent)]
BufFdError(#[from] BufFdError), BufFdError(#[from] BufFdError),
#[error("Could not parse a message of type {}", .0)] #[error("Could not parse a message of type {}", .0)]
@ -85,7 +79,6 @@ pub struct ToolClient {
AHashMap<u32, Rc<dyn Fn(&mut MsgParser) -> Result<(), ToolClientError>>>, AHashMap<u32, Rc<dyn Fn(&mut MsgParser) -> Result<(), ToolClientError>>>,
>, >,
>, >,
bufs: Stack<Vec<u32>>,
swapchain: Rc<RefCell<OutBufferSwapchain>>, swapchain: Rc<RefCell<OutBufferSwapchain>>,
flush_request: AsyncEvent, flush_request: AsyncEvent,
pending_futures: RefCell<AHashMap<u32, SpawnedFuture<()>>>, pending_futures: RefCell<AHashMap<u32, SpawnedFuture<()>>>,
@ -186,7 +179,6 @@ impl ToolClient {
eng, eng,
obj_ids: RefCell::new(obj_ids), obj_ids: RefCell::new(obj_ids),
handlers: Default::default(), handlers: Default::default(),
bufs: Default::default(),
swapchain: Default::default(), swapchain: Default::default(),
flush_request: Default::default(), flush_request: Default::default(),
pending_futures: Default::default(), pending_futures: Default::default(),
@ -209,7 +201,7 @@ impl ToolClient {
"tool client incoming", "tool client incoming",
Incoming { Incoming {
tc: slf.clone(), tc: slf.clone(),
buf: BufFdIn::new(&socket, &slf.ring), buf: WlBufFdIn::new(&socket, &slf.ring),
} }
.run(), .run(),
), ),
@ -528,7 +520,7 @@ impl Outgoing {
struct Incoming { struct Incoming {
tc: Rc<ToolClient>, tc: Rc<ToolClient>,
buf: BufFdIn, buf: WlBufFdIn,
} }
impl Incoming { impl Incoming {
@ -541,44 +533,27 @@ impl Incoming {
} }
async fn handle_msg(&mut self) -> Result<(), ToolClientError> { async fn handle_msg(&mut self) -> Result<(), ToolClientError> {
let mut hdr = [0u32, 0]; let WlMessage {
if let Err(e) = self.buf.read_full(&mut hdr[..]).await { obj_id,
return Err(ToolClientError::Read(e)); message,
} body,
let obj_id = ObjectId::from_raw(hdr[0]); fds,
let len = (hdr[1] >> 16) as usize; } = self
let request = hdr[1] & 0xffff; .buf
if len < 8 { .read_message()
return Err(ToolClientError::MsgLenTooSmall); .await
} .map_err(ToolClientError::Read)?;
if len % 4 != 0 {
return Err(ToolClientError::UnalignedMessage);
}
let len = len / 4 - 2;
let mut data_buf = self.tc.bufs.pop().unwrap_or_default();
data_buf.clear();
data_buf.reserve(len);
let unused = data_buf.split_at_spare_mut_ext().1;
if let Err(e) = self.buf.read_full(&mut unused[..len]).await {
return Err(ToolClientError::Read(e));
}
unsafe {
data_buf.set_len(len);
}
let mut handler = None; let mut handler = None;
{ {
let handlers = self.tc.handlers.borrow_mut(); let handlers = self.tc.handlers.borrow_mut();
if let Some(handlers) = handlers.get(&obj_id) { if let Some(handlers) = handlers.get(&obj_id) {
handler = handlers.get(&request).cloned(); handler = handlers.get(&message).cloned();
} }
} }
if let Some(handler) = handler { if let Some(handler) = handler {
let mut parser = MsgParser::new(&mut self.buf, &data_buf); let mut parser = MsgParser::new(fds, body);
handler(&mut parser)?; handler(&mut parser)?;
} }
if data_buf.capacity() > 0 {
self.tc.bufs.push(data_buf);
}
Ok(()) Ok(())
} }
} }

View file

@ -6,6 +6,7 @@ pub use {
ei_parser::{EiMsgParser, EiMsgParserError}, ei_parser::{EiMsgParser, EiMsgParserError},
formatter::MsgFormatter, formatter::MsgFormatter,
parser::{MsgParser, MsgParserError}, parser::{MsgParser, MsgParserError},
wl_buf_in::{WlBufFdIn, WlMessage},
}; };
mod buf_in; mod buf_in;
@ -14,6 +15,7 @@ mod ei_formatter;
mod ei_parser; mod ei_parser;
mod formatter; mod formatter;
mod parser; mod parser;
mod wl_buf_in;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum BufFdError { pub enum BufFdError {
@ -29,6 +31,12 @@ pub enum BufFdError {
Closed, Closed,
#[error("The connection timed out")] #[error("The connection timed out")]
Timeout, Timeout,
#[error("Message size is not a multiple of 4")]
UnalignedMessageSize,
#[error("Message size is larger than 4096")]
MessageTooLarge,
#[error("Message size is smaller than 8")]
MessageTooSmall,
} }
const BUF_SIZE: usize = 4096; const BUF_SIZE: usize = 4096;

View file

@ -33,6 +33,11 @@ impl<'a> MsgFormatter<'a> {
self.meta.write_pos += bytes.len(); self.meta.write_pos += bytes.len();
} }
#[inline(always)]
pub fn data(&mut self, data: &[u32]) {
self.write(uapi::as_bytes(data));
}
pub fn int(&mut self, int: i32) -> &mut Self { pub fn int(&mut self, int: i32) -> &mut Self {
self.write(uapi::as_bytes(&int)); self.write(uapi::as_bytes(&int));
self self
@ -43,11 +48,13 @@ impl<'a> MsgFormatter<'a> {
self self
} }
#[expect(dead_code)]
pub fn u64(&mut self, int: u64) -> &mut Self { pub fn u64(&mut self, int: u64) -> &mut Self {
self.uint((int >> 32) as u32); self.uint((int >> 32) as u32);
self.uint(int as u32) self.uint(int as u32)
} }
#[expect(dead_code)]
pub fn u64_rev(&mut self, int: u64) -> &mut Self { pub fn u64_rev(&mut self, int: u64) -> &mut Self {
self.uint(int as u32); self.uint(int as u32);
self.uint((int >> 32) as u32) self.uint((int >> 32) as u32)

View file

@ -1,7 +1,7 @@
use { use {
crate::{fixed::Fixed, globals::GlobalName, object::ObjectId, utils::buffd::BufFdIn}, crate::{fixed::Fixed, globals::GlobalName, object::ObjectId},
bstr::{BStr, ByteSlice}, bstr::{BStr, ByteSlice},
std::{ptr, rc::Rc}, std::{collections::VecDeque, ptr, rc::Rc},
thiserror::Error, thiserror::Error,
uapi::{OwnedFd, Pod}, uapi::{OwnedFd, Pod},
}; };
@ -22,29 +22,32 @@ pub enum MsgParserError {
TrailingData, TrailingData,
#[error("String is not UTF-8")] #[error("String is not UTF-8")]
NonUtf8, NonUtf8,
#[error("The message has an unexpected size")]
UnexpectedMessageSize,
} }
pub struct MsgParser<'a, 'b> { pub struct MsgParser<'a, 'b> {
buf: &'a mut BufFdIn, fds: &'a mut VecDeque<Rc<OwnedFd>>,
pos: usize, pos: usize,
data: &'b [u8], data: &'b [u32],
} }
impl<'a, 'b> MsgParser<'a, 'b> { impl<'a, 'b> MsgParser<'a, 'b> {
pub fn new(buf: &'a mut BufFdIn, data: &'b [u32]) -> Self { pub fn new(fds: &'a mut VecDeque<Rc<OwnedFd>>, data: &'b [u32]) -> Self {
Self { Self { fds, pos: 0, data }
buf, }
pos: 0,
data: uapi::as_bytes(data), #[inline(always)]
} pub fn data(&self) -> &[u32] {
self.data
} }
pub fn int(&mut self) -> Result<i32, MsgParserError> { pub fn int(&mut self) -> Result<i32, MsgParserError> {
if self.data.len() - self.pos < 4 { if self.pos >= self.data.len() {
return Err(MsgParserError::UnexpectedEof); return Err(MsgParserError::UnexpectedEof);
} }
let res = unsafe { *(self.data.as_ptr().add(self.pos) as *const i32) }; let res = unsafe { *(self.data.as_ptr().add(self.pos) as *const i32) };
self.pos += 4; self.pos += 1;
Ok(res) Ok(res)
} }
@ -52,12 +55,14 @@ impl<'a, 'b> MsgParser<'a, 'b> {
self.int().map(|i| i as u32) self.int().map(|i| i as u32)
} }
#[expect(dead_code)]
pub fn u64(&mut self) -> Result<u64, MsgParserError> { pub fn u64(&mut self) -> Result<u64, MsgParserError> {
let hi = self.uint()?; let hi = self.uint()?;
let lo = self.uint()?; let lo = self.uint()?;
Ok(((hi as u64) << 32) | lo as u64) Ok(((hi as u64) << 32) | lo as u64)
} }
#[expect(dead_code)]
pub fn u64_rev(&mut self) -> Result<u64, MsgParserError> { pub fn u64_rev(&mut self) -> Result<u64, MsgParserError> {
let lo = self.uint()?; let lo = self.uint()?;
let hi = self.uint()?; let hi = self.uint()?;
@ -107,8 +112,8 @@ impl<'a, 'b> MsgParser<'a, 'b> {
} }
pub fn fd(&mut self) -> Result<Rc<OwnedFd>, MsgParserError> { pub fn fd(&mut self) -> Result<Rc<OwnedFd>, MsgParserError> {
match self.buf.get_fd() { match self.fds.pop_front() {
Ok(fd) => Ok(fd), Some(fd) => Ok(fd),
_ => Err(MsgParserError::MissingFd), _ => Err(MsgParserError::MissingFd),
} }
} }
@ -123,13 +128,13 @@ impl<'a, 'b> MsgParser<'a, 'b> {
pub fn array(&mut self) -> Result<&'b [u8], MsgParserError> { pub fn array(&mut self) -> Result<&'b [u8], MsgParserError> {
let len = self.uint()? as usize; let len = self.uint()? as usize;
let cap = (len + 3) & !3; let cap = (len + 3) >> 2;
if cap > self.data.len() - self.pos { if cap > self.data.len() - self.pos {
return Err(MsgParserError::UnexpectedEof); return Err(MsgParserError::UnexpectedEof);
} }
let pos = self.pos; let pos = self.pos;
self.pos += cap; self.pos += cap;
Ok(&self.data[pos..pos + len]) Ok(&uapi::as_bytes(&self.data[pos..])[..len])
} }
pub fn binary<T: Pod>(&mut self) -> Result<T, MsgParserError> { pub fn binary<T: Pod>(&mut self) -> Result<T, MsgParserError> {

View file

@ -0,0 +1,125 @@
use {
crate::{
io_uring::IoUring,
object::ObjectId,
utils::{
buf::Buf,
buffd::{BufFdError, MAX_IN_FD},
},
},
std::{collections::VecDeque, ptr, rc::Rc, slice},
uapi::OwnedFd,
};
const WORD_SIZE: usize = 4;
const WORD_ALIGN: usize = 4;
const HEADER_WORDS: usize = 2;
const HEADER_SIZE: usize = HEADER_WORDS * WORD_SIZE;
const MAX_MESSAGE_SIZE: usize = 4096;
const BUF_SIZE: usize = 2 * MAX_MESSAGE_SIZE;
pub struct WlBufFdIn {
fd: Rc<OwnedFd>,
ring: Rc<IoUring>,
fds: VecDeque<Rc<OwnedFd>>,
buf: Buf,
lo: usize,
len: usize,
}
pub struct WlMessage<'a> {
pub obj_id: ObjectId,
pub message: u32,
pub body: &'a [u32],
pub fds: &'a mut VecDeque<Rc<OwnedFd>>,
}
impl WlBufFdIn {
pub fn new(fd: &Rc<OwnedFd>, ring: &Rc<IoUring>) -> Self {
let buf = Buf::new(BUF_SIZE);
assert_eq!(buf.as_ptr() as usize % WORD_ALIGN, 0);
Self {
fd: fd.clone(),
ring: ring.clone(),
fds: Default::default(),
buf,
lo: Default::default(),
len: Default::default(),
}
}
pub async fn read_message(&mut self) -> Result<WlMessage<'_>, BufFdError> {
if self.len == 0 {
self.lo = 0;
}
if self.len < HEADER_SIZE {
if self.lo > 0 {
self.compact();
}
while self.len < HEADER_SIZE {
self.recvmsg().await?;
}
}
let hdr: &[u32] =
unsafe { slice::from_raw_parts(self.buf[self.lo..].as_ptr().cast(), HEADER_WORDS) };
let obj_id = ObjectId::from_raw(hdr[0]);
let len = (hdr[1] >> 16) as usize;
let message = hdr[1] & 0xffff;
if len & 3 != 0 {
return Err(BufFdError::UnalignedMessageSize);
}
if len > MAX_MESSAGE_SIZE {
return Err(BufFdError::MessageTooLarge);
}
if len < HEADER_SIZE {
return Err(BufFdError::MessageTooSmall);
}
if len > self.len {
if self.lo + self.len >= MAX_MESSAGE_SIZE {
self.compact();
}
while len > self.len {
self.recvmsg().await?;
}
}
let body: &[u32] = unsafe {
let words = (len - HEADER_SIZE) >> 2;
slice::from_raw_parts(self.buf[self.lo + HEADER_SIZE..].as_ptr().cast(), words)
};
self.lo += len;
self.len -= len;
Ok(WlMessage {
obj_id,
message,
body,
fds: &mut self.fds,
})
}
#[inline(always)]
fn compact(&mut self) {
unsafe {
let dst = self.buf.as_mut_ptr();
let src = dst.add(self.lo);
ptr::copy(src, dst, self.len);
self.lo = 0;
}
}
async fn recvmsg(&mut self) -> Result<(), BufFdError> {
let mut buf = self.buf.slice(self.lo + self.len..);
match self
.ring
.recvmsg(&self.fd, slice::from_mut(&mut buf), &mut self.fds)
.await
{
Ok(0) => return Err(BufFdError::Closed),
Ok(n) => self.len += n,
Err(e) => return Err(BufFdError::Ring(e)),
}
if self.fds.len() > MAX_IN_FD {
return Err(BufFdError::TooManyFds);
}
Ok(())
}
}

View file

@ -11,15 +11,14 @@ use {
asyncevent::AsyncEvent, asyncevent::AsyncEvent,
bitfield::Bitfield, bitfield::Bitfield,
buffd::{ buffd::{
BufFdError, BufFdIn, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer, BufFdError, BufFdOut, MsgFormatter, MsgParser, MsgParserError, OutBuffer,
OutBufferSwapchain, OutBufferSwapchain, WlBufFdIn, WlMessage,
}, },
clonecell::CloneCell, clonecell::CloneCell,
copyhashmap::CopyHashMap, copyhashmap::CopyHashMap,
errorfmt::ErrorFmt, errorfmt::ErrorFmt,
hash_map_ext::HashMapExt, hash_map_ext::HashMapExt,
oserror::OsError, oserror::OsError,
vec_ext::VecExt,
}, },
video::dmabuf::DmaBufIds, video::dmabuf::DmaBufIds,
wheel::Wheel, wheel::Wheel,
@ -51,10 +50,6 @@ pub enum UsrConError {
SocketPathTooLong, SocketPathTooLong,
#[error("Could not connect to the compositor")] #[error("Could not connect to the compositor")]
Connect(#[source] IoUringError), Connect(#[source] IoUringError),
#[error("The message length is smaller than 8 bytes")]
MsgLenTooSmall,
#[error("The size of the message is not a multiple of 4")]
UnalignedMessage,
#[error(transparent)] #[error(transparent)]
BufFdError(#[from] BufFdError), BufFdError(#[from] BufFdError),
#[error("Could not read from the compositor")] #[error("Could not read from the compositor")]
@ -168,8 +163,7 @@ impl UsrCon {
"wl_usr incoming", "wl_usr incoming",
Incoming { Incoming {
con: slf.clone(), con: slf.clone(),
buf: BufFdIn::new(socket, &slf.ring), buf: WlBufFdIn::new(socket, &slf.ring),
data: vec![],
} }
.run(), .run(),
), ),
@ -257,7 +251,6 @@ impl UsrCon {
mut parser: MsgParser<'_, 'a>, mut parser: MsgParser<'_, 'a>,
) -> Result<R, MsgParserError> { ) -> Result<R, MsgParserError> {
let res = R::parse(&mut parser)?; let res = R::parse(&mut parser)?;
parser.eof()?;
log::trace!( log::trace!(
"Server {} -> {}@{}.{:?}", "Server {} -> {}@{}.{:?}",
self.server_id, self.server_id,
@ -338,8 +331,7 @@ impl Outgoing {
struct Incoming { struct Incoming {
con: Rc<UsrCon>, con: Rc<UsrCon>,
buf: BufFdIn, buf: WlBufFdIn,
data: Vec<u32>,
} }
impl Incoming { impl Incoming {
@ -358,33 +350,16 @@ impl Incoming {
} }
async fn handle_msg(&mut self) -> Result<(), UsrConError> { async fn handle_msg(&mut self) -> Result<(), UsrConError> {
let mut hdr = [0u32, 0]; let WlMessage {
if let Err(e) = self.buf.read_full(&mut hdr[..]).await { obj_id,
return Err(UsrConError::Read(e)); message,
} body,
let obj_id = ObjectId::from_raw(hdr[0]); fds,
let len = (hdr[1] >> 16) as usize; } = self.buf.read_message().await.map_err(UsrConError::Read)?;
let event = hdr[1] & 0xffff;
if len < 8 {
return Err(UsrConError::MsgLenTooSmall);
}
if len % 4 != 0 {
return Err(UsrConError::UnalignedMessage);
}
let len = len / 4 - 2;
self.data.clear();
self.data.reserve(len);
let unused = self.data.split_at_spare_mut_ext().1;
if let Err(e) = self.buf.read_full(&mut unused[..len]).await {
return Err(UsrConError::Read(e));
}
unsafe {
self.data.set_len(len);
}
if let Some(obj) = self.con.objects.get(&obj_id) { if let Some(obj) = self.con.objects.get(&obj_id) {
if let Some(obj) = obj { if let Some(obj) = obj {
let parser = MsgParser::new(&mut self.buf, &self.data); let parser = MsgParser::new(fds, body);
obj.handle_event(&self.con, event, parser)?; obj.handle_event(&self.con, message, parser)?;
} }
} else if obj_id.raw() < MIN_SERVER_ID { } else if obj_id.raw() < MIN_SERVER_ID {
return Err(UsrConError::MissingObject(obj_id)); return Err(UsrConError::MissingObject(obj_id));