diff --git a/src/compositor.rs b/src/compositor.rs index 91f0d8bb..02d09347 100644 --- a/src/compositor.rs +++ b/src/compositor.rs @@ -44,9 +44,9 @@ use { }, user_session::import_environment, utils::{ - clonecell::CloneCell, errorfmt::ErrorFmt, fdcloser::FdCloser, numcell::NumCell, - oserror::OsError, queue::AsyncQueue, refcounted::RefCounted, run_toplevel::RunToplevel, - tri::Try, + clone3::ensure_reaper, clonecell::CloneCell, errorfmt::ErrorFmt, fdcloser::FdCloser, + numcell::NumCell, oserror::OsError, queue::AsyncQueue, refcounted::RefCounted, + run_toplevel::RunToplevel, tri::Try, }, version::VERSION, video::drm::wait_for_sync_obj::WaitForSyncObj, @@ -64,7 +64,8 @@ pub const MAX_EXTENTS: i32 = (1 << 22) - 1; pub fn start_compositor(global: GlobalArgs, args: RunArgs) { sighand::reset_all(); - let forker = create_forker(); + let reaper_pid = ensure_reaper(); + let forker = create_forker(reaper_pid); let portal = portal::run_from_compositor(global.log_level.into()); enable_profiler(); let logger = Logger::install_compositor(global.log_level.into()); @@ -94,8 +95,8 @@ pub fn start_compositor_for_test(future: TestFuture) -> Result<(), CompositorErr res } -fn create_forker() -> Rc { - match ForkerProxy::create() { +fn create_forker(reaper_pid: c::pid_t) -> Rc { + match ForkerProxy::create(reaper_pid) { Ok(f) => Rc::new(f), Err(e) => fatal!("Could not create a forker process: {}", ErrorFmt(e)), } diff --git a/src/forker.rs b/src/forker.rs index 35383b1b..668b865c 100644 --- a/src/forker.rs +++ b/src/forker.rs @@ -9,7 +9,7 @@ use { state::State, utils::{ buffd::BufFdError, - clone3::{Forked, fork_with_pidfd}, + clone3::{Forked, double_fork, fork_with_pidfd}, copyhashmap::CopyHashMap, errorfmt::ErrorFmt, numcell::NumCell, @@ -38,7 +38,6 @@ use { pub struct ForkerProxy { pidfd: Rc, - pid: c::pid_t, socket: Rc, task_in: Cell>>, task_out: Cell>>, @@ -70,6 +69,12 @@ pub enum ForkerError { EncodeFailed(#[source] bincode::Error), #[error("Could not fork")] PidfdForkFailed, + #[error("Could not receive pidfd from child")] + RecvPidfd(#[source] crate::utils::oserror::OsError), + #[error("Could not read cmsg")] + CmsgRead(#[source] crate::utils::oserror::OsError), + #[error("Cmsg has an unexpected form")] + InvalidCmsg, } impl ForkerProxy { @@ -80,17 +85,15 @@ impl ForkerProxy { self.outgoing.clear(); } - pub fn create() -> Result { + pub fn create(reaper_pid: c::pid_t) -> Result { let (parent, child) = match uapi::socketpair(c::AF_UNIX, c::SOCK_STREAM | c::SOCK_CLOEXEC, 0) { Ok(o) => o, Err(e) => return Err(ForkerError::Socketpair(e.into())), }; - let pid = uapi::getpid(); - match fork_with_pidfd(false)? { - Forked::Parent { pid, pidfd } => Ok(ForkerProxy { + match double_fork()? { + Some(pidfd) => Ok(ForkerProxy { pidfd: Rc::new(pidfd), - pid, socket: Rc::new(parent), task_in: Cell::new(None), task_out: Cell::new(None), @@ -100,9 +103,9 @@ impl ForkerProxy { pending_pidfds: Default::default(), fds: Default::default(), }), - Forked::Child { .. } => { + None => { drop(parent); - Forker::handle(pid, child) + Forker::handle(reaper_pid, child) } } } @@ -284,8 +287,6 @@ impl ForkerProxy { "Cannot wait for the forker pidfd to become readable: {}", ErrorFmt(e) ); - } else { - let _ = uapi::waitpid(self.pid, 0); } log::error!("The ol' forker died. Cannot spawn further processes."); self.clear(); diff --git a/src/utils/clone3.rs b/src/utils/clone3.rs index a861134b..2d2e0feb 100644 --- a/src/utils/clone3.rs +++ b/src/utils/clone3.rs @@ -1,6 +1,10 @@ use { - crate::forker::ForkerError, - uapi::{OwnedFd, c}, + crate::{ + forker::ForkerError, + utils::{errorfmt::ErrorFmt, on_drop::OnDrop, process_name::set_process_name}, + }, + std::{env, mem::MaybeUninit, process, slice, str::FromStr}, + uapi::{Msghdr, MsghdrMut, OwnedFd, c}, }; #[derive(Default, Copy, Clone)] @@ -21,9 +25,11 @@ struct clone_args { pub enum Forked { Parent { pid: c::pid_t, pidfd: OwnedFd }, - Child { _pidfd: Option }, + Child { pidfd: Option }, } +const REAPER_VAR: &str = "JAY_REAPER_PID"; + pub fn fork_with_pidfd(pidfd_for_child: bool) -> Result { let mut pidfd: c::c_int = 0; let mut args = clone_args { @@ -46,9 +52,8 @@ pub fn fork_with_pidfd(pidfd_for_child: bool) -> Result { return Err(ForkerError::Fork(e.into())); } let res = if pid == 0 { - Forked::Child { - _pidfd: child_pidfd, - } + env::remove_var(REAPER_VAR); + Forked::Child { pidfd: child_pidfd } } else { Forked::Parent { pid: pid as _, @@ -58,3 +63,109 @@ pub fn fork_with_pidfd(pidfd_for_child: bool) -> Result { Ok(res) } } + +pub fn double_fork() -> Result, ForkerError> { + let (p, c) = uapi::socketpair(c::AF_UNIX, c::SOCK_DGRAM | c::SOCK_CLOEXEC, 0) + .map_err(|e| ForkerError::Socketpair(e.into()))?; + match fork_with_pidfd(false)? { + Forked::Parent { pid, .. } => { + drop(c); + let mut buf = [MaybeUninit::::uninit(); 128]; + let iov: &mut [&mut [u8]] = &mut []; + let mut msghdr = MsghdrMut { + iov, + control: Some(&mut buf), + name: uapi::sockaddr_none_mut(), + flags: 0, + }; + let _wait = OnDrop(|| { + let _ = uapi::waitpid(pid, 0); + }); + let (_, _, mut ctrl) = uapi::recvmsg(p.raw(), &mut msghdr, c::MSG_CMSG_CLOEXEC) + .map_err(|e| ForkerError::RecvPidfd(e.into()))?; + let (_, hdr, data) = + uapi::cmsg_read(&mut ctrl).map_err(|e| ForkerError::CmsgRead(e.into()))?; + if hdr.cmsg_level != c::SOL_SOCKET || hdr.cmsg_type != c::SCM_RIGHTS { + return Err(ForkerError::InvalidCmsg); + } + let Ok(fd) = uapi::pod_read(data) else { + return Err(ForkerError::InvalidCmsg); + }; + Ok(Some(OwnedFd::new(fd))) + } + Forked::Child { .. } => { + drop(p); + if let Ok(f) = fork_with_pidfd(true) { + match f { + Forked::Parent { pidfd, .. } => { + let pidfd = pidfd.raw(); + let mut buf = [MaybeUninit::uninit(); 128]; + let hdr = c::cmsghdr { + cmsg_len: 0, + cmsg_level: c::SOL_SOCKET, + cmsg_type: c::SCM_RIGHTS, + }; + let _ = uapi::cmsg_write(&mut &mut buf[..], hdr, &pidfd); + let iov: &[&[u8]] = &[]; + let msghdr = Msghdr { + iov, + control: Some(&buf[..uapi::cmsg_space(size_of_val(&pidfd))]), + name: uapi::sockaddr_none_ref(), + }; + let _ = uapi::sendmsg(c.raw(), &msghdr, 0); + } + Forked::Child { pidfd } => { + let pidfd = pidfd.unwrap(); + let mut pollfd = c::pollfd { + fd: pidfd.raw(), + events: c::POLLIN as _, + revents: 0, + }; + let _ = uapi::poll(slice::from_mut(&mut pollfd), -1); + return Ok(None); + } + } + }; + unsafe { + c::_exit(0); + } + } + } +} + +pub fn ensure_reaper() -> c::pid_t { + if let Ok(id) = env::var(REAPER_VAR) { + if let Ok(id) = c::pid_t::from_str(&id) { + if uapi::getppid() == id { + return id; + } + } + } + let reaper_pid = uapi::getpid(); + unsafe { + c::prctl(c::PR_SET_CHILD_SUBREAPER, 1); + } + let res = match fork_with_pidfd(false) { + Ok(r) => r, + Err(e) => { + fatal!("Could not fork reaper: {}", ErrorFmt(e)); + } + }; + let Forked::Parent { + pid: main_process_id, + .. + } = res + else { + unsafe { + env::set_var(REAPER_VAR, reaper_pid.to_string()); + } + return reaper_pid; + }; + set_process_name("jay reaper"); + while let Ok((pid, status)) = uapi::wait() { + if pid == main_process_id { + process::exit(uapi::WEXITSTATUS(status)); + } + } + process::exit(1); +}