diff --git a/src/compositor.rs b/src/compositor.rs index ad5d5720..1a93f071 100644 --- a/src/compositor.rs +++ b/src/compositor.rs @@ -13,6 +13,7 @@ use { client::{ClientId, Clients}, clientmem::{self, ClientMemError}, config::ConfigProxy, + cpu_worker::{CpuWorker, CpuWorkerError}, damage::{visualize_damage, DamageVisualizer}, dbus::Dbus, ei::ei_client::EiClients, @@ -107,6 +108,8 @@ pub enum CompositorError { WheelError(#[from] WheelError), #[error("Could not create an io-uring")] IoUringError(#[from] IoUringError), + #[error("Could not create cpu worker")] + CpuWorkerError(#[from] CpuWorkerError), } pub const WAYLAND_DISPLAY: &str = "WAYLAND_DISPLAY"; @@ -143,6 +146,7 @@ fn start_compositor2( let node_ids = NodeIds::default(); let scales = RefCounted::default(); scales.add(Scale::from_int(1)); + let cpu_worker = Rc::new(CpuWorker::new(&ring, &engine)?); let state = Rc::new(State { xkb_ctx, backend: CloneCell::new(Rc::new(DummyBackend)), @@ -258,6 +262,7 @@ fn start_compositor2( enable_ei_acceptor: Default::default(), ei_clients: EiClients::new(), slow_ei_clients: Default::default(), + cpu_worker, }); state.tracker.register(ClientId::from_raw(0)); create_dummy_output(&state); diff --git a/src/cpu_worker.rs b/src/cpu_worker.rs new file mode 100644 index 00000000..d3d6adca --- /dev/null +++ b/src/cpu_worker.rs @@ -0,0 +1,432 @@ +pub mod jobs; +#[cfg(test)] +mod tests; + +use { + crate::{ + async_engine::{AsyncEngine, SpawnedFuture}, + io_uring::IoUring, + utils::{ + buf::TypedBuf, copyhashmap::CopyHashMap, errorfmt::ErrorFmt, oserror::OsError, + ptr_ext::MutPtrExt, queue::AsyncQueue, stack::Stack, + }, + }, + parking_lot::Mutex, + std::{ + any::Any, + cell::{Cell, RefCell}, + collections::VecDeque, + mem, + ptr::NonNull, + rc::Rc, + sync::Arc, + thread, + }, + thiserror::Error, + uapi::{c, OwnedFd}, +}; + +pub trait CpuJob { + fn work(&mut self) -> &mut dyn CpuWork; + fn completed(self: Box); +} + +pub trait CpuWork: Send { + fn run(&mut self) -> Option>; + + fn cancel_async(&mut self, ring: &Rc) { + let _ = ring; + unreachable!(); + } + + fn async_work_done(&mut self, work: Box) { + let _ = work; + unreachable!(); + } +} + +pub trait AsyncCpuWork { + fn run( + self: Box, + eng: &Rc, + ring: &Rc, + completion: WorkCompletion, + ) -> SpawnedFuture; + + fn into_any(self: Box) -> Box; +} + +pub struct WorkCompletion { + worker: Rc, + id: CpuJobId, +} + +pub struct CompletedWork(()); + +impl WorkCompletion { + pub fn complete(self, work: Box) -> CompletedWork { + let job = self.worker.async_jobs.remove(&self.id).unwrap(); + unsafe { + job.work.deref_mut().async_work_done(work); + } + self.worker.send_completion(self.id); + CompletedWork(()) + } +} + +pub struct CpuWorker { + data: Rc, + _completions_listener: SpawnedFuture<()>, + _job_enqueuer: SpawnedFuture<()>, +} + +#[must_use] +pub struct PendingJob { + id: CpuJobId, + thread_data: Rc, + job_data: Rc, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)] +enum PendingJobState { + #[default] + Waiting, + Abandoned, + Completed, +} + +#[derive(Default)] +struct PendingJobData { + job: Cell>>, + state: Cell, +} + +enum Job { + New { + id: CpuJobId, + work: *mut dyn CpuWork, + }, + Cancel { + id: CpuJobId, + }, +} + +unsafe impl Send for Job {} + +struct CpuWorkerData { + next: CpuJobIds, + jobs_to_enqueue: AsyncQueue, + new_jobs: Arc>>, + have_new_jobs: Rc, + completed_jobs_remote: Arc>>, + completed_jobs_local: RefCell>, + have_completed_jobs: Rc, + pending_jobs: CopyHashMap>, + ring: Rc, + _stop: OwnedFd, + pending_job_data_cache: Stack>, +} + +linear_ids!(CpuJobIds, CpuJobId, u64); + +#[derive(Debug, Error)] +pub enum CpuWorkerError { + #[error("Could not create a pipe")] + Pipe(#[source] OsError), + #[error("Could not create an eventfd")] + EventFd(#[source] OsError), + #[error("Could not dup an eventfd")] + Dup(#[source] OsError), +} + +impl PendingJob { + #[expect(dead_code)] + pub fn detach(self) { + match self.job_data.state.get() { + PendingJobState::Waiting => { + self.job_data.state.set(PendingJobState::Abandoned); + } + PendingJobState::Abandoned => { + unreachable!(); + } + PendingJobState::Completed => {} + } + } +} + +impl Drop for CpuWorker { + fn drop(&mut self) { + self.data.do_equeue_jobs(); + if self.data.pending_jobs.is_not_empty() { + log::warn!("CpuWorker dropped with pending jobs. Completed jobs will not be triggered.") + } + } +} + +impl Drop for PendingJob { + fn drop(&mut self) { + match self.job_data.state.get() { + PendingJobState::Waiting => { + log::warn!("PendingJob dropped before completion. Blocking."); + let data = &self.thread_data; + let id = self.id; + self.job_data.state.set(PendingJobState::Abandoned); + data.jobs_to_enqueue.push(Job::Cancel { id }); + data.do_equeue_jobs(); + let mut buf = 0u64; + while data.pending_jobs.contains(&id) { + if let Err(e) = uapi::read(data.have_completed_jobs.raw(), &mut buf) { + panic!("Could not wait for job completions: {}", ErrorFmt(e)); + } + data.dispatch_completions(); + } + } + PendingJobState::Abandoned => {} + PendingJobState::Completed => { + self.thread_data + .pending_job_data_cache + .push(self.job_data.clone()); + } + } + } +} + +impl CpuWorkerData { + async fn wait_for_completions(self: Rc) { + let mut buf = TypedBuf::::new(); + loop { + if let Err(e) = self.ring.read(&self.have_completed_jobs, buf.buf()).await { + log::error!("Could not wait for job completions: {}", ErrorFmt(e)); + return; + } + self.dispatch_completions(); + } + } + + fn dispatch_completions(&self) { + let completions = &mut *self.completed_jobs_local.borrow_mut(); + mem::swap(completions, &mut *self.completed_jobs_remote.lock()); + while let Some(id) = completions.pop_front() { + let job_data = self.pending_jobs.remove(&id).unwrap(); + let job = job_data.job.take().unwrap(); + let job = unsafe { Box::from_raw(job.as_ptr()) }; + match job_data.state.get() { + PendingJobState::Waiting => { + job_data.state.set(PendingJobState::Completed); + job.completed(); + } + PendingJobState::Abandoned => { + self.pending_job_data_cache.push(job_data); + } + PendingJobState::Completed => { + unreachable!(); + } + } + } + } + + async fn equeue_jobs(self: Rc) { + loop { + self.jobs_to_enqueue.non_empty().await; + self.do_equeue_jobs(); + } + } + + fn do_equeue_jobs(&self) { + self.jobs_to_enqueue.move_to(&mut self.new_jobs.lock()); + if let Err(e) = uapi::eventfd_write(self.have_new_jobs.raw(), 1) { + panic!("Could not signal eventfd: {}", ErrorFmt(e)); + } + } +} + +impl CpuWorker { + pub fn new(ring: &Rc, eng: &Rc) -> Result { + let new_jobs: Arc>> = Default::default(); + let completed_jobs: Arc>> = Default::default(); + let (stop_read, stop_write) = + uapi::pipe2(c::O_CLOEXEC).map_err(|e| CpuWorkerError::Pipe(e.into()))?; + let have_new_jobs = + uapi::eventfd(0, c::EFD_CLOEXEC).map_err(|e| CpuWorkerError::EventFd(e.into()))?; + let have_completed_jobs = + uapi::eventfd(0, c::EFD_CLOEXEC).map_err(|e| CpuWorkerError::EventFd(e.into()))?; + thread::Builder::new() + .name("cpu worker".to_string()) + .spawn({ + let new_jobs = new_jobs.clone(); + let completed_jobs = completed_jobs.clone(); + let have_new_jobs = uapi::fcntl_dupfd_cloexec(have_new_jobs.raw(), 0) + .map_err(|e| CpuWorkerError::Dup(e.into()))?; + let have_completed_jobs = uapi::fcntl_dupfd_cloexec(have_completed_jobs.raw(), 0) + .map_err(|e| CpuWorkerError::Dup(e.into()))?; + move || { + work( + new_jobs, + completed_jobs, + stop_write, + have_new_jobs, + have_completed_jobs, + ) + } + }) + .unwrap(); + let data = Rc::new(CpuWorkerData { + next: Default::default(), + jobs_to_enqueue: Default::default(), + new_jobs, + have_new_jobs: Rc::new(have_new_jobs), + completed_jobs_remote: completed_jobs, + completed_jobs_local: Default::default(), + have_completed_jobs: Rc::new(have_completed_jobs), + pending_jobs: Default::default(), + ring: ring.clone(), + _stop: stop_read, + pending_job_data_cache: Default::default(), + }); + Ok(Self { + _completions_listener: eng.spawn(data.clone().wait_for_completions()), + _job_enqueuer: eng.spawn(data.clone().equeue_jobs()), + data, + }) + } + + #[expect(dead_code)] + pub fn submit(&self, job: Box) -> PendingJob { + let mut job = NonNull::from(Box::leak(job)); + let id = self.data.next.next(); + self.data.jobs_to_enqueue.push(Job::New { + id, + work: unsafe { job.as_mut().work() }, + }); + let job_data = self.data.pending_job_data_cache.pop().unwrap_or_default(); + job_data.job.set(Some(job)); + job_data.state.set(PendingJobState::Waiting); + self.data.pending_jobs.set(id, job_data.clone()); + PendingJob { + id, + thread_data: self.data.clone(), + job_data, + } + } +} + +fn work( + new_jobs: Arc>>, + completed_jobs: Arc>>, + stop: OwnedFd, + have_new_jobs: OwnedFd, + have_completed_jobs: OwnedFd, +) { + let eng = AsyncEngine::new(); + let ring = IoUring::new(&eng, 32).unwrap(); + let worker = Rc::new(Worker { + eng, + ring, + completed_jobs, + have_completed_jobs, + async_jobs: Default::default(), + stopped: Cell::new(false), + }); + let _stop_listener = worker.eng.spawn(worker.clone().handle_stop(stop)); + let _new_job_listener = worker + .eng + .spawn(worker.clone().handle_new_jobs(new_jobs, have_new_jobs)); + if let Err(e) = worker.ring.run() { + panic!("io_uring failed: {}", ErrorFmt(e)); + } +} + +struct Worker { + eng: Rc, + ring: Rc, + completed_jobs: Arc>>, + have_completed_jobs: OwnedFd, + async_jobs: CopyHashMap, + stopped: Cell, +} + +struct AsyncJob { + _future: SpawnedFuture, + work: *mut dyn CpuWork, +} + +impl Worker { + async fn handle_stop(self: Rc, stop: OwnedFd) { + let stop = Rc::new(stop); + if let Err(e) = self.ring.poll(&stop, 0).await { + log::error!( + "Could not wait for stop fd to become readable: {}", + ErrorFmt(e) + ); + } else { + assert!(self.async_jobs.is_empty()); + self.stopped.set(true); + self.ring.stop(); + } + } + + async fn handle_new_jobs( + self: Rc, + jobs_remote: Arc>>, + new_jobs: OwnedFd, + ) { + let mut buf = TypedBuf::::new(); + let new_jobs = Rc::new(new_jobs); + let mut jobs = VecDeque::new(); + loop { + if let Err(e) = self.ring.read(&new_jobs, buf.buf()).await { + if self.stopped.get() { + return; + } + panic!( + "Could not wait for new jobs fd to be signaled: {}", + ErrorFmt(e), + ); + } + mem::swap(&mut jobs, &mut *jobs_remote.lock()); + while let Some(job) = jobs.pop_front() { + self.handle_new_job(job); + } + } + } + + fn handle_new_job(self: &Rc, job: Job) { + match job { + Job::Cancel { id } => { + let mut jobs = self.async_jobs.lock(); + if let Some(job) = jobs.get_mut(&id) { + unsafe { + job.work.deref_mut().cancel_async(&self.ring); + } + } + } + Job::New { id, work } => match unsafe { work.deref_mut() }.run() { + None => { + self.send_completion(id); + return; + } + Some(w) => { + let completion = WorkCompletion { + worker: self.clone(), + id, + }; + let future = w.run(&self.eng, &self.ring, completion); + self.async_jobs.set( + id, + AsyncJob { + _future: future, + work, + }, + ); + } + }, + } + } + + fn send_completion(&self, id: CpuJobId) { + self.completed_jobs.lock().push_back(id); + if let Err(e) = uapi::eventfd_write(self.have_completed_jobs.raw(), 1) { + panic!("Could not signal job completion: {}", ErrorFmt(e)); + } + } +} diff --git a/src/cpu_worker/jobs.rs b/src/cpu_worker/jobs.rs new file mode 100644 index 00000000..aadb998f --- /dev/null +++ b/src/cpu_worker/jobs.rs @@ -0,0 +1 @@ +pub mod read_write; diff --git a/src/cpu_worker/jobs/read_write.rs b/src/cpu_worker/jobs/read_write.rs new file mode 100644 index 00000000..c8790529 --- /dev/null +++ b/src/cpu_worker/jobs/read_write.rs @@ -0,0 +1,153 @@ +use { + crate::{ + async_engine::{AsyncEngine, SpawnedFuture}, + cpu_worker::{AsyncCpuWork, CompletedWork, CpuWork, WorkCompletion}, + io_uring::{IoUring, IoUringError, IoUringTaskId}, + }, + std::{ + any::Any, + ptr, + rc::Rc, + slice, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering::Relaxed}, + Arc, + }, + }, + thiserror::Error, + uapi::{c, Fd}, +}; + +#[derive(Debug, Error)] +pub enum ReadWriteJobError { + #[error("An io_uring error occurred")] + IoUring(#[source] IoUringError), + #[error("The job was cancelled")] + Cancelled, + #[error("Tried to operate outside the bounds of the file descriptor")] + OutOfBounds, +} + +pub struct ReadWriteWork { + cancel: Arc, + config: Option>, +} + +unsafe impl Send for ReadWriteWork {} + +impl ReadWriteWork { + #[expect(dead_code)] + pub unsafe fn new() -> Self { + let cancel = Arc::new(CancelState::default()); + ReadWriteWork { + cancel: cancel.clone(), + config: Some(Box::new(ReadWriteWorkConfig { + fd: -1, + offset: 0, + ptr: ptr::null_mut(), + len: 0, + write: false, + cancel, + result: None, + })), + } + } + + #[expect(dead_code)] + pub fn config(&mut self) -> &mut ReadWriteWorkConfig { + self.config.as_mut().unwrap() + } +} + +pub struct ReadWriteWorkConfig { + pub fd: c::c_int, + pub offset: usize, + pub ptr: *mut u8, + pub len: usize, + pub write: bool, + pub result: Option>, + cancel: Arc, +} + +#[derive(Default)] +struct CancelState { + cancelled: AtomicBool, + cancel_id: AtomicU64, +} + +impl CpuWork for ReadWriteWork { + fn run(&mut self) -> Option> { + self.cancel.cancelled.store(false, Relaxed); + self.cancel.cancel_id.store(0, Relaxed); + self.config.take().map(|b| b as _) + } + + fn cancel_async(&mut self, ring: &Rc) { + self.cancel.cancelled.store(true, Relaxed); + let id = self.cancel.cancel_id.load(Relaxed); + if id != 0 { + ring.cancel(IoUringTaskId::from_raw(id)); + } + } + + fn async_work_done(&mut self, work: Box) { + let work = work.into_any().downcast().unwrap(); + self.config = Some(work); + } +} + +impl AsyncCpuWork for ReadWriteWorkConfig { + fn run( + mut self: Box, + eng: &Rc, + ring: &Rc, + completion: WorkCompletion, + ) -> SpawnedFuture { + let ring = ring.clone(); + eng.spawn(async move { + let res = loop { + if self.cancel.cancelled.load(Relaxed) { + break Err(ReadWriteJobError::Cancelled); + } + if self.len == 0 { + break Ok(()); + }; + let res = if self.write { + ring.write_no_cancel( + Fd::new(self.fd), + self.offset, + unsafe { slice::from_raw_parts(self.ptr, self.len) }, + None, + |id| self.cancel.cancel_id.store(id.raw(), Relaxed), + ) + .await + } else { + ring.read_no_cancel( + Fd::new(self.fd), + self.offset, + unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }, + |id| self.cancel.cancel_id.store(id.raw(), Relaxed), + ) + .await + }; + match res { + Ok(0) => break Err(ReadWriteJobError::OutOfBounds), + Ok(n) => { + self.len -= n; + self.offset += n; + unsafe { + self.ptr = self.ptr.add(n); + } + } + Err(e) => break Err(ReadWriteJobError::IoUring(e)), + } + }; + self.result = Some(res); + completion.complete(self) + }) + } + + fn into_any(self: Box) -> Box { + self + } +} diff --git a/src/cpu_worker/tests.rs b/src/cpu_worker/tests.rs new file mode 100644 index 00000000..6b0bb826 --- /dev/null +++ b/src/cpu_worker/tests.rs @@ -0,0 +1,117 @@ +use { + crate::{ + async_engine::{AsyncEngine, SpawnedFuture}, + cpu_worker::{AsyncCpuWork, CompletedWork, CpuJob, CpuWork, CpuWorker, WorkCompletion}, + io_uring::IoUring, + utils::asyncevent::AsyncEvent, + wheel::Wheel, + }, + std::{any::Any, future::pending, rc::Rc, sync::Arc}, + uapi::{c::EFD_CLOEXEC, OwnedFd}, +}; + +struct Job { + ae: Rc, + work: Work, + cancel: bool, +} +struct Work(Arc); +struct AsyncWork(Arc); + +impl CpuJob for Job { + fn work(&mut self) -> &mut dyn CpuWork { + &mut self.work + } + + fn completed(self: Box) { + if self.cancel { + unreachable!(); + } else { + self.ae.trigger(); + } + } +} + +impl Drop for Job { + fn drop(&mut self) { + if self.cancel { + self.ae.trigger(); + } + } +} + +impl CpuWork for Work { + fn run(&mut self) -> Option> { + Some(Box::new(AsyncWork(self.0.clone()))) + } + + fn cancel_async(&mut self, _ring: &Rc) { + uapi::eventfd_write(self.0.raw(), 1).unwrap(); + } + + fn async_work_done(&mut self, work: Box) { + let _ = work; + } +} + +impl AsyncCpuWork for AsyncWork { + fn run( + self: Box, + eng: &Rc, + ring: &Rc, + completion: WorkCompletion, + ) -> SpawnedFuture { + let ring = ring.clone(); + eng.spawn(async move { + let mut buf = [0; 8]; + let res = ring + .read_no_cancel(self.0.borrow(), 0, &mut buf, |_| ()) + .await; + res.unwrap(); + completion.complete(self) + }) + } + + fn into_any(self: Box) -> Box { + self + } +} + +fn run(cancel: bool) { + let eng = AsyncEngine::new(); + let ring = IoUring::new(&eng, 32).unwrap(); + let ring2 = ring.clone(); + let wheel = Wheel::new(&eng, &ring).unwrap(); + let cpu = Rc::new(CpuWorker::new(&ring, &eng).unwrap()); + let ae = Rc::new(AsyncEvent::default()); + let eventfd = Arc::new(uapi::eventfd(0, EFD_CLOEXEC).unwrap()); + let pending_job = cpu.submit(Box::new(Job { + ae: ae.clone(), + work: Work(eventfd.clone()), + cancel, + })); + let _fut1 = eng.spawn(async move { + wheel.timeout(1).await.unwrap(); + if cancel { + drop(pending_job); + } else { + uapi::eventfd_write(eventfd.raw(), 1).unwrap(); + pending::<()>().await; + } + }); + let _fut2 = eng.spawn(async move { + ae.triggered().await; + ring2.stop(); + }); + ring.run().unwrap(); +} + +#[test] +fn cancel() { + run(true); +} + +#[test] +fn complete() { + run(false); +} diff --git a/src/io_uring.rs b/src/io_uring.rs index 865634e4..60c5291a 100644 --- a/src/io_uring.rs +++ b/src/io_uring.rs @@ -231,7 +231,6 @@ impl IoUring { res } - #[expect(dead_code)] pub fn cancel(&self, id: IoUringTaskId) { self.ring.cancel_task(id); } diff --git a/src/io_uring/ops/read_write_no_cancel.rs b/src/io_uring/ops/read_write_no_cancel.rs index 16729de7..02cc2107 100644 --- a/src/io_uring/ops/read_write_no_cancel.rs +++ b/src/io_uring/ops/read_write_no_cancel.rs @@ -15,7 +15,6 @@ use { }; impl IoUring { - #[expect(dead_code)] pub async fn read_no_cancel( &self, fd: Fd, @@ -35,7 +34,6 @@ impl IoUring { .await } - #[expect(dead_code)] pub async fn write_no_cancel( &self, fd: Fd, diff --git a/src/main.rs b/src/main.rs index 82655149..4e47f8dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -54,6 +54,7 @@ mod client; mod clientmem; mod compositor; mod config; +mod cpu_worker; mod cursor; mod cursor_user; mod damage; diff --git a/src/state.rs b/src/state.rs index 98636a3d..fb44f9f6 100644 --- a/src/state.rs +++ b/src/state.rs @@ -13,6 +13,7 @@ use { clientmem::ClientMemOffset, compositor::LIBEI_SOCKET, config::ConfigProxy, + cpu_worker::CpuWorker, cursor::{Cursor, ServerCursors}, cursor_user::{CursorUserGroup, CursorUserGroupId, CursorUserGroupIds, CursorUserIds}, damage::DamageVisualizer, @@ -214,6 +215,8 @@ pub struct State { pub enable_ei_acceptor: Cell, pub ei_clients: EiClients, pub slow_ei_clients: AsyncQueue>, + #[expect(dead_code)] + pub cpu_worker: Rc, } // impl Drop for State { diff --git a/src/utils/queue.rs b/src/utils/queue.rs index 8e55ba6b..c39a28d5 100644 --- a/src/utils/queue.rs +++ b/src/utils/queue.rs @@ -64,6 +64,12 @@ impl AsyncQueue { } self.waiter.take(); } + + pub fn move_to(&self, other: &mut VecDeque) { + unsafe { + other.append(self.data.get().deref_mut()); + } + } } pub struct AsyncQueuePop<'a, T> {