1
0
Fork 0
forked from wry/wry
wry/src/forker/proxy.rs

260 lines
7.6 KiB
Rust

use {
crate::{
async_engine::SpawnedFuture,
compositor::LIBEI_SOCKET,
forker::{
ForkerError,
io::{IoIn, IoOut},
protocol::{ForkerMessage, ServerMessage},
worker::Forker,
},
state::State,
utils::{
clone3::double_fork,
copyhashmap::CopyHashMap,
errorfmt::ErrorFmt,
numcell::NumCell,
oserror::OsErrorExt2,
queue::AsyncQueue,
},
xwayland,
},
log::Level,
std::{
cell::{Cell, RefCell},
rc::{Rc, Weak},
task::{Poll, Waker},
},
uapi::{OwnedFd, c},
};
pub struct ForkerProxy {
pidfd: Rc<OwnedFd>,
socket: Rc<OwnedFd>,
task_in: Cell<Option<SpawnedFuture<()>>>,
task_out: Cell<Option<SpawnedFuture<()>>>,
task_proc: Cell<Option<SpawnedFuture<()>>>,
outgoing: AsyncQueue<ServerMessage>,
next_id: NumCell<u32>,
pending_pidfds: CopyHashMap<u32, Weak<PidfdHandoff>>,
fds: RefCell<Vec<Rc<OwnedFd>>>,
}
struct PidfdHandoff {
pidfd: Cell<Option<Result<(Rc<OwnedFd>, c::pid_t), ForkerError>>>,
waiter: Cell<Option<Waker>>,
}
impl ForkerProxy {
pub fn clear(&self) {
self.task_in.take();
self.task_out.take();
self.task_proc.take();
self.outgoing.clear();
}
pub fn create(reaper_pid: c::pid_t) -> Result<Self, ForkerError> {
let (parent, child) = uapi::socketpair(c::AF_UNIX, c::SOCK_STREAM | c::SOCK_CLOEXEC, 0)
.map_os_err(ForkerError::Socketpair)?;
match double_fork()? {
Some(pidfd) => Ok(ForkerProxy {
pidfd: Rc::new(pidfd),
socket: Rc::new(parent),
task_in: Cell::new(None),
task_out: Cell::new(None),
task_proc: Cell::new(None),
outgoing: Default::default(),
next_id: Default::default(),
pending_pidfds: Default::default(),
fds: Default::default(),
}),
None => {
drop(parent);
Forker::handle(reaper_pid, child)
}
}
}
pub fn install(self: &Rc<Self>, state: &Rc<State>) {
state.forker.set(Some(self.clone()));
self.task_proc.set(Some(state.eng.spawn(
"forker check process",
self.clone().check_process(state.clone()),
)));
self.task_in.set(Some(
state
.eng
.spawn("forker incoming", self.clone().incoming(state.clone())),
));
self.task_out.set(Some(
state
.eng
.spawn("forker outgoing", self.clone().outgoing(state.clone())),
));
}
pub fn setenv(&self, key: &[u8], val: &[u8]) {
self.outgoing.push(ServerMessage::SetEnv {
var: key.to_vec(),
val: Some(val.to_vec()),
})
}
pub fn unsetenv(&self, key: &[u8]) {
self.outgoing.push(ServerMessage::SetEnv {
var: key.to_vec(),
val: None,
})
}
async fn pidfd(&self, id: u32) -> Result<(Rc<OwnedFd>, c::pid_t), ForkerError> {
let handoff = Rc::new(PidfdHandoff {
pidfd: Cell::new(None),
waiter: Cell::new(None),
});
self.pending_pidfds.set(id, Rc::downgrade(&handoff));
futures_util::future::poll_fn(|ctx| {
if let Some(pidfd) = handoff.pidfd.take() {
Poll::Ready(pidfd)
} else {
handoff.waiter.set(Some(ctx.waker().clone()));
Poll::Pending
}
})
.await
}
pub async fn xwayland(
&self,
state: &State,
stderr: Rc<OwnedFd>,
dfd: Rc<OwnedFd>,
listenfd: Rc<OwnedFd>,
wmfd: Rc<OwnedFd>,
waylandfd: Rc<OwnedFd>,
) -> Result<(Rc<OwnedFd>, c::pid_t), ForkerError> {
let (prog, args) = xwayland::build_args(state, self).await;
let env = vec![
("WAYLAND_SOCKET".to_string(), Some("6".to_string())),
(LIBEI_SOCKET.to_string(), None),
];
let fds = vec![
(2, stderr),
(3, dfd),
(4, listenfd),
(5, wmfd),
(6, waylandfd),
];
let pidfd_id = self.next_id.fetch_add(1);
self.spawn_(prog, args, env, fds, Some(pidfd_id));
self.pidfd(pidfd_id).await
}
pub fn spawn(
&self,
prog: String,
args: Vec<String>,
env: Vec<(String, Option<String>)>,
fds: Vec<(i32, Rc<OwnedFd>)>,
) {
self.spawn_(prog, args, env, fds, None)
}
fn spawn_(
&self,
prog: String,
args: Vec<String>,
env: Vec<(String, Option<String>)>,
fds: Vec<(i32, Rc<OwnedFd>)>,
pidfd_id: Option<u32>,
) {
for (_, fd) in &fds {
self.fds.borrow_mut().push(fd.clone());
}
let fds = fds.into_iter().map(|(a, _)| a).collect();
self.outgoing.push(ServerMessage::Spawn {
prog,
args,
env,
fds,
pidfd_id,
})
}
async fn incoming(self: Rc<Self>, state: Rc<State>) {
let mut io = IoIn::new(&self.socket, &state.ring);
loop {
let msg = match io.read_msg().await {
Ok(msg) => msg,
Err(e) => {
log::error!("Could not read from the ol' forker: {}", ErrorFmt(e));
self.task_in.take();
return;
}
};
self.handle_msg(msg, &mut io);
}
}
fn handle_msg(&self, msg: ForkerMessage, io: &mut IoIn) {
match msg {
ForkerMessage::Log { level, msg } => self.handle_log(level, &msg),
ForkerMessage::PidFd { id, success, pid } => self.handle_pidfd(id, success, io, pid),
}
}
fn handle_pidfd(&self, id: u32, success: bool, io: &mut IoIn, pid: c::pid_t) {
let res = match success {
true => Ok((io.pop_fd().unwrap(), pid)),
_ => Err(ForkerError::PidfdForkFailed),
};
if let Some(handoff) = self.pending_pidfds.remove(&id)
&& let Some(handoff) = handoff.upgrade()
{
handoff.pidfd.set(Some(res));
if let Some(w) = handoff.waiter.take() {
w.wake();
}
}
}
fn handle_log(&self, level: usize, msg: &str) {
let level = match level {
1 => Level::Error,
2 => Level::Warn,
3 => Level::Info,
4 => Level::Debug,
5 => Level::Trace,
_ => Level::Error,
};
log::log!(level, "{}", msg);
}
async fn outgoing(self: Rc<Self>, state: Rc<State>) {
let mut io = IoOut::new(&self.socket, &state.ring);
loop {
let msg = self.outgoing.pop().await;
for fd in self.fds.borrow_mut().drain(..) {
io.push_fd(fd);
}
if let Err(e) = io.write_msg(msg).await {
log::error!("Could not write to the ol' forker: {}", ErrorFmt(e));
self.clear();
state.forker.set(None);
return;
}
}
}
async fn check_process(self: Rc<Self>, state: Rc<State>) {
if let Err(e) = state.ring.readable(&self.pidfd).await {
log::error!(
"Cannot wait for the forker pidfd to become readable: {}",
ErrorFmt(e)
);
}
log::error!("The ol' forker died. Cannot spawn further processes.");
self.clear();
state.forker.set(None);
}
}