pub mod security;
use crate::LOG_TARGET;
use cpu_time::ProcessTime;
use futures::never::Never;
use std::{
any::Any,
path::PathBuf,
sync::mpsc::{Receiver, RecvTimeoutError},
time::Duration,
};
use tokio::{io, net::UnixStream, runtime::Runtime};
#[macro_export]
macro_rules! decl_worker_main {
($expected_command:expr, $entrypoint:expr, $worker_version:expr) => {
fn print_help(expected_command: &str) {
println!("{} {}", expected_command, $worker_version);
println!();
println!("PVF worker that is called by polkadot.");
}
fn main() {
$crate::sp_tracing::try_init_simple();
let args = std::env::args().collect::<Vec<_>>();
if args.len() == 1 {
print_help($expected_command);
return
}
match args[1].as_ref() {
"--help" | "-h" => {
print_help($expected_command);
return
},
"--version" | "-v" => {
println!("{}", $worker_version);
return
},
subcommand => {
if subcommand != $expected_command {
panic!(
"trying to run {} binary with the {} subcommand",
$expected_command, subcommand
)
}
},
}
let mut node_version = None;
let mut socket_path: &str = "";
for i in (2..args.len()).step_by(2) {
match args[i].as_ref() {
"--socket-path" => socket_path = args[i + 1].as_str(),
"--node-impl-version" => node_version = Some(args[i + 1].as_str()),
arg => panic!("Unexpected argument found: {}", arg),
}
}
$entrypoint(&socket_path, node_version, Some($worker_version));
}
};
}
pub const JOB_TIMEOUT_OVERHEAD: Duration = Duration::from_millis(50);
pub fn bytes_to_path(bytes: &[u8]) -> Option<PathBuf> {
std::str::from_utf8(bytes).ok().map(PathBuf::from)
}
pub fn worker_event_loop<F, Fut>(
debug_id: &'static str,
socket_path: &str,
node_version: Option<&str>,
worker_version: Option<&str>,
mut event_loop: F,
) where
F: FnMut(UnixStream) -> Fut,
Fut: futures::Future<Output = io::Result<Never>>,
{
let worker_pid = std::process::id();
gum::debug!(target: LOG_TARGET, %worker_pid, "starting pvf worker ({})", debug_id);
if let (Some(node_version), Some(worker_version)) = (node_version, worker_version) {
if node_version != worker_version {
gum::error!(
target: LOG_TARGET,
%worker_pid,
%node_version,
%worker_version,
"Node and worker version mismatch, node needs restarting, forcing shutdown",
);
kill_parent_node_in_emergency();
let err = io::Error::new(io::ErrorKind::Unsupported, "Version mismatch");
worker_shutdown_message(debug_id, worker_pid, err);
return
}
}
remove_env_vars(debug_id);
let rt = Runtime::new().expect("Creates tokio runtime. If this panics the worker will die and the host will detect that and deal with it.");
let err = rt
.block_on(async move {
let stream = UnixStream::connect(socket_path).await?;
let _ = tokio::fs::remove_file(socket_path).await;
let result = event_loop(stream).await;
result
})
.unwrap_err();
worker_shutdown_message(debug_id, worker_pid, err);
rt.shutdown_background();
}
fn remove_env_vars(debug_id: &'static str) {
for (key, value) in std::env::vars_os() {
if key == "RUST_LOG" {
continue
}
let mut err_reasons = vec![];
let (key_str, value_str) = (key.to_str(), value.to_str());
if key.is_empty() {
err_reasons.push("key is empty");
}
if key_str.is_some_and(|s| s.contains('=')) {
err_reasons.push("key contains '='");
}
if key_str.is_some_and(|s| s.contains('\0')) {
err_reasons.push("key contains null character");
}
if value_str.is_some_and(|s| s.contains('\0')) {
err_reasons.push("value contains null character");
}
if !err_reasons.is_empty() {
gum::warn!(
target: LOG_TARGET,
%debug_id,
?key,
?value,
"Attempting to remove badly-formatted env var, this may cause the PVF worker to crash. Please remove it yourself. Reasons: {:?}",
err_reasons
);
}
std::env::remove_var(key);
}
}
fn worker_shutdown_message(debug_id: &'static str, worker_pid: u32, err: io::Error) {
gum::debug!(target: LOG_TARGET, %worker_pid, "quitting pvf worker ({}): {:?}", debug_id, err);
}
pub fn cpu_time_monitor_loop(
cpu_time_start: ProcessTime,
timeout: Duration,
finished_rx: Receiver<()>,
) -> Option<Duration> {
loop {
let cpu_time_elapsed = cpu_time_start.elapsed();
if cpu_time_elapsed <= timeout {
let sleep_interval = timeout.saturating_sub(cpu_time_elapsed) + JOB_TIMEOUT_OVERHEAD;
match finished_rx.recv_timeout(sleep_interval) {
Ok(()) => return None,
Err(RecvTimeoutError::Timeout) => continue,
Err(RecvTimeoutError::Disconnected) => return None,
}
}
return Some(cpu_time_elapsed)
}
}
pub fn stringify_panic_payload(payload: Box<dyn Any + Send + 'static>) -> String {
match payload.downcast::<&'static str>() {
Ok(msg) => msg.to_string(),
Err(payload) => match payload.downcast::<String>() {
Ok(msg) => *msg,
Err(_) => "unknown panic payload".to_string(),
},
}
}
fn kill_parent_node_in_emergency() {
unsafe {
let ppid = libc::getppid();
if ppid > 1 {
libc::kill(ppid, libc::SIGTERM);
}
}
}
pub mod thread {
use std::{
panic,
sync::{Arc, Condvar, Mutex},
thread,
time::Duration,
};
#[derive(Debug, Clone, Copy)]
pub enum WaitOutcome {
Finished,
TimedOut,
Pending,
}
impl WaitOutcome {
pub fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
}
pub type Cond = Arc<(Mutex<WaitOutcome>, Condvar)>;
pub fn get_condvar() -> Cond {
Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()))
}
pub fn spawn_worker_thread<F, R>(
name: &str,
f: F,
cond: Cond,
outcome: WaitOutcome,
) -> std::io::Result<thread::JoinHandle<R>>
where
F: FnOnce() -> R,
F: Send + 'static + panic::UnwindSafe,
R: Send + 'static,
{
thread::Builder::new()
.name(name.into())
.spawn(move || cond_notify_on_done(f, cond, outcome))
}
pub fn spawn_worker_thread_with_stack_size<F, R>(
name: &str,
f: F,
cond: Cond,
outcome: WaitOutcome,
stack_size: usize,
) -> std::io::Result<thread::JoinHandle<R>>
where
F: FnOnce() -> R,
F: Send + 'static + panic::UnwindSafe,
R: Send + 'static,
{
thread::Builder::new()
.name(name.into())
.stack_size(stack_size)
.spawn(move || cond_notify_on_done(f, cond, outcome))
}
fn cond_notify_on_done<F, R>(f: F, cond: Cond, outcome: WaitOutcome) -> R
where
F: FnOnce() -> R,
F: panic::UnwindSafe,
{
let result = panic::catch_unwind(|| f());
cond_notify_all(cond, outcome);
match result {
Ok(inner) => return inner,
Err(err) => panic::resume_unwind(err),
}
}
fn cond_notify_all(cond: Cond, outcome: WaitOutcome) {
let (lock, cvar) = &*cond;
let mut flag = lock.lock().unwrap();
if !flag.is_pending() {
return
}
*flag = outcome;
cvar.notify_all();
}
pub fn wait_for_threads(cond: Cond) -> WaitOutcome {
let (lock, cvar) = &*cond;
let guard = cvar.wait_while(lock.lock().unwrap(), |flag| flag.is_pending()).unwrap();
*guard
}
#[cfg_attr(not(any(target_os = "linux", feature = "jemalloc-allocator")), allow(dead_code))]
pub fn wait_for_threads_with_timeout(cond: &Cond, dur: Duration) -> Option<WaitOutcome> {
let (lock, cvar) = &**cond;
let result = cvar
.wait_timeout_while(lock.lock().unwrap(), dur, |flag| flag.is_pending())
.unwrap();
if result.1.timed_out() {
None
} else {
Some(*result.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
#[test]
fn get_condvar_should_be_pending() {
let condvar = get_condvar();
let outcome = *condvar.0.lock().unwrap();
assert!(outcome.is_pending());
}
#[test]
fn wait_for_threads_with_timeout_return_none_on_time_out() {
let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
let outcome = wait_for_threads_with_timeout(&condvar, Duration::from_millis(100));
assert!(outcome.is_none());
}
#[test]
fn wait_for_threads_with_timeout_returns_outcome() {
let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
let condvar2 = condvar.clone();
cond_notify_all(condvar2, WaitOutcome::Finished);
let outcome = wait_for_threads_with_timeout(&condvar, Duration::from_secs(2));
assert_matches!(outcome.unwrap(), WaitOutcome::Finished);
}
#[test]
fn spawn_worker_thread_should_notify_on_done() {
let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
let response =
spawn_worker_thread("thread", || 2, condvar.clone(), WaitOutcome::TimedOut);
let (lock, _) = &*condvar;
let r = response.unwrap().join().unwrap();
assert_eq!(r, 2);
assert_matches!(*lock.lock().unwrap(), WaitOutcome::TimedOut);
}
#[test]
fn spawn_worker_should_not_change_finished_outcome() {
let condvar = Arc::new((Mutex::new(WaitOutcome::Finished), Condvar::new()));
let response =
spawn_worker_thread("thread", move || 2, condvar.clone(), WaitOutcome::TimedOut);
let r = response.unwrap().join().unwrap();
assert_eq!(r, 2);
assert_matches!(*condvar.0.lock().unwrap(), WaitOutcome::Finished);
}
#[test]
fn cond_notify_on_done_should_update_wait_outcome_when_panic() {
let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
let err = panic::catch_unwind(panic::AssertUnwindSafe(|| {
cond_notify_on_done(|| panic!("test"), condvar.clone(), WaitOutcome::Finished)
}));
assert_matches!(*condvar.0.lock().unwrap(), WaitOutcome::Finished);
assert!(err.is_err());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc::channel;
#[test]
fn cpu_time_monitor_loop_should_return_time_elapsed() {
let cpu_time_start = ProcessTime::now();
let timeout = Duration::from_secs(0);
let (_tx, rx) = channel();
let result = cpu_time_monitor_loop(cpu_time_start, timeout, rx);
assert_ne!(result, None);
}
#[test]
fn cpu_time_monitor_loop_should_return_none() {
let cpu_time_start = ProcessTime::now();
let timeout = Duration::from_secs(10);
let (tx, rx) = channel();
tx.send(()).unwrap();
let result = cpu_time_monitor_loop(cpu_time_start, timeout, rx);
assert_eq!(result, None);
}
}