Files
lighthouse/common/task_executor/src/lib.rs
Abhivansh 9b3d7e3a54 refactor: remove service_name (#8606)
Which issue # does this PR address?
#8586


  Please list or describe the changes introduced by this PR.
Remove `service_name` from `TaskExecutor`


Co-Authored-By: Abhivansh <31abhivanshj@gmail.com>
2026-01-02 00:07:40 +00:00

412 lines
15 KiB
Rust

mod metrics;
mod rayon_pool_provider;
pub mod test_utils;
use futures::channel::mpsc::Sender;
use futures::prelude::*;
use std::sync::{Arc, Weak};
use tokio::runtime::{Handle, Runtime};
use tracing::debug;
use crate::rayon_pool_provider::RayonPoolProvider;
pub use crate::rayon_pool_provider::RayonPoolType;
pub use tokio::task::JoinHandle;
/// Provides a reason when Lighthouse is shut down.
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum ShutdownReason {
/// The node shut down successfully.
Success(&'static str),
/// The node shut down due to an error condition.
Failure(&'static str),
}
impl ShutdownReason {
pub fn message(&self) -> &'static str {
match self {
ShutdownReason::Success(msg) => msg,
ShutdownReason::Failure(msg) => msg,
}
}
}
/// Provides a `Handle` by either:
///
/// 1. Holding a `Weak<Runtime>` and calling `Runtime::handle`.
/// 2. Directly holding a `Handle` and cloning it.
///
/// This enum allows the `TaskExecutor` to work in production where a `Weak<Runtime>` is directly
/// accessible and in testing where the `Runtime` is hidden outside our scope.
#[derive(Clone)]
pub enum HandleProvider {
Runtime(Weak<Runtime>),
Handle(Handle),
}
impl From<Handle> for HandleProvider {
fn from(handle: Handle) -> Self {
HandleProvider::Handle(handle)
}
}
impl From<Weak<Runtime>> for HandleProvider {
fn from(weak_runtime: Weak<Runtime>) -> Self {
HandleProvider::Runtime(weak_runtime)
}
}
impl HandleProvider {
/// Returns a `Handle` to a `Runtime`.
///
/// May return `None` if the weak reference to the `Runtime` has been dropped (this generally
/// means Lighthouse is shutting down).
pub fn handle(&self) -> Option<Handle> {
match self {
HandleProvider::Runtime(weak_runtime) => weak_runtime
.upgrade()
.map(|runtime| runtime.handle().clone()),
HandleProvider::Handle(handle) => Some(handle.clone()),
}
}
}
/// A wrapper over a runtime handle which can spawn async and blocking tasks.
#[derive(Clone)]
pub struct TaskExecutor {
/// The handle to the runtime on which tasks are spawned
handle_provider: HandleProvider,
/// The receiver exit future which on receiving shuts down the task
exit: async_channel::Receiver<()>,
/// Sender given to tasks, so that if they encounter a state in which execution cannot
/// continue they can request that everything shuts down.
///
/// The task must provide a reason for shutting down.
signal_tx: Sender<ShutdownReason>,
rayon_pool_provider: Arc<RayonPoolProvider>,
}
impl TaskExecutor {
/// Create a new task executor.
///
/// ## Note
///
/// This function should only be used during testing. In production, prefer to obtain an
/// instance of `Self` via a `environment::RuntimeContext` (see the `lighthouse/environment`
/// crate).
pub fn new<T: Into<HandleProvider>>(
handle: T,
exit: async_channel::Receiver<()>,
signal_tx: Sender<ShutdownReason>,
) -> Self {
Self {
handle_provider: handle.into(),
exit,
signal_tx,
rayon_pool_provider: Arc::new(RayonPoolProvider::default()),
}
}
/// A convenience wrapper for `Self::spawn` which ignores a `Result` as long as both `Ok`/`Err`
/// are of type `()`.
///
/// The purpose of this function is to create a compile error if some function which previously
/// returned `()` starts returning something else. Such a case may otherwise result in
/// accidental error suppression.
pub fn spawn_ignoring_error(
&self,
task: impl Future<Output = Result<(), ()>> + Send + 'static,
name: &'static str,
) {
self.spawn(task.map(|_| ()), name)
}
/// Spawn a task to monitor the completion of another task.
///
/// If the other task exits by panicking, then the monitor task will shut down the executor.
fn spawn_monitor<R: Send>(
&self,
task_handle: impl Future<Output = Result<R, tokio::task::JoinError>> + Send + 'static,
name: &'static str,
) {
let mut shutdown_sender = self.shutdown_sender();
if let Some(handle) = self.handle() {
let fut = async move {
let timer = metrics::start_timer_vec(&metrics::TASKS_HISTOGRAM, &[name]);
if let Err(join_error) = task_handle.await
&& let Ok(_panic) = join_error.try_into_panic()
{
let _ =
shutdown_sender.try_send(ShutdownReason::Failure("Panic (fatal error)"));
}
drop(timer);
};
#[cfg(tokio_unstable)]
tokio::task::Builder::new()
.name(&format!("{name}-monitor"))
.spawn_on(fut, &handle)
.expect("Failed to spawn monitor task");
#[cfg(not(tokio_unstable))]
handle.spawn(fut);
} else {
debug!("Couldn't spawn monitor task. Runtime shutting down")
}
}
/// Spawn a future on the tokio runtime.
///
/// The future is wrapped in an `async-channel::Receiver`. The task is cancelled when the corresponding
/// Sender is dropped.
///
/// The future is monitored via another spawned future to ensure that it doesn't panic. In case
/// of a panic, the executor will be shut down via `self.signal_tx`.
///
/// This function generates prometheus metrics on number of tasks and task duration.
pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static, name: &'static str) {
if let Some(task_handle) = self.spawn_handle(task, name) {
self.spawn_monitor(task_handle, name)
}
}
/// Spawn a future on the tokio runtime. This function does not wrap the task in an `async-channel::Receiver`
/// like [spawn](#method.spawn).
/// The caller of this function is responsible for wrapping up the task with an `async-channel::Receiver` to
/// ensure that the task gets cancelled appropriately.
/// This function generates prometheus metrics on number of tasks and task duration.
///
/// This is useful in cases where the future to be spawned needs to do additional cleanup work when
/// the task is completed/canceled (e.g. writing local variables to disk) or the task is created from
/// some framework which does its own cleanup (e.g. a hyper server).
pub fn spawn_without_exit(
&self,
task: impl Future<Output = ()> + Send + 'static,
name: &'static str,
) {
if let Some(int_gauge) = metrics::get_int_gauge(&metrics::ASYNC_TASKS_COUNT, &[name]) {
let int_gauge_1 = int_gauge.clone();
let future = task.then(move |_| {
int_gauge_1.dec();
futures::future::ready(())
});
int_gauge.inc();
if let Some(handle) = self.handle() {
#[cfg(tokio_unstable)]
tokio::task::Builder::new()
.name(name)
.spawn_on(future, &handle)
.expect("Failed to spawn task");
#[cfg(not(tokio_unstable))]
handle.spawn(future);
} else {
debug!("Couldn't spawn task. Runtime shutting down");
}
}
}
/// Spawn a blocking task on a dedicated tokio thread pool wrapped in an exit future.
/// This function generates prometheus metrics on number of tasks and task duration.
pub fn spawn_blocking<F>(&self, task: F, name: &'static str)
where
F: FnOnce() + Send + 'static,
{
if let Some(task_handle) = self.spawn_blocking_handle(task, name) {
self.spawn_monitor(task_handle, name)
}
}
/// Spawns a blocking task on a dedicated tokio thread pool and installs a rayon context within it.
pub fn spawn_blocking_with_rayon<F>(
self,
task: F,
rayon_pool_type: RayonPoolType,
name: &'static str,
) where
F: FnOnce() + Send + 'static,
{
let thread_pool = self.rayon_pool_provider.get_thread_pool(rayon_pool_type);
self.spawn_blocking(
move || {
thread_pool.install(|| {
task();
});
},
name,
)
}
/// Spawns a blocking computation on a rayon thread pool and awaits the result.
pub async fn spawn_blocking_with_rayon_async<F, R>(
&self,
rayon_pool_type: RayonPoolType,
task: F,
) -> Result<R, tokio::sync::oneshot::error::RecvError>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let thread_pool = self.rayon_pool_provider.get_thread_pool(rayon_pool_type);
let (tx, rx) = tokio::sync::oneshot::channel();
thread_pool.spawn(move || {
let result = task();
let _ = tx.send(result);
});
rx.await
}
/// Spawn a future on the tokio runtime wrapped in an `async-channel::Receiver` returning an optional
/// join handle to the future.
/// The task is cancelled when the corresponding async-channel is dropped.
///
/// This function generates prometheus metrics on number of tasks and task duration.
pub fn spawn_handle<R: Send + 'static>(
&self,
task: impl Future<Output = R> + Send + 'static,
name: &'static str,
) -> Option<tokio::task::JoinHandle<Option<R>>> {
let exit = self.exit();
if let Some(int_gauge) = metrics::get_int_gauge(&metrics::ASYNC_TASKS_COUNT, &[name]) {
// Task is shutdown before it completes if `exit` receives
let int_gauge_1 = int_gauge.clone();
int_gauge.inc();
if let Some(handle) = self.handle() {
let fut = async move {
futures::pin_mut!(exit);
let result = match future::select(Box::pin(task), exit).await {
future::Either::Left((value, _)) => Some(value),
future::Either::Right(_) => {
debug!(task = name, "Async task shutdown, exit received");
None
}
};
int_gauge_1.dec();
result
};
#[cfg(tokio_unstable)]
return Some(
tokio::task::Builder::new()
.name(name)
.spawn_on(fut, &handle)
.expect("Failed to spawn task"),
);
#[cfg(not(tokio_unstable))]
Some(handle.spawn(fut))
} else {
debug!("Couldn't spawn task. Runtime shutting down");
None
}
} else {
None
}
}
/// Spawn a blocking task on a dedicated tokio thread pool wrapped in an exit future returning
/// a join handle to the future.
/// If the runtime doesn't exist, this will return None.
/// The Future returned behaves like the standard JoinHandle which can return an error if the
/// task failed.
/// This function generates prometheus metrics on number of tasks and task duration.
pub fn spawn_blocking_handle<F, R>(
&self,
task: F,
name: &'static str,
) -> Option<impl Future<Output = Result<R, tokio::task::JoinError>> + Send + 'static + use<F, R>>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let timer = metrics::start_timer_vec(&metrics::BLOCKING_TASKS_HISTOGRAM, &[name]);
metrics::inc_gauge_vec(&metrics::BLOCKING_TASKS_COUNT, &[name]);
let join_handle = if let Some(handle) = self.handle() {
handle.spawn_blocking(task)
} else {
debug!("Couldn't spawn task. Runtime shutting down");
return None;
};
let future = async move {
let result = match join_handle.await {
Ok(result) => Ok(result),
Err(error) => {
debug!(%error, "Blocking task ended unexpectedly");
Err(error)
}
};
drop(timer);
metrics::dec_gauge_vec(&metrics::BLOCKING_TASKS_COUNT, &[name]);
result
};
Some(future)
}
/// Block the current (non-async) thread on the completion of some future.
///
/// ## Warning
///
/// This method is "dangerous" since calling it from an async thread will result in a panic! Any
/// use of this outside of testing should be very deeply considered as Lighthouse has been
/// burned by this function in the past.
///
/// Determining what is an "async thread" is rather challenging; just because a function isn't
/// marked as `async` doesn't mean it's not being called from an `async` function or there isn't
/// a `tokio` context present in the thread-local storage due to some `rayon` funkiness. Talk to
/// @paulhauner if you plan to use this function in production. He has put metrics in here to
/// track any use of it, so don't think you can pull a sneaky one on him.
pub fn block_on_dangerous<F: Future>(
&self,
future: F,
name: &'static str,
) -> Option<F::Output> {
let timer = metrics::start_timer_vec(&metrics::BLOCK_ON_TASKS_HISTOGRAM, &[name]);
metrics::inc_gauge_vec(&metrics::BLOCK_ON_TASKS_COUNT, &[name]);
let handle = self.handle()?;
let exit = self.exit();
debug!(name, "Starting block_on task");
handle.block_on(async {
let output = tokio::select! {
output = future => {
debug!(
name,
"Completed block_on task"
);
Some(output)
}
_ = exit => {
debug!(
name,
"Cancelled block_on task"
);
None
}
};
metrics::dec_gauge_vec(&metrics::BLOCK_ON_TASKS_COUNT, &[name]);
drop(timer);
output
})
}
/// Returns a `Handle` to the current runtime.
pub fn handle(&self) -> Option<Handle> {
self.handle_provider.handle()
}
/// Returns a future that completes when `async-channel::Sender` is dropped or () is sent,
/// which translates to the exit signal being triggered.
pub fn exit(&self) -> impl Future<Output = ()> + 'static {
let exit = self.exit.clone();
async move {
let _ = exit.recv().await;
}
}
/// Get a channel to request shutting down.
pub fn shutdown_sender(&self) -> Sender<ShutdownReason> {
self.signal_tx.clone()
}
}