diff --git a/Cargo.lock b/Cargo.lock index 31602b57..15a9d18b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -709,6 +709,7 @@ dependencies = [ "jay-tree-types", "jay-units", "jay-utils", + "jay-wheel", "jay-wire-buf", "jay-wire-types", "jay-xcon", @@ -909,6 +910,19 @@ dependencies = [ "uapi", ] +[[package]] +name = "jay-wheel" +version = "0.1.0" +dependencies = [ + "jay-async-engine", + "jay-io-uring", + "jay-time", + "jay-utils", + "log", + "thiserror", + "uapi", +] + [[package]] name = "jay-wire-buf" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e7a10432..711e9cd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ members = [ "wire-buf", "tree-types", "eventfd-cache", + "wheel", "toml-config", "algorithms", "toml-spec", @@ -73,6 +74,7 @@ jay-wire-types = { version = "0.1.0", path = "wire-types" } jay-wire-buf = { version = "0.1.0", path = "wire-buf" } jay-tree-types = { version = "0.1.0", path = "tree-types" } jay-eventfd-cache = { version = "0.1.0", path = "eventfd-cache" } +jay-wheel = { version = "0.1.0", path = "wheel" } uapi = "0.2.13" thiserror = "2.0.11" diff --git a/src/wheel.rs b/src/wheel.rs index c0d10c97..eb67bf95 100644 --- a/src/wheel.rs +++ b/src/wheel.rs @@ -1,255 +1 @@ -use { - crate::{ - async_engine::{AsyncEngine, SpawnedFuture}, - io_uring::{IoUring, IoUringError}, - time::Time, - utils::{ - buf::TypedBuf, - copyhashmap::CopyHashMap, - errorfmt::ErrorFmt, - hash_map_ext::HashMapExt, - numcell::NumCell, - oserror::{OsError, OsErrorExt, OsErrorExt2}, - stack::Stack, - }, - }, - std::{ - cell::{Cell, RefCell}, - cmp::Reverse, - collections::BinaryHeap, - future::Future, - pin::Pin, - rc::Rc, - task::{Context, Poll, Waker}, - time::Duration, - }, - thiserror::Error, - uapi::{OwnedFd, c}, -}; - -#[derive(Debug, Error)] -pub enum WheelError { - #[error("Could not create the timerfd")] - CreateFailed(#[source] OsError), - #[error("Could not set the timerfd")] - SetFailed(#[source] OsError), - #[error("The timer wheel is already destroyed")] - Destroyed, - #[error("Could not read from the timerfd")] - Read(#[source] IoUringError), -} - -#[derive(Debug, Eq, PartialEq, Ord, PartialOrd)] -struct WheelEntry { - expiration: Time, - id: u64, -} - -pub struct Wheel { - data: Rc, -} - -impl Drop for Wheel { - fn drop(&mut self) { - self.data.kill(); - } -} - -struct WheelTimeoutData { - id: Cell, - expired: Cell>>, - wheel: Rc, - waker: Cell>, -} - -impl WheelTimeoutData { - fn complete(&self, res: Result<(), WheelError>) { - self.expired.set(Some(res)); - if let Some(waker) = self.waker.take() { - waker.wake(); - } - } -} - -pub struct WheelTimeoutFuture { - data: Rc, -} - -impl Drop for WheelTimeoutFuture { - fn drop(&mut self) { - self.data.wheel.dispatchers.remove(&self.data.id.get()); - self.data.waker.set(None); - if !self.data.wheel.destroyed.get() { - self.data.expired.take(); - self.data.wheel.cached_futures.push(self.data.clone()); - } - } -} - -impl Future for WheelTimeoutFuture { - type Output = Result<(), WheelError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if let Some(res) = self.data.expired.take() { - Poll::Ready(res) - } else { - self.data.waker.set(Some(cx.waker().clone())); - Poll::Pending - } - } -} - -pub struct WheelData { - destroyed: Cell, - ring: Rc, - eng: Rc, - fd: Rc, - next_id: NumCell, - start: Time, - current_expiration: Cell>, - dispatchers: CopyHashMap>, - expirations: RefCell>>, - dispatcher: Cell>>, - cached_futures: Stack>, -} - -impl Wheel { - pub fn new(eng: &Rc, ring: &Rc) -> Result, WheelError> { - let fd = uapi::timerfd_create(c::CLOCK_MONOTONIC, c::TFD_CLOEXEC) - .map(Rc::new) - .map_os_err(WheelError::CreateFailed)?; - let data = Rc::new(WheelData { - destroyed: Cell::new(false), - ring: ring.clone(), - eng: eng.clone(), - fd, - next_id: NumCell::new(1), - start: eng.now(), - current_expiration: Default::default(), - dispatchers: Default::default(), - expirations: Default::default(), - dispatcher: Default::default(), - cached_futures: Default::default(), - }); - data.dispatcher - .set(Some(eng.spawn("wheel", data.clone().dispatch()))); - Ok(Rc::new(Wheel { data })) - } - - pub fn clear(&self) { - self.data.kill(); - } - - fn future(&self) -> WheelTimeoutFuture { - let data = self.data.cached_futures.pop().unwrap_or_else(|| { - Rc::new(WheelTimeoutData { - id: Cell::new(0), - expired: Cell::new(None), - wheel: self.data.clone(), - waker: Cell::new(None), - }) - }); - data.id.set(self.data.next_id.fetch_add(1)); - WheelTimeoutFuture { data } - } - - pub fn timeout(&self, ms: u64) -> WheelTimeoutFuture { - if self.data.destroyed.get() { - return WheelTimeoutFuture { - data: Rc::new(WheelTimeoutData { - id: Cell::new(0), - expired: Cell::new(Some(Err(WheelError::Destroyed))), - wheel: self.data.clone(), - waker: Default::default(), - }), - }; - } - let future = self.future(); - let now = self.data.eng.now(); - let expiration = (now + Duration::from_millis(ms)).round_to_ms(); - let current = self.data.current_expiration.get(); - if current.is_none() || expiration - self.data.start < current.unwrap() - self.data.start { - let res = uapi::timerfd_settime( - self.data.fd.raw(), - c::TFD_TIMER_ABSTIME, - &c::itimerspec { - it_interval: uapi::pod_zeroed(), - it_value: expiration.0, - }, - ); - if let Err(e) = res.to_os_error() { - future.data.expired.set(Some(Err(WheelError::SetFailed(e)))); - return future; - } - self.data.current_expiration.set(Some(expiration)); - } - self.data.expirations.borrow_mut().push(Reverse(WheelEntry { - expiration, - id: future.data.id.get(), - })); - self.data - .dispatchers - .set(future.data.id.get(), future.data.clone()); - future - } -} - -impl WheelData { - fn kill(&self) { - self.destroyed.set(true); - self.dispatcher.set(None); - self.cached_futures.take(); - for dispatcher in self.dispatchers.lock().drain_values() { - dispatcher.complete(Err(WheelError::Destroyed)); - } - } - - async fn dispatch(self: Rc) { - let mut n = TypedBuf::new(); - loop { - if let Err(e) = self.dispatch_once(&mut n).await { - log::error!("Could not dispatch wheel expirations: {}", ErrorFmt(e)); - self.kill(); - return; - } - } - } - - async fn dispatch_once(&self, n: &mut TypedBuf) -> Result<(), WheelError> { - if let Err(e) = self.ring.read(&self.fd, n.buf()).await { - return Err(WheelError::Read(e)); - } - let now = self.eng.now(); - let dist = now - self.start; - { - let mut expirations = self.expirations.borrow_mut(); - while let Some(Reverse(entry)) = expirations.peek() { - if entry.expiration - self.start > dist { - break; - } - if let Some(dispatcher) = self.dispatchers.remove(&entry.id) { - dispatcher.complete(Ok(())); - } - expirations.pop(); - } - self.current_expiration.set(None); - while let Some(Reverse(entry)) = expirations.peek() { - if self.dispatchers.get(&entry.id).is_some() { - uapi::timerfd_settime( - self.fd.raw(), - c::TFD_TIMER_ABSTIME, - &c::itimerspec { - it_interval: uapi::pod_zeroed(), - it_value: entry.expiration.0, - }, - ) - .map_os_err(WheelError::SetFailed)?; - self.current_expiration.set(Some(entry.expiration)); - break; - } - expirations.pop(); - } - } - Ok(()) - } -} +pub use jay_wheel::*; diff --git a/wheel/Cargo.toml b/wheel/Cargo.toml new file mode 100644 index 00000000..5b69da86 --- /dev/null +++ b/wheel/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "jay-wheel" +version = "0.1.0" +edition = "2024" +license = "GPL-3.0-only" + +[dependencies] +jay-async-engine = { version = "0.1.0", path = "../async-engine" } +jay-io-uring = { version = "0.1.0", path = "../io-uring" } +jay-time = { version = "0.1.0", path = "../time" } +jay-utils = { version = "0.1.0", path = "../utils" } + +log = { version = "0.4.20", features = ["std"] } +thiserror = "2.0.11" +uapi = "0.2.13" diff --git a/wheel/src/lib.rs b/wheel/src/lib.rs new file mode 100644 index 00000000..360d20c7 --- /dev/null +++ b/wheel/src/lib.rs @@ -0,0 +1,253 @@ +use { + jay_async_engine::{AsyncEngine, SpawnedFuture}, + jay_io_uring::{IoUring, IoUringError}, + jay_time::Time, + jay_utils::{ + buf::TypedBuf, + copyhashmap::CopyHashMap, + errorfmt::ErrorFmt, + hash_map_ext::HashMapExt, + numcell::NumCell, + oserror::{OsError, OsErrorExt, OsErrorExt2}, + stack::Stack, + }, + std::{ + cell::{Cell, RefCell}, + cmp::Reverse, + collections::BinaryHeap, + future::Future, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, + time::Duration, + }, + thiserror::Error, + uapi::{OwnedFd, c}, +}; + +#[derive(Debug, Error)] +pub enum WheelError { + #[error("Could not create the timerfd")] + CreateFailed(#[source] OsError), + #[error("Could not set the timerfd")] + SetFailed(#[source] OsError), + #[error("The timer wheel is already destroyed")] + Destroyed, + #[error("Could not read from the timerfd")] + Read(#[source] IoUringError), +} + +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd)] +struct WheelEntry { + expiration: Time, + id: u64, +} + +pub struct Wheel { + data: Rc, +} + +impl Drop for Wheel { + fn drop(&mut self) { + self.data.kill(); + } +} + +struct WheelTimeoutData { + id: Cell, + expired: Cell>>, + wheel: Rc, + waker: Cell>, +} + +impl WheelTimeoutData { + fn complete(&self, res: Result<(), WheelError>) { + self.expired.set(Some(res)); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } +} + +pub struct WheelTimeoutFuture { + data: Rc, +} + +impl Drop for WheelTimeoutFuture { + fn drop(&mut self) { + self.data.wheel.dispatchers.remove(&self.data.id.get()); + self.data.waker.set(None); + if !self.data.wheel.destroyed.get() { + self.data.expired.take(); + self.data.wheel.cached_futures.push(self.data.clone()); + } + } +} + +impl Future for WheelTimeoutFuture { + type Output = Result<(), WheelError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(res) = self.data.expired.take() { + Poll::Ready(res) + } else { + self.data.waker.set(Some(cx.waker().clone())); + Poll::Pending + } + } +} + +pub struct WheelData { + destroyed: Cell, + ring: Rc, + eng: Rc, + fd: Rc, + next_id: NumCell, + start: Time, + current_expiration: Cell>, + dispatchers: CopyHashMap>, + expirations: RefCell>>, + dispatcher: Cell>>, + cached_futures: Stack>, +} + +impl Wheel { + pub fn new(eng: &Rc, ring: &Rc) -> Result, WheelError> { + let fd = uapi::timerfd_create(c::CLOCK_MONOTONIC, c::TFD_CLOEXEC) + .map(Rc::new) + .map_os_err(WheelError::CreateFailed)?; + let data = Rc::new(WheelData { + destroyed: Cell::new(false), + ring: ring.clone(), + eng: eng.clone(), + fd, + next_id: NumCell::new(1), + start: eng.now(), + current_expiration: Default::default(), + dispatchers: Default::default(), + expirations: Default::default(), + dispatcher: Default::default(), + cached_futures: Default::default(), + }); + data.dispatcher + .set(Some(eng.spawn("wheel", data.clone().dispatch()))); + Ok(Rc::new(Wheel { data })) + } + + pub fn clear(&self) { + self.data.kill(); + } + + fn future(&self) -> WheelTimeoutFuture { + let data = self.data.cached_futures.pop().unwrap_or_else(|| { + Rc::new(WheelTimeoutData { + id: Cell::new(0), + expired: Cell::new(None), + wheel: self.data.clone(), + waker: Cell::new(None), + }) + }); + data.id.set(self.data.next_id.fetch_add(1)); + WheelTimeoutFuture { data } + } + + pub fn timeout(&self, ms: u64) -> WheelTimeoutFuture { + if self.data.destroyed.get() { + return WheelTimeoutFuture { + data: Rc::new(WheelTimeoutData { + id: Cell::new(0), + expired: Cell::new(Some(Err(WheelError::Destroyed))), + wheel: self.data.clone(), + waker: Default::default(), + }), + }; + } + let future = self.future(); + let now = self.data.eng.now(); + let expiration = (now + Duration::from_millis(ms)).round_to_ms(); + let current = self.data.current_expiration.get(); + if current.is_none() || expiration - self.data.start < current.unwrap() - self.data.start { + let res = uapi::timerfd_settime( + self.data.fd.raw(), + c::TFD_TIMER_ABSTIME, + &c::itimerspec { + it_interval: uapi::pod_zeroed(), + it_value: expiration.0, + }, + ); + if let Err(e) = res.to_os_error() { + future.data.expired.set(Some(Err(WheelError::SetFailed(e)))); + return future; + } + self.data.current_expiration.set(Some(expiration)); + } + self.data.expirations.borrow_mut().push(Reverse(WheelEntry { + expiration, + id: future.data.id.get(), + })); + self.data + .dispatchers + .set(future.data.id.get(), future.data.clone()); + future + } +} + +impl WheelData { + fn kill(&self) { + self.destroyed.set(true); + self.dispatcher.set(None); + self.cached_futures.take(); + for dispatcher in self.dispatchers.lock().drain_values() { + dispatcher.complete(Err(WheelError::Destroyed)); + } + } + + async fn dispatch(self: Rc) { + let mut n = TypedBuf::new(); + loop { + if let Err(e) = self.dispatch_once(&mut n).await { + log::error!("Could not dispatch wheel expirations: {}", ErrorFmt(e)); + self.kill(); + return; + } + } + } + + async fn dispatch_once(&self, n: &mut TypedBuf) -> Result<(), WheelError> { + if let Err(e) = self.ring.read(&self.fd, n.buf()).await { + return Err(WheelError::Read(e)); + } + let now = self.eng.now(); + let dist = now - self.start; + { + let mut expirations = self.expirations.borrow_mut(); + while let Some(Reverse(entry)) = expirations.peek() { + if entry.expiration - self.start > dist { + break; + } + if let Some(dispatcher) = self.dispatchers.remove(&entry.id) { + dispatcher.complete(Ok(())); + } + expirations.pop(); + } + self.current_expiration.set(None); + while let Some(Reverse(entry)) = expirations.peek() { + if self.dispatchers.get(&entry.id).is_some() { + uapi::timerfd_settime( + self.fd.raw(), + c::TFD_TIMER_ABSTIME, + &c::itimerspec { + it_interval: uapi::pod_zeroed(), + it_value: entry.expiration.0, + }, + ) + .map_os_err(WheelError::SetFailed)?; + self.current_expiration.set(Some(entry.expiration)); + break; + } + expirations.pop(); + } + } + Ok(()) + } +}