use crate::handler::{
ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr, IntoConnectionHandler,
KeepAlive, SubstreamProtocol,
};
use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend};
use crate::NegotiatedSubstream;
use futures::{future::BoxFuture, prelude::*};
use libp2p_core::upgrade::{NegotiationError, ProtocolError, ProtocolName, UpgradeError};
use libp2p_core::{ConnectedPoint, Multiaddr, PeerId};
use rand::Rng;
use std::{
cmp,
collections::{HashMap, HashSet},
error,
fmt::{self, Debug},
hash::Hash,
iter::{self, FromIterator},
task::{Context, Poll},
time::Duration,
};
#[derive(Clone)]
pub struct MultiHandler<K, H> {
handlers: HashMap<K, H>,
}
impl<K, H> fmt::Debug for MultiHandler<K, H>
where
K: fmt::Debug + Eq + Hash,
H: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultiHandler")
.field("handlers", &self.handlers)
.finish()
}
}
impl<K, H> MultiHandler<K, H>
where
K: Hash + Eq,
H: ConnectionHandler,
{
pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
where
I: IntoIterator<Item = (K, H)>,
{
let m = MultiHandler {
handlers: HashMap::from_iter(iter),
};
uniq_proto_names(
m.handlers
.values()
.map(|h| h.listen_protocol().into_upgrade().0),
)?;
Ok(m)
}
}
impl<K, H> ConnectionHandler for MultiHandler<K, H>
where
K: Clone + Debug + Hash + Eq + Send + 'static,
H: ConnectionHandler,
H::InboundProtocol: InboundUpgradeSend,
H::OutboundProtocol: OutboundUpgradeSend,
{
type InEvent = (K, <H as ConnectionHandler>::InEvent);
type OutEvent = (K, <H as ConnectionHandler>::OutEvent);
type Error = <H as ConnectionHandler>::Error;
type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
let (upgrade, info, timeout) = self
.handlers
.iter()
.map(|(key, handler)| {
let proto = handler.listen_protocol();
let timeout = *proto.timeout();
let (upgrade, info) = proto.into_upgrade();
(key.clone(), (upgrade, info, timeout))
})
.fold(
(Upgrade::new(), Info::new(), Duration::from_secs(0)),
|(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
upg.upgrades.push((k.clone(), u));
inf.infos.push((k, i));
timeout = cmp::max(timeout, t);
(upg, inf, timeout)
},
);
SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
}
fn inject_fully_negotiated_outbound(
&mut self,
protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
(key, arg): Self::OutboundOpenInfo,
) {
if let Some(h) = self.handlers.get_mut(&key) {
#[allow(deprecated)]
h.inject_fully_negotiated_outbound(protocol, arg)
} else {
log::error!("inject_fully_negotiated_outbound: no handler for key")
}
}
fn inject_fully_negotiated_inbound(
&mut self,
(key, arg): <Self::InboundProtocol as InboundUpgradeSend>::Output,
mut info: Self::InboundOpenInfo,
) {
if let Some(h) = self.handlers.get_mut(&key) {
if let Some(i) = info.take(&key) {
#[allow(deprecated)]
h.inject_fully_negotiated_inbound(arg, i)
}
} else {
log::error!("inject_fully_negotiated_inbound: no handler for key")
}
}
fn on_behaviour_event(&mut self, (key, event): Self::InEvent) {
if let Some(h) = self.handlers.get_mut(&key) {
#[allow(deprecated)]
h.inject_event(event)
} else {
log::error!("inject_event: no handler for key")
}
}
fn inject_address_change(&mut self, addr: &Multiaddr) {
for h in self.handlers.values_mut() {
#[allow(deprecated)]
h.inject_address_change(addr)
}
}
fn inject_dial_upgrade_error(
&mut self,
(key, arg): Self::OutboundOpenInfo,
error: ConnectionHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>,
) {
if let Some(h) = self.handlers.get_mut(&key) {
#[allow(deprecated)]
h.inject_dial_upgrade_error(arg, error)
} else {
log::error!("inject_dial_upgrade_error: no handler for protocol")
}
}
fn inject_listen_upgrade_error(
&mut self,
mut info: Self::InboundOpenInfo,
error: ConnectionHandlerUpgrErr<<Self::InboundProtocol as InboundUpgradeSend>::Error>,
) {
match error {
ConnectionHandlerUpgrErr::Timer => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
#[allow(deprecated)]
h.inject_listen_upgrade_error(i, ConnectionHandlerUpgrErr::Timer)
}
}
}
ConnectionHandlerUpgrErr::Timeout => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
#[allow(deprecated)]
h.inject_listen_upgrade_error(i, ConnectionHandlerUpgrErr::Timeout)
}
}
}
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
#[allow(deprecated)]
h.inject_listen_upgrade_error(
i,
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(
NegotiationError::Failed,
)),
)
}
}
}
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(
NegotiationError::ProtocolError(e),
)) => match e {
ProtocolError::IoError(e) => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
let e = NegotiationError::ProtocolError(ProtocolError::IoError(
e.kind().into(),
));
#[allow(deprecated)]
h.inject_listen_upgrade_error(
i,
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(e)),
)
}
}
}
ProtocolError::InvalidMessage => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
let e = NegotiationError::ProtocolError(ProtocolError::InvalidMessage);
#[allow(deprecated)]
h.inject_listen_upgrade_error(
i,
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(e)),
)
}
}
}
ProtocolError::InvalidProtocol => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
let e = NegotiationError::ProtocolError(ProtocolError::InvalidProtocol);
#[allow(deprecated)]
h.inject_listen_upgrade_error(
i,
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(e)),
)
}
}
}
ProtocolError::TooManyProtocols => {
for (k, h) in &mut self.handlers {
if let Some(i) = info.take(k) {
let e =
NegotiationError::ProtocolError(ProtocolError::TooManyProtocols);
#[allow(deprecated)]
h.inject_listen_upgrade_error(
i,
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Select(e)),
)
}
}
}
},
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Apply((k, e))) => {
if let Some(h) = self.handlers.get_mut(&k) {
if let Some(i) = info.take(&k) {
#[allow(deprecated)]
h.inject_listen_upgrade_error(
i,
ConnectionHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)),
)
}
}
}
}
}
fn connection_keep_alive(&self) -> KeepAlive {
self.handlers
.values()
.map(|h| h.connection_keep_alive())
.max()
.unwrap_or(KeepAlive::No)
}
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::OutEvent,
Self::Error,
>,
> {
if self.handlers.is_empty() {
return Poll::Pending;
}
let pos = rand::thread_rng().gen_range(0..self.handlers.len());
for (k, h) in self.handlers.iter_mut().skip(pos) {
if let Poll::Ready(e) = h.poll(cx) {
let e = e
.map_outbound_open_info(|i| (k.clone(), i))
.map_custom(|p| (k.clone(), p));
return Poll::Ready(e);
}
}
for (k, h) in self.handlers.iter_mut().take(pos) {
if let Poll::Ready(e) = h.poll(cx) {
let e = e
.map_outbound_open_info(|i| (k.clone(), i))
.map_custom(|p| (k.clone(), p));
return Poll::Ready(e);
}
}
Poll::Pending
}
}
impl<K, H> IntoIterator for MultiHandler<K, H> {
type Item = <Self::IntoIter as Iterator>::Item;
type IntoIter = std::collections::hash_map::IntoIter<K, H>;
fn into_iter(self) -> Self::IntoIter {
self.handlers.into_iter()
}
}
#[derive(Clone)]
pub struct IntoMultiHandler<K, H> {
handlers: HashMap<K, H>,
}
impl<K, H> fmt::Debug for IntoMultiHandler<K, H>
where
K: fmt::Debug + Eq + Hash,
H: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntoMultiHandler")
.field("handlers", &self.handlers)
.finish()
}
}
impl<K, H> IntoMultiHandler<K, H>
where
K: Hash + Eq,
H: IntoConnectionHandler,
{
pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
where
I: IntoIterator<Item = (K, H)>,
{
let m = IntoMultiHandler {
handlers: HashMap::from_iter(iter),
};
uniq_proto_names(m.handlers.values().map(|h| h.inbound_protocol()))?;
Ok(m)
}
}
impl<K, H> IntoConnectionHandler for IntoMultiHandler<K, H>
where
K: Debug + Clone + Eq + Hash + Send + 'static,
H: IntoConnectionHandler,
{
type Handler = MultiHandler<K, H::Handler>;
fn into_handler(self, p: &PeerId, c: &ConnectedPoint) -> Self::Handler {
MultiHandler {
handlers: self
.handlers
.into_iter()
.map(|(k, h)| (k, h.into_handler(p, c)))
.collect(),
}
}
fn inbound_protocol(&self) -> <Self::Handler as ConnectionHandler>::InboundProtocol {
Upgrade {
upgrades: self
.handlers
.iter()
.map(|(k, h)| (k.clone(), h.inbound_protocol()))
.collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct IndexedProtoName<H>(usize, H);
impl<H: ProtocolName> ProtocolName for IndexedProtoName<H> {
fn protocol_name(&self) -> &[u8] {
self.1.protocol_name()
}
}
#[derive(Clone)]
pub struct Info<K, I> {
infos: Vec<(K, I)>,
}
impl<K: Eq, I> Info<K, I> {
fn new() -> Self {
Info { infos: Vec::new() }
}
pub fn take(&mut self, k: &K) -> Option<I> {
if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
return Some(self.infos.remove(p).1);
}
None
}
}
#[derive(Clone)]
pub struct Upgrade<K, H> {
upgrades: Vec<(K, H)>,
}
impl<K, H> Upgrade<K, H> {
fn new() -> Self {
Upgrade {
upgrades: Vec::new(),
}
}
}
impl<K, H> fmt::Debug for Upgrade<K, H>
where
K: fmt::Debug + Eq + Hash,
H: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgrade")
.field("upgrades", &self.upgrades)
.finish()
}
}
impl<K, H> UpgradeInfoSend for Upgrade<K, H>
where
H: UpgradeInfoSend,
K: Send + 'static,
{
type Info = IndexedProtoName<H::Info>;
type InfoIter = std::vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.upgrades
.iter()
.enumerate()
.flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
.map(|(i, h)| IndexedProtoName(i, h))
.collect::<Vec<_>>()
.into_iter()
}
}
impl<K, H> InboundUpgradeSend for Upgrade<K, H>
where
H: InboundUpgradeSend,
K: Send + 'static,
{
type Output = (K, <H as InboundUpgradeSend>::Output);
type Error = (K, <H as InboundUpgradeSend>::Error);
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_inbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
let IndexedProtoName(index, info) = info;
let (key, upgrade) = self.upgrades.remove(index);
upgrade
.upgrade_inbound(resource, info)
.map(move |out| match out {
Ok(o) => Ok((key, o)),
Err(e) => Err((key, e)),
})
.boxed()
}
}
impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
where
H: OutboundUpgradeSend,
K: Send + 'static,
{
type Output = (K, <H as OutboundUpgradeSend>::Output);
type Error = (K, <H as OutboundUpgradeSend>::Error);
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
fn upgrade_outbound(mut self, resource: NegotiatedSubstream, info: Self::Info) -> Self::Future {
let IndexedProtoName(index, info) = info;
let (key, upgrade) = self.upgrades.remove(index);
upgrade
.upgrade_outbound(resource, info)
.map(move |out| match out {
Ok(o) => Ok((key, o)),
Err(e) => Err((key, e)),
})
.boxed()
}
}
fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
where
I: Iterator<Item = T>,
T: UpgradeInfoSend,
{
let mut set = HashSet::new();
for infos in iter {
for i in infos.protocol_info() {
let v = Vec::from(i.protocol_name());
if set.contains(&v) {
return Err(DuplicateProtonameError(v));
} else {
set.insert(v);
}
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct DuplicateProtonameError(Vec<u8>);
impl DuplicateProtonameError {
pub fn protocol_name(&self) -> &[u8] {
&self.0
}
}
impl fmt::Display for DuplicateProtonameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Ok(s) = std::str::from_utf8(&self.0) {
write!(f, "duplicate protocol name: {}", s)
} else {
write!(f, "duplicate protocol name: {:?}", self.0)
}
}
}
impl error::Error for DuplicateProtonameError {}