use futures::channel::oneshot;
use parking_lot::{RwLock, RwLockWriteGuard};
use sp_runtime::traits::Block as BlockT;
use std::{
collections::{hash_map::Entry, HashMap, HashSet},
sync::Arc,
};
#[derive(Debug)]
pub enum SubscriptionManagementError {
ExceededLimits,
Custom(String),
}
struct SubscriptionInner<Block: BlockT> {
runtime_updates: bool,
tx_stop: Option<oneshot::Sender<()>>,
blocks: HashSet<Block::Hash>,
max_pinned_blocks: usize,
}
#[derive(Clone)]
pub struct SubscriptionHandle<Block: BlockT> {
inner: Arc<RwLock<SubscriptionInner<Block>>>,
best_block: Arc<RwLock<Option<Block::Hash>>>,
}
impl<Block: BlockT> SubscriptionHandle<Block> {
fn new(runtime_updates: bool, tx_stop: oneshot::Sender<()>, max_pinned_blocks: usize) -> Self {
SubscriptionHandle {
inner: Arc::new(RwLock::new(SubscriptionInner {
runtime_updates,
tx_stop: Some(tx_stop),
blocks: HashSet::new(),
max_pinned_blocks,
})),
best_block: Arc::new(RwLock::new(None)),
}
}
pub fn stop(&self) {
let mut inner = self.inner.write();
if let Some(tx_stop) = inner.tx_stop.take() {
let _ = tx_stop.send(());
}
}
pub fn pin_block(&self, hash: Block::Hash) -> Result<bool, SubscriptionManagementError> {
let mut inner = self.inner.write();
if inner.blocks.len() == inner.max_pinned_blocks {
if inner.blocks.contains(&hash) {
return Ok(false)
} else {
return Err(SubscriptionManagementError::ExceededLimits)
}
}
Ok(inner.blocks.insert(hash))
}
pub fn unpin_block(&self, hash: &Block::Hash) -> bool {
let mut inner = self.inner.write();
inner.blocks.remove(hash)
}
pub fn contains_block(&self, hash: &Block::Hash) -> bool {
let inner = self.inner.read();
inner.blocks.contains(hash)
}
pub fn has_runtime_updates(&self) -> bool {
let inner = self.inner.read();
inner.runtime_updates
}
pub fn best_block_write(&self) -> RwLockWriteGuard<'_, Option<Block::Hash>> {
self.best_block.write()
}
}
pub struct SubscriptionManagement<Block: BlockT> {
inner: RwLock<HashMap<String, SubscriptionHandle<Block>>>,
}
impl<Block: BlockT> SubscriptionManagement<Block> {
pub fn new() -> Self {
SubscriptionManagement { inner: RwLock::new(HashMap::new()) }
}
pub fn insert_subscription(
&self,
subscription_id: String,
runtime_updates: bool,
max_pinned_blocks: usize,
) -> Option<(oneshot::Receiver<()>, SubscriptionHandle<Block>)> {
let mut subs = self.inner.write();
if let Entry::Vacant(entry) = subs.entry(subscription_id) {
let (tx_stop, rx_stop) = oneshot::channel();
let handle =
SubscriptionHandle::<Block>::new(runtime_updates, tx_stop, max_pinned_blocks);
entry.insert(handle.clone());
Some((rx_stop, handle))
} else {
None
}
}
pub fn remove_subscription(&self, subscription_id: &String) {
let mut subs = self.inner.write();
subs.remove(subscription_id);
}
pub fn get_subscription(&self, subscription_id: &String) -> Option<SubscriptionHandle<Block>> {
let subs = self.inner.write();
subs.get(subscription_id).and_then(|handle| Some(handle.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use sp_core::H256;
use substrate_test_runtime_client::runtime::Block;
#[test]
fn subscription_check_id() {
let subs = SubscriptionManagement::<Block>::new();
let id = "abc".to_string();
let hash = H256::random();
let handle = subs.get_subscription(&id);
assert!(handle.is_none());
let (_, handle) = subs.insert_subscription(id.clone(), false, 10).unwrap();
assert!(!handle.contains_block(&hash));
subs.remove_subscription(&id);
let handle = subs.get_subscription(&id);
assert!(handle.is_none());
}
#[test]
fn subscription_check_block() {
let subs = SubscriptionManagement::<Block>::new();
let id = "abc".to_string();
let hash = H256::random();
let (_, handle) = subs.insert_subscription(id.clone(), false, 10).unwrap();
assert!(!handle.contains_block(&hash));
assert!(!handle.unpin_block(&hash));
handle.pin_block(hash).unwrap();
assert!(handle.contains_block(&hash));
assert!(!handle.unpin_block(&H256::random()));
assert!(handle.unpin_block(&hash));
assert!(!handle.contains_block(&hash));
}
#[test]
fn subscription_check_stop_event() {
let subs = SubscriptionManagement::<Block>::new();
let id = "abc".to_string();
let (mut rx_stop, handle) = subs.insert_subscription(id.clone(), false, 10).unwrap();
let res = rx_stop.try_recv().unwrap();
assert!(res.is_none());
let res = subs.insert_subscription(id.clone(), false, 10);
assert!(res.is_none());
handle.stop();
let res = rx_stop.try_recv().unwrap();
assert!(res.is_some());
}
#[test]
fn subscription_check_data() {
let subs = SubscriptionManagement::<Block>::new();
let id = "abc".to_string();
let (_, handle) = subs.insert_subscription(id.clone(), false, 10).unwrap();
assert!(!handle.has_runtime_updates());
let id2 = "abcd".to_string();
let (_, handle) = subs.insert_subscription(id2.clone(), true, 10).unwrap();
assert!(handle.has_runtime_updates());
}
#[test]
fn subscription_check_max_pinned() {
let subs = SubscriptionManagement::<Block>::new();
let id = "abc".to_string();
let hash = H256::random();
let hash_2 = H256::random();
let (_, handle) = subs.insert_subscription(id.clone(), false, 1).unwrap();
handle.pin_block(hash).unwrap();
handle.pin_block(hash).unwrap();
handle.pin_block(hash_2).unwrap_err();
}
}