use crate::PeerId;
use polkadot_primitives::v2::{AuthorityDiscoveryId, SessionIndex, ValidatorIndex};
use rand::{CryptoRng, Rng};
use std::{
collections::{hash_map, HashMap, HashSet},
fmt::Debug,
};
const LOG_TARGET: &str = "parachain::grid-topology";
pub const DEFAULT_RANDOM_SAMPLE_RATE: usize = crate::MIN_GOSSIP_PEERS;
pub const DEFAULT_RANDOM_CIRCULATION: usize = 4;
#[derive(Debug, Clone, PartialEq)]
pub struct TopologyPeerInfo {
pub peer_ids: Vec<PeerId>,
pub validator_index: ValidatorIndex,
pub discovery_id: AuthorityDiscoveryId,
}
#[derive(Default, Clone, Debug, PartialEq)]
pub struct SessionGridTopology {
shuffled_indices: Vec<usize>,
canonical_shuffling: Vec<TopologyPeerInfo>,
}
impl SessionGridTopology {
pub fn new(shuffled_indices: Vec<usize>, canonical_shuffling: Vec<TopologyPeerInfo>) -> Self {
SessionGridTopology { shuffled_indices, canonical_shuffling }
}
pub fn compute_grid_neighbors_for(&self, v: ValidatorIndex) -> Option<GridNeighbors> {
if self.shuffled_indices.len() != self.canonical_shuffling.len() {
return None
}
let shuffled_val_index = *self.shuffled_indices.get(v.0 as usize)?;
let neighbors = matrix_neighbors(shuffled_val_index, self.shuffled_indices.len())?;
let mut grid_subset = GridNeighbors::empty();
for r_n in neighbors.row_neighbors {
let n = &self.canonical_shuffling[r_n];
grid_subset.validator_indices_x.insert(n.validator_index);
for p in &n.peer_ids {
grid_subset.peers_x.insert(*p);
}
}
for c_n in neighbors.column_neighbors {
let n = &self.canonical_shuffling[c_n];
grid_subset.validator_indices_y.insert(n.validator_index);
for p in &n.peer_ids {
grid_subset.peers_y.insert(*p);
}
}
Some(grid_subset)
}
}
struct MatrixNeighbors<R, C> {
row_neighbors: R,
column_neighbors: C,
}
fn matrix_neighbors(
val_index: usize,
len: usize,
) -> Option<MatrixNeighbors<impl Iterator<Item = usize>, impl Iterator<Item = usize>>> {
if val_index >= len {
return None
}
let sqrt = (len as f64).sqrt() as usize;
let our_row = val_index / sqrt;
let our_column = val_index % sqrt;
let row_neighbors = our_row * sqrt..std::cmp::min(our_row * sqrt + sqrt, len);
let column_neighbors = (our_column..len).step_by(sqrt);
Some(MatrixNeighbors {
row_neighbors: row_neighbors.filter(move |i| *i != val_index),
column_neighbors: column_neighbors.filter(move |i| *i != val_index),
})
}
#[derive(Debug, Clone, PartialEq)]
pub struct GridNeighbors {
pub peers_x: HashSet<PeerId>,
pub validator_indices_x: HashSet<ValidatorIndex>,
pub peers_y: HashSet<PeerId>,
pub validator_indices_y: HashSet<ValidatorIndex>,
}
impl GridNeighbors {
pub fn empty() -> Self {
GridNeighbors {
peers_x: HashSet::new(),
validator_indices_x: HashSet::new(),
peers_y: HashSet::new(),
validator_indices_y: HashSet::new(),
}
}
pub fn required_routing_by_index(
&self,
originator: ValidatorIndex,
local: bool,
) -> RequiredRouting {
if local {
return RequiredRouting::GridXY
}
let grid_x = self.validator_indices_x.contains(&originator);
let grid_y = self.validator_indices_y.contains(&originator);
match (grid_x, grid_y) {
(false, false) => RequiredRouting::None,
(true, false) => RequiredRouting::GridY, (false, true) => RequiredRouting::GridX, (true, true) => RequiredRouting::GridXY, }
}
pub fn required_routing_by_peer_id(&self, originator: PeerId, local: bool) -> RequiredRouting {
if local {
return RequiredRouting::GridXY
}
let grid_x = self.peers_x.contains(&originator);
let grid_y = self.peers_y.contains(&originator);
match (grid_x, grid_y) {
(false, false) => RequiredRouting::None,
(true, false) => RequiredRouting::GridY, (false, true) => RequiredRouting::GridX, (true, true) => {
gum::debug!(
target: LOG_TARGET,
?originator,
"Grid topology is unexpected, play it safe and send to X AND Y"
);
RequiredRouting::GridXY
}, }
}
pub fn route_to_peer(&self, required_routing: RequiredRouting, peer: &PeerId) -> bool {
match required_routing {
RequiredRouting::All => true,
RequiredRouting::GridX => self.peers_x.contains(peer),
RequiredRouting::GridY => self.peers_y.contains(peer),
RequiredRouting::GridXY => self.peers_x.contains(peer) || self.peers_y.contains(peer),
RequiredRouting::None | RequiredRouting::PendingTopology => false,
}
}
pub fn peers_diff(&self, other: &Self) -> Vec<PeerId> {
self.peers_x
.iter()
.chain(self.peers_y.iter())
.filter(|peer_id| !(other.peers_x.contains(peer_id) || other.peers_y.contains(peer_id)))
.cloned()
.collect::<Vec<_>>()
}
pub fn len(&self) -> usize {
self.peers_x.len().saturating_add(self.peers_y.len())
}
}
#[derive(Debug)]
pub struct SessionGridTopologyEntry {
topology: SessionGridTopology,
local_neighbors: GridNeighbors,
}
impl SessionGridTopologyEntry {
pub fn local_grid_neighbors(&self) -> &GridNeighbors {
&self.local_neighbors
}
pub fn local_grid_neighbors_mut(&mut self) -> &mut GridNeighbors {
&mut self.local_neighbors
}
pub fn get(&self) -> &SessionGridTopology {
&self.topology
}
}
#[derive(Default)]
pub struct SessionGridTopologies {
inner: HashMap<SessionIndex, (Option<SessionGridTopologyEntry>, usize)>,
}
impl SessionGridTopologies {
pub fn get_topology(&self, session: SessionIndex) -> Option<&SessionGridTopologyEntry> {
self.inner.get(&session).and_then(|val| val.0.as_ref())
}
pub fn inc_session_refs(&mut self, session: SessionIndex) {
self.inner.entry(session).or_insert((None, 0)).1 += 1;
}
pub fn dec_session_refs(&mut self, session: SessionIndex) {
if let hash_map::Entry::Occupied(mut occupied) = self.inner.entry(session) {
occupied.get_mut().1 = occupied.get().1.saturating_sub(1);
if occupied.get().1 == 0 {
let _ = occupied.remove();
}
}
}
pub fn insert_topology(
&mut self,
session: SessionIndex,
topology: SessionGridTopology,
local_index: Option<ValidatorIndex>,
) {
let entry = self.inner.entry(session).or_insert((None, 0));
if entry.0.is_none() {
let local_neighbors = local_index
.and_then(|l| topology.compute_grid_neighbors_for(l))
.unwrap_or_else(GridNeighbors::empty);
entry.0 = Some(SessionGridTopologyEntry { topology, local_neighbors });
}
}
}
#[derive(Debug)]
struct GridTopologySessionBound {
entry: SessionGridTopologyEntry,
session_index: SessionIndex,
}
#[derive(Debug)]
pub struct SessionBoundGridTopologyStorage {
current_topology: GridTopologySessionBound,
prev_topology: Option<GridTopologySessionBound>,
}
impl Default for SessionBoundGridTopologyStorage {
fn default() -> Self {
SessionBoundGridTopologyStorage {
current_topology: GridTopologySessionBound {
session_index: SessionIndex::max_value(),
entry: SessionGridTopologyEntry {
topology: SessionGridTopology {
shuffled_indices: Vec::new(),
canonical_shuffling: Vec::new(),
},
local_neighbors: GridNeighbors::empty(),
},
},
prev_topology: None,
}
}
}
impl SessionBoundGridTopologyStorage {
pub fn get_topology_or_fallback(&self, idx: SessionIndex) -> &SessionGridTopologyEntry {
self.get_topology(idx).unwrap_or(&self.current_topology.entry)
}
pub fn get_topology(&self, idx: SessionIndex) -> Option<&SessionGridTopologyEntry> {
if let Some(prev_topology) = &self.prev_topology {
if idx == prev_topology.session_index {
return Some(&prev_topology.entry)
}
}
if self.current_topology.session_index == idx {
return Some(&self.current_topology.entry)
}
None
}
pub fn update_topology(
&mut self,
session_index: SessionIndex,
topology: SessionGridTopology,
local_index: Option<ValidatorIndex>,
) {
let local_neighbors = local_index
.and_then(|l| topology.compute_grid_neighbors_for(l))
.unwrap_or_else(GridNeighbors::empty);
let old_current = std::mem::replace(
&mut self.current_topology,
GridTopologySessionBound {
entry: SessionGridTopologyEntry { topology, local_neighbors },
session_index,
},
);
self.prev_topology.replace(old_current);
}
pub fn get_current_topology(&self) -> &SessionGridTopologyEntry {
&self.current_topology.entry
}
pub fn get_current_topology_mut(&mut self) -> &mut SessionGridTopologyEntry {
&mut self.current_topology.entry
}
}
#[derive(Debug, Clone, Copy)]
pub struct RandomRouting {
target: usize,
sent: usize,
sample_rate: usize,
}
impl Default for RandomRouting {
fn default() -> Self {
RandomRouting {
target: DEFAULT_RANDOM_CIRCULATION,
sent: 0_usize,
sample_rate: DEFAULT_RANDOM_SAMPLE_RATE,
}
}
}
impl RandomRouting {
pub fn sample(&self, n_peers_total: usize, rng: &mut (impl CryptoRng + Rng)) -> bool {
if n_peers_total == 0 || self.sent >= self.target {
false
} else if self.sample_rate > n_peers_total {
true
} else {
rng.gen_ratio(self.sample_rate as _, n_peers_total as _)
}
}
pub fn inc_sent(&mut self) {
self.sent += 1
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RequiredRouting {
PendingTopology,
All,
GridXY,
GridX,
GridY,
None,
}
impl RequiredRouting {
pub fn is_empty(self) -> bool {
match self {
RequiredRouting::PendingTopology | RequiredRouting::None => true,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha12Rng;
fn dummy_rng() -> ChaCha12Rng {
rand_chacha::ChaCha12Rng::seed_from_u64(12345)
}
#[test]
fn test_random_routing_sample() {
let mut rng = dummy_rng();
let mut random_routing = RandomRouting { target: 4, sent: 0, sample_rate: 8 };
assert_eq!(random_routing.sample(16, &mut rng), true);
random_routing.inc_sent();
assert_eq!(random_routing.sample(16, &mut rng), false);
assert_eq!(random_routing.sample(16, &mut rng), false);
assert_eq!(random_routing.sample(16, &mut rng), true);
random_routing.inc_sent();
assert_eq!(random_routing.sample(16, &mut rng), true);
random_routing.inc_sent();
assert_eq!(random_routing.sample(16, &mut rng), false);
assert_eq!(random_routing.sample(16, &mut rng), false);
assert_eq!(random_routing.sample(16, &mut rng), false);
assert_eq!(random_routing.sample(16, &mut rng), true);
random_routing.inc_sent();
for _ in 0..16 {
assert_eq!(random_routing.sample(16, &mut rng), false);
}
}
fn run_random_routing(
random_routing: &mut RandomRouting,
rng: &mut (impl CryptoRng + Rng),
npeers: usize,
iters: usize,
) -> usize {
let mut ret = 0_usize;
for _ in 0..iters {
if random_routing.sample(npeers, rng) {
random_routing.inc_sent();
ret += 1;
}
}
ret
}
#[test]
fn test_random_routing_distribution() {
let mut rng = dummy_rng();
let mut random_routing = RandomRouting { target: 4, sent: 0, sample_rate: 8 };
assert_eq!(run_random_routing(&mut random_routing, &mut rng, 100, 10000), 4);
let mut random_routing = RandomRouting { target: 8, sent: 0, sample_rate: 100 };
assert_eq!(run_random_routing(&mut random_routing, &mut rng, 100, 10000), 8);
let mut random_routing = RandomRouting { target: 0, sent: 0, sample_rate: 100 };
assert_eq!(run_random_routing(&mut random_routing, &mut rng, 100, 10000), 0);
let mut random_routing = RandomRouting { target: 10, sent: 0, sample_rate: 10 };
assert_eq!(run_random_routing(&mut random_routing, &mut rng, 10, 100), 10);
}
#[test]
fn test_matrix_neighbors() {
for (our_index, len, expected_row, expected_column) in vec![
(0usize, 1usize, vec![], vec![]),
(1, 2, vec![], vec![0usize]),
(0, 9, vec![1, 2], vec![3, 6]),
(9, 10, vec![], vec![0, 3, 6]),
(10, 11, vec![9], vec![1, 4, 7]),
(7, 11, vec![6, 8], vec![1, 4, 10]),
]
.into_iter()
{
let matrix = matrix_neighbors(our_index, len).unwrap();
let mut row_result: Vec<_> = matrix.row_neighbors.collect();
let mut column_result: Vec<_> = matrix.column_neighbors.collect();
row_result.sort();
column_result.sort();
assert_eq!(row_result, expected_row);
assert_eq!(column_result, expected_column);
}
}
}