use crate::{config::TaskType, Error};
use exit_future::Signal;
use futures::{
future::{pending, select, try_join_all, BoxFuture, Either},
Future, FutureExt, StreamExt,
};
use parking_lot::Mutex;
use prometheus_endpoint::{
exponential_buckets, register, CounterVec, HistogramOpts, HistogramVec, Opts, PrometheusError,
Registry, U64,
};
use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
use std::{
collections::{hash_map::Entry, HashMap},
panic,
pin::Pin,
result::Result,
sync::Arc,
};
use tokio::runtime::Handle;
use tracing_futures::Instrument;
mod prometheus_future;
#[cfg(test)]
mod tests;
pub const DEFAULT_GROUP_NAME: &str = "default";
pub enum GroupName {
Default,
Specific(&'static str),
}
impl From<Option<&'static str>> for GroupName {
fn from(name: Option<&'static str>) -> Self {
match name {
Some(name) => Self::Specific(name),
None => Self::Default,
}
}
}
impl From<&'static str> for GroupName {
fn from(name: &'static str) -> Self {
Self::Specific(name)
}
}
#[derive(Clone)]
pub struct SpawnTaskHandle {
on_exit: exit_future::Exit,
tokio_handle: Handle,
metrics: Option<Metrics>,
task_registry: TaskRegistry,
}
impl SpawnTaskHandle {
pub fn spawn(
&self,
name: &'static str,
group: impl Into<GroupName>,
task: impl Future<Output = ()> + Send + 'static,
) {
self.spawn_inner(name, group, task, TaskType::Async)
}
pub fn spawn_blocking(
&self,
name: &'static str,
group: impl Into<GroupName>,
task: impl Future<Output = ()> + Send + 'static,
) {
self.spawn_inner(name, group, task, TaskType::Blocking)
}
fn spawn_inner(
&self,
name: &'static str,
group: impl Into<GroupName>,
task: impl Future<Output = ()> + Send + 'static,
task_type: TaskType,
) {
let on_exit = self.on_exit.clone();
let metrics = self.metrics.clone();
let registry = self.task_registry.clone();
let group = match group.into() {
GroupName::Specific(var) => var,
GroupName::Default => DEFAULT_GROUP_NAME,
};
if let Some(metrics) = &self.metrics {
metrics.tasks_spawned.with_label_values(&[name, group]).inc();
metrics.tasks_ended.with_label_values(&[name, "finished", group]).inc_by(0);
}
let future = async move {
let _registry_token = registry.register_task(name, group);
if let Some(metrics) = metrics {
let task = {
let poll_duration = metrics.poll_duration.with_label_values(&[name, group]);
let poll_start = metrics.poll_start.with_label_values(&[name, group]);
let inner =
prometheus_future::with_poll_durations(poll_duration, poll_start, task);
panic::AssertUnwindSafe(inner).catch_unwind()
};
futures::pin_mut!(task);
match select(on_exit, task).await {
Either::Right((Err(payload), _)) => {
metrics.tasks_ended.with_label_values(&[name, "panic", group]).inc();
panic::resume_unwind(payload)
},
Either::Right((Ok(()), _)) => {
metrics.tasks_ended.with_label_values(&[name, "finished", group]).inc();
},
Either::Left(((), _)) => {
metrics.tasks_ended.with_label_values(&[name, "interrupted", group]).inc();
},
}
} else {
futures::pin_mut!(task);
let _ = select(on_exit, task).await;
}
}
.in_current_span();
match task_type {
TaskType::Async => {
self.tokio_handle.spawn(future);
},
TaskType::Blocking => {
let handle = self.tokio_handle.clone();
self.tokio_handle.spawn_blocking(move || {
handle.block_on(future);
});
},
}
}
}
impl sp_core::traits::SpawnNamed for SpawnTaskHandle {
fn spawn_blocking(
&self,
name: &'static str,
group: Option<&'static str>,
future: BoxFuture<'static, ()>,
) {
self.spawn_inner(name, group, future, TaskType::Blocking)
}
fn spawn(
&self,
name: &'static str,
group: Option<&'static str>,
future: BoxFuture<'static, ()>,
) {
self.spawn_inner(name, group, future, TaskType::Async)
}
}
#[derive(Clone)]
pub struct SpawnEssentialTaskHandle {
essential_failed_tx: TracingUnboundedSender<()>,
inner: SpawnTaskHandle,
}
impl SpawnEssentialTaskHandle {
pub fn new(
essential_failed_tx: TracingUnboundedSender<()>,
spawn_task_handle: SpawnTaskHandle,
) -> SpawnEssentialTaskHandle {
SpawnEssentialTaskHandle { essential_failed_tx, inner: spawn_task_handle }
}
pub fn spawn(
&self,
name: &'static str,
group: impl Into<GroupName>,
task: impl Future<Output = ()> + Send + 'static,
) {
self.spawn_inner(name, group, task, TaskType::Async)
}
pub fn spawn_blocking(
&self,
name: &'static str,
group: impl Into<GroupName>,
task: impl Future<Output = ()> + Send + 'static,
) {
self.spawn_inner(name, group, task, TaskType::Blocking)
}
fn spawn_inner(
&self,
name: &'static str,
group: impl Into<GroupName>,
task: impl Future<Output = ()> + Send + 'static,
task_type: TaskType,
) {
let essential_failed = self.essential_failed_tx.clone();
let essential_task = std::panic::AssertUnwindSafe(task).catch_unwind().map(move |_| {
log::error!("Essential task `{}` failed. Shutting down service.", name);
let _ = essential_failed.close_channel();
});
let _ = self.inner.spawn_inner(name, group, essential_task, task_type);
}
}
impl sp_core::traits::SpawnEssentialNamed for SpawnEssentialTaskHandle {
fn spawn_essential_blocking(
&self,
name: &'static str,
group: Option<&'static str>,
future: BoxFuture<'static, ()>,
) {
self.spawn_blocking(name, group, future);
}
fn spawn_essential(
&self,
name: &'static str,
group: Option<&'static str>,
future: BoxFuture<'static, ()>,
) {
self.spawn(name, group, future);
}
}
pub struct TaskManager {
on_exit: exit_future::Exit,
_signal: Signal,
tokio_handle: Handle,
metrics: Option<Metrics>,
essential_failed_tx: TracingUnboundedSender<()>,
essential_failed_rx: TracingUnboundedReceiver<()>,
keep_alive: Box<dyn std::any::Any + Send>,
children: Vec<TaskManager>,
task_registry: TaskRegistry,
}
impl TaskManager {
pub fn new(
tokio_handle: Handle,
prometheus_registry: Option<&Registry>,
) -> Result<Self, PrometheusError> {
let (signal, on_exit) = exit_future::signal();
let (essential_failed_tx, essential_failed_rx) =
tracing_unbounded("mpsc_essential_tasks", 100);
let metrics = prometheus_registry.map(Metrics::register).transpose()?;
Ok(Self {
on_exit,
_signal: signal,
tokio_handle,
metrics,
essential_failed_tx,
essential_failed_rx,
keep_alive: Box::new(()),
children: Vec::new(),
task_registry: Default::default(),
})
}
pub fn spawn_handle(&self) -> SpawnTaskHandle {
SpawnTaskHandle {
on_exit: self.on_exit.clone(),
tokio_handle: self.tokio_handle.clone(),
metrics: self.metrics.clone(),
task_registry: self.task_registry.clone(),
}
}
pub fn spawn_essential_handle(&self) -> SpawnEssentialTaskHandle {
SpawnEssentialTaskHandle::new(self.essential_failed_tx.clone(), self.spawn_handle())
}
pub fn future<'a>(
&'a mut self,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
Box::pin(async move {
let mut t1 = self.essential_failed_rx.next().fuse();
let mut t2 = self.on_exit.clone().fuse();
let mut t3 = try_join_all(
self.children
.iter_mut()
.map(|x| x.future())
.chain(std::iter::once(pending().boxed())),
)
.fuse();
futures::select! {
_ = t1 => Err(Error::Other("Essential task failed.".into())),
_ = t2 => Ok(()),
res = t3 => Err(res.map(|_| ()).expect_err("this future never ends; qed")),
}
})
}
pub fn keep_alive<T: 'static + Send>(&mut self, to_keep_alive: T) {
use std::mem;
let old = mem::replace(&mut self.keep_alive, Box::new(()));
self.keep_alive = Box::new((to_keep_alive, old));
}
pub fn add_child(&mut self, child: TaskManager) {
self.children.push(child);
}
pub fn into_task_registry(self) -> TaskRegistry {
self.task_registry
}
}
#[derive(Clone)]
struct Metrics {
poll_duration: HistogramVec,
poll_start: CounterVec<U64>,
tasks_spawned: CounterVec<U64>,
tasks_ended: CounterVec<U64>,
}
impl Metrics {
fn register(registry: &Registry) -> Result<Self, PrometheusError> {
Ok(Self {
poll_duration: register(HistogramVec::new(
HistogramOpts {
common_opts: Opts::new(
"substrate_tasks_polling_duration",
"Duration in seconds of each invocation of Future::poll"
),
buckets: exponential_buckets(0.001, 4.0, 9)
.expect("function parameters are constant and always valid; qed"),
},
&["task_name", "task_group"]
)?, registry)?,
poll_start: register(CounterVec::new(
Opts::new(
"substrate_tasks_polling_started_total",
"Total number of times we started invoking Future::poll"
),
&["task_name", "task_group"]
)?, registry)?,
tasks_spawned: register(CounterVec::new(
Opts::new(
"substrate_tasks_spawned_total",
"Total number of tasks that have been spawned on the Service"
),
&["task_name", "task_group"]
)?, registry)?,
tasks_ended: register(CounterVec::new(
Opts::new(
"substrate_tasks_ended_total",
"Total number of tasks for which Future::poll has returned Ready(()) or panicked"
),
&["task_name", "reason", "task_group"]
)?, registry)?,
})
}
}
struct UnregisterOnDrop {
task: Task,
registry: TaskRegistry,
}
impl Drop for UnregisterOnDrop {
fn drop(&mut self) {
let mut tasks = self.registry.tasks.lock();
if let Entry::Occupied(mut entry) = (*tasks).entry(self.task.clone()) {
*entry.get_mut() -= 1;
if *entry.get() == 0 {
entry.remove();
}
}
}
}
#[derive(Clone, Hash, Eq, PartialEq)]
pub struct Task {
pub name: &'static str,
pub group: &'static str,
}
impl Task {
pub fn is_default_group(&self) -> bool {
self.group == DEFAULT_GROUP_NAME
}
}
#[derive(Clone, Default)]
pub struct TaskRegistry {
tasks: Arc<Mutex<HashMap<Task, usize>>>,
}
impl TaskRegistry {
fn register_task(&self, name: &'static str, group: &'static str) -> UnregisterOnDrop {
let task = Task { name, group };
{
let mut tasks = self.tasks.lock();
*(*tasks).entry(task.clone()).or_default() += 1;
}
UnregisterOnDrop { task, registry: self.clone() }
}
pub fn running_tasks(&self) -> HashMap<Task, usize> {
(*self.tasks.lock()).clone()
}
}