use crate::behaviour::{inject_from_swarm, FromSwarm};
use crate::handler::{
ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, ConnectionHandlerUpgrErr,
DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, IntoConnectionHandler,
KeepAlive, ListenUpgradeError, SubstreamProtocol,
};
use crate::upgrade::SendWrapper;
use crate::{NetworkBehaviour, NetworkBehaviourAction, PollParameters};
use either::Either;
use libp2p_core::{
either::{EitherError, EitherOutput},
upgrade::{DeniedUpgrade, EitherUpgrade},
ConnectedPoint, Multiaddr, PeerId,
};
use std::{task::Context, task::Poll};
pub struct Toggle<TBehaviour> {
inner: Option<TBehaviour>,
}
impl<TBehaviour> Toggle<TBehaviour> {
pub fn is_enabled(&self) -> bool {
self.inner.is_some()
}
pub fn as_ref(&self) -> Option<&TBehaviour> {
self.inner.as_ref()
}
pub fn as_mut(&mut self) -> Option<&mut TBehaviour> {
self.inner.as_mut()
}
}
impl<TBehaviour> From<Option<TBehaviour>> for Toggle<TBehaviour> {
fn from(inner: Option<TBehaviour>) -> Self {
Toggle { inner }
}
}
impl<TBehaviour> NetworkBehaviour for Toggle<TBehaviour>
where
TBehaviour: NetworkBehaviour,
{
type ConnectionHandler = ToggleIntoConnectionHandler<TBehaviour::ConnectionHandler>;
type OutEvent = TBehaviour::OutEvent;
fn new_handler(&mut self) -> Self::ConnectionHandler {
ToggleIntoConnectionHandler {
inner: self.inner.as_mut().map(|i| i.new_handler()),
}
}
fn addresses_of_peer(&mut self, peer_id: &PeerId) -> Vec<Multiaddr> {
self.inner
.as_mut()
.map(|b| b.addresses_of_peer(peer_id))
.unwrap_or_else(Vec::new)
}
fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
if let Some(behaviour) = &mut self.inner {
if let Some(event) = event.maybe_map_handler(|h| h.inner, |h| h.inner) {
inject_from_swarm(behaviour, event);
}
}
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: libp2p_core::connection::ConnectionId,
event: crate::THandlerOutEvent<Self>,
) {
if let Some(behaviour) = &mut self.inner {
#[allow(deprecated)]
behaviour.inject_event(peer_id, connection_id, event)
}
}
fn poll(
&mut self,
cx: &mut Context<'_>,
params: &mut impl PollParameters,
) -> Poll<NetworkBehaviourAction<Self::OutEvent, Self::ConnectionHandler>> {
if let Some(inner) = self.inner.as_mut() {
inner.poll(cx, params).map(|action| {
action.map_handler(|h| ToggleIntoConnectionHandler { inner: Some(h) })
})
} else {
Poll::Pending
}
}
}
pub struct ToggleIntoConnectionHandler<TInner> {
inner: Option<TInner>,
}
impl<TInner> IntoConnectionHandler for ToggleIntoConnectionHandler<TInner>
where
TInner: IntoConnectionHandler,
{
type Handler = ToggleConnectionHandler<TInner::Handler>;
fn into_handler(
self,
remote_peer_id: &PeerId,
connected_point: &ConnectedPoint,
) -> Self::Handler {
ToggleConnectionHandler {
inner: self
.inner
.map(|h| h.into_handler(remote_peer_id, connected_point)),
}
}
fn inbound_protocol(&self) -> <Self::Handler as ConnectionHandler>::InboundProtocol {
if let Some(inner) = self.inner.as_ref() {
EitherUpgrade::A(SendWrapper(inner.inbound_protocol()))
} else {
EitherUpgrade::B(SendWrapper(DeniedUpgrade))
}
}
}
pub struct ToggleConnectionHandler<TInner> {
inner: Option<TInner>,
}
impl<TInner> ToggleConnectionHandler<TInner>
where
TInner: ConnectionHandler,
{
fn on_fully_negotiated_inbound(
&mut self,
FullyNegotiatedInbound {
protocol: out,
info,
}: FullyNegotiatedInbound<
<Self as ConnectionHandler>::InboundProtocol,
<Self as ConnectionHandler>::InboundOpenInfo,
>,
) {
let out = match out {
EitherOutput::First(out) => out,
EitherOutput::Second(v) => void::unreachable(v),
};
if let Either::Left(info) = info {
#[allow(deprecated)]
self.inner
.as_mut()
.expect("Can't receive an inbound substream if disabled; QED")
.inject_fully_negotiated_inbound(out, info)
} else {
panic!("Unexpected Either::Right in enabled `inject_fully_negotiated_inbound`.")
}
}
fn on_listen_upgrade_error(
&mut self,
ListenUpgradeError { info, error: err }: ListenUpgradeError<
<Self as ConnectionHandler>::InboundOpenInfo,
<Self as ConnectionHandler>::InboundProtocol,
>,
) {
let (inner, info) = match (self.inner.as_mut(), info) {
(Some(inner), Either::Left(info)) => (inner, info),
(None, Either::Right(())) => return,
(Some(_), Either::Right(())) => panic!(
"Unexpected `Either::Right` inbound info through \
`inject_listen_upgrade_error` in enabled state.",
),
(None, Either::Left(_)) => panic!(
"Unexpected `Either::Left` inbound info through \
`inject_listen_upgrade_error` in disabled state.",
),
};
let err = match err {
ConnectionHandlerUpgrErr::Timeout => ConnectionHandlerUpgrErr::Timeout,
ConnectionHandlerUpgrErr::Timer => ConnectionHandlerUpgrErr::Timer,
ConnectionHandlerUpgrErr::Upgrade(err) => {
ConnectionHandlerUpgrErr::Upgrade(err.map_err(|err| match err {
EitherError::A(e) => e,
EitherError::B(v) => void::unreachable(v),
}))
}
};
#[allow(deprecated)]
inner.inject_listen_upgrade_error(info, err)
}
}
impl<TInner> ConnectionHandler for ToggleConnectionHandler<TInner>
where
TInner: ConnectionHandler,
{
type InEvent = TInner::InEvent;
type OutEvent = TInner::OutEvent;
type Error = TInner::Error;
type InboundProtocol =
EitherUpgrade<SendWrapper<TInner::InboundProtocol>, SendWrapper<DeniedUpgrade>>;
type OutboundProtocol = TInner::OutboundProtocol;
type OutboundOpenInfo = TInner::OutboundOpenInfo;
type InboundOpenInfo = Either<TInner::InboundOpenInfo, ()>;
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
if let Some(inner) = self.inner.as_ref() {
inner
.listen_protocol()
.map_upgrade(|u| EitherUpgrade::A(SendWrapper(u)))
.map_info(Either::Left)
} else {
SubstreamProtocol::new(
EitherUpgrade::B(SendWrapper(DeniedUpgrade)),
Either::Right(()),
)
}
}
fn on_behaviour_event(&mut self, event: Self::InEvent) {
#[allow(deprecated)]
self.inner
.as_mut()
.expect("Can't receive events if disabled; QED")
.inject_event(event)
}
fn connection_keep_alive(&self) -> KeepAlive {
self.inner
.as_ref()
.map(|h| h.connection_keep_alive())
.unwrap_or(KeepAlive::No)
}
fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::OutEvent,
Self::Error,
>,
> {
if let Some(inner) = self.inner.as_mut() {
inner.poll(cx)
} else {
Poll::Pending
}
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
self.on_fully_negotiated_inbound(fully_negotiated_inbound)
}
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
protocol: out,
info,
}) =>
{
#[allow(deprecated)]
self.inner
.as_mut()
.expect("Can't receive an outbound substream if disabled; QED")
.inject_fully_negotiated_outbound(out, info)
}
ConnectionEvent::AddressChange(address_change) => {
if let Some(inner) = self.inner.as_mut() {
#[allow(deprecated)]
inner.inject_address_change(address_change.new_address)
}
}
ConnectionEvent::DialUpgradeError(DialUpgradeError { info, error: err }) =>
{
#[allow(deprecated)]
self.inner
.as_mut()
.expect("Can't receive an outbound substream if disabled; QED")
.inject_dial_upgrade_error(info, err)
}
ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
self.on_listen_upgrade_error(listen_upgrade_error)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dummy;
#[test]
fn ignore_listen_upgrade_error_when_disabled() {
let mut handler = ToggleConnectionHandler::<dummy::ConnectionHandler> { inner: None };
handler.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
info: Either::Right(()),
error: ConnectionHandlerUpgrErr::Timeout,
}));
}
}