use crate::upgrade::{InboundUpgrade, OutboundUpgrade, ProtocolName, UpgradeError};
use crate::{connection::ConnectedPoint, Negotiated};
use futures::{future::Either, prelude::*};
use log::debug;
use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture};
use std::{iter, mem, pin::Pin, task::Context, task::Poll};
pub use multistream_select::Version;
pub fn apply<C, U>(
conn: C,
up: U,
cp: ConnectedPoint,
v: Version,
) -> Either<InboundUpgradeApply<C, U>, OutboundUpgradeApply<C, U>>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>> + OutboundUpgrade<Negotiated<C>>,
{
match cp {
ConnectedPoint::Dialer { role_override, .. } if role_override.is_dialer() => {
Either::Right(apply_outbound(conn, up, v))
}
_ => Either::Left(apply_inbound(conn, up)),
}
}
pub fn apply_inbound<C, U>(conn: C, up: U) -> InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
{
let iter = up
.protocol_info()
.into_iter()
.map(NameWrap as fn(_) -> NameWrap<_>);
let future = multistream_select::listener_select_proto(conn, iter);
InboundUpgradeApply {
inner: InboundUpgradeApplyState::Init {
future,
upgrade: up,
},
}
}
pub fn apply_outbound<C, U>(conn: C, up: U, v: Version) -> OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
{
let iter = up
.protocol_info()
.into_iter()
.map(NameWrap as fn(_) -> NameWrap<_>);
let future = multistream_select::dialer_select_proto(conn, iter, v);
OutboundUpgradeApply {
inner: OutboundUpgradeApplyState::Init {
future,
upgrade: up,
},
}
}
pub struct InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
{
inner: InboundUpgradeApplyState<C, U>,
}
enum InboundUpgradeApplyState<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
{
Init {
future: ListenerSelectFuture<C, NameWrap<U::Info>>,
upgrade: U,
},
Upgrade {
future: Pin<Box<U::Future>>,
},
Undefined,
}
impl<C, U> Unpin for InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
{
}
impl<C, U> Future for InboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: InboundUpgrade<Negotiated<C>>,
{
type Output = Result<U::Output, UpgradeError<U::Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match mem::replace(&mut self.inner, InboundUpgradeApplyState::Undefined) {
InboundUpgradeApplyState::Init {
mut future,
upgrade,
} => {
let (info, io) = match Future::poll(Pin::new(&mut future), cx)? {
Poll::Ready(x) => x,
Poll::Pending => {
self.inner = InboundUpgradeApplyState::Init { future, upgrade };
return Poll::Pending;
}
};
self.inner = InboundUpgradeApplyState::Upgrade {
future: Box::pin(upgrade.upgrade_inbound(io, info.0)),
};
}
InboundUpgradeApplyState::Upgrade { mut future } => {
match Future::poll(Pin::new(&mut future), cx) {
Poll::Pending => {
self.inner = InboundUpgradeApplyState::Upgrade { future };
return Poll::Pending;
}
Poll::Ready(Ok(x)) => {
debug!("Successfully applied negotiated protocol");
return Poll::Ready(Ok(x));
}
Poll::Ready(Err(e)) => {
debug!("Failed to apply negotiated protocol");
return Poll::Ready(Err(UpgradeError::Apply(e)));
}
}
}
InboundUpgradeApplyState::Undefined => {
panic!("InboundUpgradeApplyState::poll called after completion")
}
}
}
}
}
pub struct OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
{
inner: OutboundUpgradeApplyState<C, U>,
}
enum OutboundUpgradeApplyState<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
{
Init {
future: DialerSelectFuture<C, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>,
upgrade: U,
},
Upgrade {
future: Pin<Box<U::Future>>,
},
Undefined,
}
impl<C, U> Unpin for OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
{
}
impl<C, U> Future for OutboundUpgradeApply<C, U>
where
C: AsyncRead + AsyncWrite + Unpin,
U: OutboundUpgrade<Negotiated<C>>,
{
type Output = Result<U::Output, UpgradeError<U::Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
match mem::replace(&mut self.inner, OutboundUpgradeApplyState::Undefined) {
OutboundUpgradeApplyState::Init {
mut future,
upgrade,
} => {
let (info, connection) = match Future::poll(Pin::new(&mut future), cx)? {
Poll::Ready(x) => x,
Poll::Pending => {
self.inner = OutboundUpgradeApplyState::Init { future, upgrade };
return Poll::Pending;
}
};
self.inner = OutboundUpgradeApplyState::Upgrade {
future: Box::pin(upgrade.upgrade_outbound(connection, info.0)),
};
}
OutboundUpgradeApplyState::Upgrade { mut future } => {
match Future::poll(Pin::new(&mut future), cx) {
Poll::Pending => {
self.inner = OutboundUpgradeApplyState::Upgrade { future };
return Poll::Pending;
}
Poll::Ready(Ok(x)) => {
debug!("Successfully applied negotiated protocol");
return Poll::Ready(Ok(x));
}
Poll::Ready(Err(e)) => {
debug!("Failed to apply negotiated protocol");
return Poll::Ready(Err(UpgradeError::Apply(e)));
}
}
}
OutboundUpgradeApplyState::Undefined => {
panic!("OutboundUpgradeApplyState::poll called after completion")
}
}
}
}
}
type NameWrapIter<I> = iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>;
#[derive(Clone)]
struct NameWrap<N>(N);
impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> {
fn as_ref(&self) -> &[u8] {
self.0.protocol_name()
}
}