From 7cc6c945d34a8baa9e58f581ebae831f9fac8367 Mon Sep 17 00:00:00 2001 From: Julian Orth Date: Fri, 13 May 2022 21:17:28 +0200 Subject: [PATCH] io-uring: add recvmsg --- src/io_uring.rs | 16 ++++- src/io_uring/ops.rs | 1 + src/io_uring/ops/recvmsg.rs | 131 ++++++++++++++++++++++++++++++++++++ src/utils/buffd/buf_in.rs | 73 ++++++++------------ src/wheel.rs | 1 - 5 files changed, 173 insertions(+), 49 deletions(-) create mode 100644 src/io_uring/ops/recvmsg.rs diff --git a/src/io_uring.rs b/src/io_uring.rs index bdae93f3..1d508b1e 100644 --- a/src/io_uring.rs +++ b/src/io_uring.rs @@ -4,8 +4,8 @@ use { async_engine::AsyncEngine, io_uring::{ ops::{ - async_cancel::AsyncCancelTask, poll::PollTask, sendmsg::SendmsgTask, - timeout::TimeoutTask, write::WriteTask, + async_cancel::AsyncCancelTask, poll::PollTask, recvmsg::RecvmsgTask, + sendmsg::SendmsgTask, timeout::TimeoutTask, write::WriteTask, }, pending_result::PendingResults, sys::{ @@ -17,6 +17,7 @@ use { utils::{ asyncevent::AsyncEvent, bitflags::BitflagsExt, + buf::Buf, copyhashmap::CopyHashMap, errorfmt::ErrorFmt, mmap::{mmap, Mmapped}, @@ -76,6 +77,8 @@ pub enum IoUringError { Destroyed, #[error("io_uring_enter failed")] Enter(#[source] OsError), + #[error("Kernel sent invalid cmsg data")] + InvalidCmsgData, } pub struct IoUring { @@ -205,7 +208,9 @@ impl IoUring { cached_cancels: Default::default(), cached_polls: Default::default(), cached_sendmsg: Default::default(), + cached_recvmsg: Default::default(), cached_timeouts: Default::default(), + cached_cmsg_bufs: Default::default(), fd_ids_scratch: Default::default(), }); Ok(Rc::new(Self { ring: data })) @@ -251,11 +256,14 @@ struct IoUringData { tasks: CopyHashMap>, pending_results: PendingResults, + cached_writes: Stack>, cached_cancels: Stack>, cached_polls: Stack>, cached_sendmsg: Stack>, + cached_recvmsg: Stack>, cached_timeouts: Stack>, + cached_cmsg_bufs: Stack, fd_ids_scratch: RefCell>, } @@ -432,6 +440,10 @@ impl IoUringData { } } } + + fn cmsg_buf(&self) -> Buf { + self.cached_cmsg_bufs.pop().unwrap_or_else(|| Buf::new(256)) + } } struct Cancellable<'a> { diff --git a/src/io_uring/ops.rs b/src/io_uring/ops.rs index f0786deb..a3607664 100644 --- a/src/io_uring/ops.rs +++ b/src/io_uring/ops.rs @@ -2,6 +2,7 @@ use crate::{io_uring::IoUringError, utils::oserror::OsError}; pub mod async_cancel; pub mod poll; +pub mod recvmsg; pub mod sendmsg; pub mod timeout; pub mod write; diff --git a/src/io_uring/ops/recvmsg.rs b/src/io_uring/ops/recvmsg.rs new file mode 100644 index 00000000..aaa4c864 --- /dev/null +++ b/src/io_uring/ops/recvmsg.rs @@ -0,0 +1,131 @@ +use { + crate::{ + io_uring::{ + pending_result::PendingResult, + sys::{io_uring_sqe, IORING_OP_RECVMSG}, + IoUring, IoUringData, IoUringError, Task, + }, + utils::buf::Buf, + }, + std::{cell::Cell, collections::VecDeque, mem::MaybeUninit, rc::Rc}, + uapi::{c, OwnedFd}, +}; + +impl IoUring { + pub async fn recvmsg( + &self, + fd: &Rc, + bufs: &mut [Buf], + fds: &mut VecDeque, + ) -> Result { + self.ring.check_destroyed()?; + let id = self.ring.id(); + let pr = self.ring.pending_results.acquire(); + let mut cmsg = self.ring.cmsg_buf(); + let cmsg_len; + { + let mut rm = self.ring.cached_recvmsg.pop().unwrap_or_default(); + rm.iovecs.clear(); + for buf in bufs { + rm.bufs.push(buf.clone()); + rm.iovecs.push(c::iovec { + iov_base: buf.as_ptr() as _, + iov_len: buf.len() as _, + }); + } + rm.id = id.id; + rm.fd = fd.raw(); + rm.msghdr.msg_control = cmsg.as_ptr() as _; + rm.msghdr.msg_controllen = cmsg.len() as _; + rm.msghdr.msg_iov = rm.iovecs.as_mut_ptr(); + rm.msghdr.msg_iovlen = rm.iovecs.len() as _; + rm.data = Some(Data { + _cmsg: cmsg.clone(), + _fd: fd.clone(), + pr: pr.clone(), + }); + cmsg_len = rm.cmsg_len.clone(); + self.ring.schedule(rm); + } + macro_rules! return_cmsg { + () => { + self.ring.cached_cmsg_bufs.push(cmsg); + }; + } + match pr.await { + Ok(n) => { + let mut cmsg_data = &cmsg[..cmsg_len.get()]; + while cmsg_data.len() > 0 { + let (_, hdr, data) = match uapi::cmsg_read(&mut cmsg_data) { + Ok(m) => m, + Err(_) => { + return_cmsg!(); + return Err(IoUringError::InvalidCmsgData); + } + }; + if (hdr.cmsg_level, hdr.cmsg_type) == (c::SOL_SOCKET, c::SCM_RIGHTS) { + fds.extend(uapi::pod_iter(data).unwrap()); + } + } + return_cmsg!(); + Ok(n as _) + } + Err(e) => { + return_cmsg!(); + Err(IoUringError::OsError(e)) + } + } + } +} + +struct Data { + _cmsg: Buf, + _fd: Rc, + pr: PendingResult, +} + +pub struct RecvmsgTask { + id: u64, + fd: c::c_int, + bufs: Vec, + iovecs: Vec, + msghdr: c::msghdr, + cmsg_len: Rc>, + data: Option, +} + +impl Default for RecvmsgTask { + fn default() -> Self { + RecvmsgTask { + id: 0, + fd: 0, + bufs: vec![], + iovecs: vec![], + msghdr: unsafe { MaybeUninit::zeroed().assume_init() }, + cmsg_len: Rc::new(Cell::new(0)), + data: None, + } + } +} + +unsafe impl Task for RecvmsgTask { + fn id(&self) -> u64 { + self.id + } + + fn complete(mut self: Box, ring: &IoUringData, res: i32) { + self.cmsg_len.set(self.msghdr.msg_controllen as _); + self.bufs.clear(); + if let Some(data) = self.data.take() { + data.pr.complete(res); + } + ring.cached_recvmsg.push(self); + } + + fn encode(&self, sqe: &mut io_uring_sqe) { + sqe.opcode = IORING_OP_RECVMSG; + sqe.fd = self.fd; + sqe.u2.addr = &self.msghdr as *const _ as _; + sqe.u3.msg_flags = c::MSG_CMSG_CLOEXEC as _; + } +} diff --git a/src/utils/buffd/buf_in.rs b/src/utils/buffd/buf_in.rs index 811ab1ae..54077ebd 100644 --- a/src/utils/buffd/buf_in.rs +++ b/src/utils/buffd/buf_in.rs @@ -1,10 +1,14 @@ use { crate::{ io_uring::IoUring, - utils::buffd::{BufFdError, BUF_SIZE, CMSG_BUF_SIZE, MAX_IN_FD}, + utils::{ + buf::Buf, + buffd::{BufFdError, BUF_SIZE, MAX_IN_FD}, + }, }, + smallvec::SmallVec, std::{collections::VecDeque, mem::MaybeUninit, rc::Rc}, - uapi::{c, Errno, OwnedFd, Pod}, + uapi::{OwnedFd, Pod}, }; pub struct BufFdIn { @@ -13,8 +17,7 @@ pub struct BufFdIn { in_fd: VecDeque, - in_buf: Box<[MaybeUninit; BUF_SIZE]>, - in_cmsg_buf: Box<[MaybeUninit; CMSG_BUF_SIZE]>, + in_buf: Buf, in_left: usize, in_right: usize, } @@ -25,8 +28,7 @@ impl BufFdIn { fd: fd.clone(), ring: ring.clone(), in_fd: Default::default(), - in_buf: Box::new([MaybeUninit::uninit(); BUF_SIZE]), - in_cmsg_buf: Box::new([MaybeUninit::uninit(); CMSG_BUF_SIZE]), + in_buf: Buf::new(BUF_SIZE), in_left: 0, in_right: 0, } @@ -36,73 +38,52 @@ impl BufFdIn { let bytes = unsafe { uapi::as_maybe_uninit_bytes_mut2(buf) }; let mut offset = 0; while offset < bytes.len() { - if self.read_full_(bytes, &mut offset)? { - self.ring.readable(&self.fd).await?; - } + self.read_full_(bytes, &mut offset).await?; } Ok(()) } - fn read_full_( + async fn read_full_( &mut self, bytes: &mut [MaybeUninit], offset: &mut usize, - ) -> Result { + ) -> Result<(), BufFdError> { + let in_buf = uapi::as_maybe_uninit_bytes(&self.in_buf[..]); let num_bytes = (bytes.len() - *offset).min(self.in_right - self.in_left); if num_bytes > 0 { let left = self.in_left % BUF_SIZE; let right = (self.in_left + num_bytes) % BUF_SIZE; if left < right { - bytes[*offset..*offset + num_bytes].copy_from_slice(&self.in_buf[left..right]); + bytes[*offset..*offset + num_bytes].copy_from_slice(&in_buf[left..right]); } else { - bytes[*offset..*offset + (BUF_SIZE - left)].copy_from_slice(&self.in_buf[left..]); + bytes[*offset..*offset + (BUF_SIZE - left)].copy_from_slice(&in_buf[left..]); bytes[*offset + (BUF_SIZE - left)..*offset + num_bytes] - .copy_from_slice(&self.in_buf[..right]); + .copy_from_slice(&in_buf[..right]); } self.in_left += num_bytes; *offset += num_bytes; } if *offset == bytes.len() { - return Ok(false); + return Ok(()); } let left = self.in_left % BUF_SIZE; let right = self.in_right % BUF_SIZE; - let mut iov = if right < left { - [&mut self.in_buf[right..left], &mut []] + let mut iov = SmallVec::<[_; 2]>::new(); + if right < left { + iov.push(self.in_buf.slice(right..left)); } else { - let (l, r) = self.in_buf.split_at_mut(right); - [r, &mut l[..left]] - }; - let mut hdr = uapi::MsghdrMut { - iov: &mut iov[..], - control: Some(&mut self.in_cmsg_buf[..]), - name: uapi::sockaddr_none_mut(), - flags: 0, - }; - let (iov, _, mut cmsg) = match uapi::recvmsg( - self.fd.raw(), - &mut hdr, - c::MSG_DONTWAIT | c::MSG_CMSG_CLOEXEC, - ) { - Ok((iov, _, _)) if iov.is_empty() => return Err(BufFdError::Closed), - Ok(v) => v, - Err(Errno(c::EAGAIN)) => return Ok(true), - Err(e) => return Err(BufFdError::Io(e.into())), - }; - self.in_right += iov.len(); - while cmsg.len() > 0 { - let (_, hdr, data) = match uapi::cmsg_read(&mut cmsg) { - Ok(m) => m, - Err(e) => return Err(BufFdError::Io(e.into())), - }; - if (hdr.cmsg_level, hdr.cmsg_type) == (c::SOL_SOCKET, c::SCM_RIGHTS) { - self.in_fd.extend(uapi::pod_iter(data).unwrap()); - } + iov.push(self.in_buf.slice(right..)); + iov.push(self.in_buf.slice(..left)); + } + match self.ring.recvmsg(&self.fd, &mut iov, &mut self.in_fd).await { + Ok(0) => return Err(BufFdError::Closed), + Ok(n) => self.in_right += n, + Err(e) => return Err(BufFdError::Ring(e.into())), } if self.in_fd.len() > MAX_IN_FD { return Err(BufFdError::TooManyFds); } - Ok(false) + Ok(()) } pub fn get_fd(&mut self) -> Result { diff --git a/src/wheel.rs b/src/wheel.rs index 76f538b3..e1602792 100644 --- a/src/wheel.rs +++ b/src/wheel.rs @@ -167,7 +167,6 @@ impl Wheel { let expiration = (now + Duration::from_millis(ms)).round_to_ms(); let current = self.data.current_expiration.get(); if current.is_none() || expiration - self.data.start < current.unwrap() - self.data.start { - log::info!("programming timer {}", self.data.fd.raw()); let res = uapi::timerfd_settime( self.data.fd.raw(), c::TFD_TIMER_ABSTIME,