mod error;
pub(crate) mod pool;
pub use error::{
ConnectionError, PendingConnectionError, PendingInboundConnectionError,
PendingOutboundConnectionError,
};
use crate::handler::ConnectionHandler;
use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper};
use crate::{ConnectionHandlerEvent, ConnectionHandlerUpgrErr, KeepAlive, SubstreamProtocol};
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::StreamExt;
use futures_timer::Delay;
use instant::Instant;
use libp2p_core::connection::ConnectedPoint;
use libp2p_core::multiaddr::Multiaddr;
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox};
use libp2p_core::upgrade::{InboundUpgradeApply, OutboundUpgradeApply};
use libp2p_core::PeerId;
use libp2p_core::{upgrade, UpgradeError};
use std::future::Future;
use std::task::Waker;
use std::time::Duration;
use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Connected {
pub endpoint: ConnectedPoint,
pub peer_id: PeerId,
}
#[derive(Debug, Clone)]
pub enum Event<T> {
Handler(T),
AddressChange(Multiaddr),
}
pub struct Connection<THandler>
where
THandler: ConnectionHandler,
{
muxing: StreamMuxerBox,
handler: THandler,
negotiating_in: FuturesUnordered<
SubstreamUpgrade<
THandler::InboundOpenInfo,
InboundUpgradeApply<SubstreamBox, SendWrapper<THandler::InboundProtocol>>,
>,
>,
negotiating_out: FuturesUnordered<
SubstreamUpgrade<
THandler::OutboundOpenInfo,
OutboundUpgradeApply<SubstreamBox, SendWrapper<THandler::OutboundProtocol>>,
>,
>,
shutdown: Shutdown,
substream_upgrade_protocol_override: Option<upgrade::Version>,
max_negotiating_inbound_streams: usize,
requested_substreams: FuturesUnordered<
SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
>,
}
impl<THandler> fmt::Debug for Connection<THandler>
where
THandler: ConnectionHandler + fmt::Debug,
THandler::OutboundOpenInfo: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Connection")
.field("handler", &self.handler)
.finish()
}
}
impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
impl<THandler> Connection<THandler>
where
THandler: ConnectionHandler,
{
pub fn new(
muxer: StreamMuxerBox,
handler: THandler,
substream_upgrade_protocol_override: Option<upgrade::Version>,
max_negotiating_inbound_streams: usize,
) -> Self {
Connection {
muxing: muxer,
handler,
negotiating_in: Default::default(),
negotiating_out: Default::default(),
shutdown: Shutdown::None,
substream_upgrade_protocol_override,
max_negotiating_inbound_streams,
requested_substreams: Default::default(),
}
}
pub fn on_behaviour_event(&mut self, event: THandler::InEvent) {
#[allow(deprecated)]
self.handler.inject_event(event);
}
pub fn close(self) -> (THandler, impl Future<Output = io::Result<()>>) {
(self.handler, self.muxing.close())
}
pub fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Event<THandler::OutEvent>, ConnectionError<THandler::Error>>> {
let Self {
requested_substreams,
muxing,
handler,
negotiating_out,
negotiating_in,
shutdown,
max_negotiating_inbound_streams,
substream_upgrade_protocol_override,
} = self.get_mut();
loop {
match requested_substreams.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(()))) => continue,
Poll::Ready(Some(Err(user_data))) => {
#[allow(deprecated)]
handler.inject_dial_upgrade_error(user_data, ConnectionHandlerUpgrErr::Timeout);
continue;
}
Poll::Ready(None) | Poll::Pending => {}
}
match handler.poll(cx) {
Poll::Pending => {}
Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
let timeout = *protocol.timeout();
let (upgrade, user_data) = protocol.into_upgrade();
requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
continue; }
Poll::Ready(ConnectionHandlerEvent::Custom(event)) => {
return Poll::Ready(Ok(Event::Handler(event)));
}
Poll::Ready(ConnectionHandlerEvent::Close(err)) => {
return Poll::Ready(Err(ConnectionError::Handler(err)));
}
}
match negotiating_out.poll_next_unpin(cx) {
Poll::Pending | Poll::Ready(None) => {}
Poll::Ready(Some((user_data, Ok(upgrade)))) => {
#[allow(deprecated)]
handler.inject_fully_negotiated_outbound(upgrade, user_data);
continue;
}
Poll::Ready(Some((user_data, Err(err)))) => {
#[allow(deprecated)]
handler.inject_dial_upgrade_error(user_data, err);
continue;
}
}
match negotiating_in.poll_next_unpin(cx) {
Poll::Pending | Poll::Ready(None) => {}
Poll::Ready(Some((user_data, Ok(upgrade)))) => {
#[allow(deprecated)]
handler.inject_fully_negotiated_inbound(upgrade, user_data);
continue;
}
Poll::Ready(Some((user_data, Err(err)))) => {
#[allow(deprecated)]
handler.inject_listen_upgrade_error(user_data, err);
continue;
}
}
let keep_alive = handler.connection_keep_alive();
match (&mut *shutdown, keep_alive) {
(Shutdown::Later(timer, deadline), KeepAlive::Until(t)) => {
if *deadline != t {
*deadline = t;
if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
timer.reset(dur)
}
}
}
(_, KeepAlive::Until(t)) => {
if let Some(dur) = t.checked_duration_since(Instant::now()) {
*shutdown = Shutdown::Later(Delay::new(dur), t)
}
}
(_, KeepAlive::No) => *shutdown = Shutdown::Asap,
(_, KeepAlive::Yes) => *shutdown = Shutdown::None,
};
if negotiating_in.is_empty()
&& negotiating_out.is_empty()
&& requested_substreams.is_empty()
{
match shutdown {
Shutdown::None => {}
Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) {
Poll::Ready(_) => {
return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
}
Poll::Pending => {}
},
}
}
match muxing.poll_unpin(cx)? {
Poll::Pending => {}
Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
#[allow(deprecated)]
handler.inject_address_change(&address);
return Poll::Ready(Ok(Event::AddressChange(address)));
}
}
if let Some(requested_substream) = requested_substreams.iter_mut().next() {
match muxing.poll_outbound_unpin(cx)? {
Poll::Pending => {}
Poll::Ready(substream) => {
let (user_data, timeout, upgrade) = requested_substream.extract();
negotiating_out.push(SubstreamUpgrade::new_outbound(
substream,
user_data,
timeout,
upgrade,
*substream_upgrade_protocol_override,
));
continue; }
}
}
if negotiating_in.len() < *max_negotiating_inbound_streams {
match muxing.poll_inbound_unpin(cx)? {
Poll::Pending => {}
Poll::Ready(substream) => {
let protocol = handler.listen_protocol();
negotiating_in.push(SubstreamUpgrade::new_inbound(substream, protocol));
continue; }
}
}
return Poll::Pending; }
}
}
#[derive(Debug, Copy, Clone)]
pub struct IncomingInfo<'a> {
pub local_addr: &'a Multiaddr,
pub send_back_addr: &'a Multiaddr,
}
impl<'a> IncomingInfo<'a> {
pub fn create_connected_point(&self) -> ConnectedPoint {
ConnectedPoint::Listener {
local_addr: self.local_addr.clone(),
send_back_addr: self.send_back_addr.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionLimit {
pub limit: u32,
pub current: u32,
}
impl fmt::Display for ConnectionLimit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"connection limit exceeded ({}/{})",
self.current, self.limit
)
}
}
impl std::error::Error for ConnectionLimit {}
struct SubstreamUpgrade<UserData, Upgrade> {
user_data: Option<UserData>,
timeout: Delay,
upgrade: Upgrade,
}
impl<UserData, Upgrade>
SubstreamUpgrade<UserData, OutboundUpgradeApply<SubstreamBox, SendWrapper<Upgrade>>>
where
Upgrade: Send + OutboundUpgradeSend,
{
fn new_outbound(
substream: SubstreamBox,
user_data: UserData,
timeout: Delay,
upgrade: Upgrade,
version_override: Option<upgrade::Version>,
) -> Self {
let effective_version = match version_override {
Some(version_override) if version_override != upgrade::Version::default() => {
log::debug!(
"Substream upgrade protocol override: {:?} -> {:?}",
upgrade::Version::default(),
version_override
);
version_override
}
_ => upgrade::Version::default(),
};
Self {
user_data: Some(user_data),
timeout,
upgrade: upgrade::apply_outbound(substream, SendWrapper(upgrade), effective_version),
}
}
}
impl<UserData, Upgrade>
SubstreamUpgrade<UserData, InboundUpgradeApply<SubstreamBox, SendWrapper<Upgrade>>>
where
Upgrade: Send + InboundUpgradeSend,
{
fn new_inbound(
substream: SubstreamBox,
protocol: SubstreamProtocol<Upgrade, UserData>,
) -> Self {
let timeout = *protocol.timeout();
let (upgrade, open_info) = protocol.into_upgrade();
Self {
user_data: Some(open_info),
timeout: Delay::new(timeout),
upgrade: upgrade::apply_inbound(substream, SendWrapper(upgrade)),
}
}
}
impl<UserData, Upgrade> Unpin for SubstreamUpgrade<UserData, Upgrade> {}
impl<UserData, Upgrade, UpgradeOutput, TUpgradeError> Future for SubstreamUpgrade<UserData, Upgrade>
where
Upgrade: Future<Output = Result<UpgradeOutput, UpgradeError<TUpgradeError>>> + Unpin,
{
type Output = (
UserData,
Result<UpgradeOutput, ConnectionHandlerUpgrErr<TUpgradeError>>,
);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.timeout.poll_unpin(cx) {
Poll::Ready(()) => {
return Poll::Ready((
self.user_data
.take()
.expect("Future not to be polled again once ready."),
Err(ConnectionHandlerUpgrErr::Timeout),
))
}
Poll::Pending => {}
}
match self.upgrade.poll_unpin(cx) {
Poll::Ready(Ok(upgrade)) => Poll::Ready((
self.user_data
.take()
.expect("Future not to be polled again once ready."),
Ok(upgrade),
)),
Poll::Ready(Err(err)) => Poll::Ready((
self.user_data
.take()
.expect("Future not to be polled again once ready."),
Err(ConnectionHandlerUpgrErr::Upgrade(err)),
)),
Poll::Pending => Poll::Pending,
}
}
}
enum SubstreamRequested<UserData, Upgrade> {
Waiting {
user_data: UserData,
timeout: Delay,
upgrade: Upgrade,
extracted_waker: Option<Waker>,
},
Done,
}
impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
Self::Waiting {
user_data,
timeout: Delay::new(timeout),
upgrade,
extracted_waker: None,
}
}
fn extract(&mut self) -> (UserData, Delay, Upgrade) {
match mem::replace(self, Self::Done) {
SubstreamRequested::Waiting {
user_data,
timeout,
upgrade,
extracted_waker: waker,
} => {
if let Some(waker) = waker {
waker.wake();
}
(user_data, timeout, upgrade)
}
SubstreamRequested::Done => panic!("cannot extract twice"),
}
}
}
impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
type Output = Result<(), UserData>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match mem::replace(this, Self::Done) {
SubstreamRequested::Waiting {
user_data,
upgrade,
mut timeout,
..
} => match timeout.poll_unpin(cx) {
Poll::Ready(()) => Poll::Ready(Err(user_data)),
Poll::Pending => {
*this = Self::Waiting {
user_data,
upgrade,
timeout,
extracted_waker: Some(cx.waker().clone()),
};
Poll::Pending
}
},
SubstreamRequested::Done => Poll::Ready(Ok(())),
}
}
}
#[derive(Debug)]
enum Shutdown {
None,
Asap,
Later(Delay, Instant),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::keep_alive;
use futures::AsyncRead;
use futures::AsyncWrite;
use libp2p_core::upgrade::DeniedUpgrade;
use libp2p_core::StreamMuxer;
use quickcheck::*;
use std::sync::{Arc, Weak};
use void::Void;
#[test]
fn max_negotiating_inbound_streams() {
fn prop(max_negotiating_inbound_streams: u8) {
let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
let alive_substream_counter = Arc::new(());
let mut connection = Connection::new(
StreamMuxerBox::new(DummyStreamMuxer {
counter: alive_substream_counter.clone(),
}),
keep_alive::ConnectionHandler,
None,
max_negotiating_inbound_streams,
);
let result = Pin::new(&mut connection)
.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));
assert!(result.is_pending());
assert_eq!(
Arc::weak_count(&alive_substream_counter),
max_negotiating_inbound_streams,
"Expect no more than the maximum number of allowed streams"
);
}
QuickCheck::new().quickcheck(prop as fn(_));
}
#[test]
fn outbound_stream_timeout_starts_on_request() {
let upgrade_timeout = Duration::from_secs(1);
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
MockConnectionHandler::new(upgrade_timeout),
None,
2,
);
connection.handler.open_new_outbound();
let _ = Pin::new(&mut connection)
.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));
std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
let _ = Pin::new(&mut connection)
.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));
assert!(matches!(
connection.handler.error.unwrap(),
ConnectionHandlerUpgrErr::Timeout
))
}
struct DummyStreamMuxer {
counter: Arc<()>,
}
impl StreamMuxer for DummyStreamMuxer {
type Substream = PendingSubstream;
type Error = Void;
fn poll_inbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Ready(Ok(PendingSubstream(Arc::downgrade(&self.counter))))
}
fn poll_outbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
Poll::Pending
}
}
struct PendingStreamMuxer;
impl StreamMuxer for PendingStreamMuxer {
type Substream = PendingSubstream;
type Error = Void;
fn poll_inbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Pending
}
fn poll_outbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
Poll::Pending
}
}
struct PendingSubstream(Weak<()>);
impl AsyncRead for PendingSubstream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
}
impl AsyncWrite for PendingSubstream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct MockConnectionHandler {
outbound_requested: bool,
error: Option<ConnectionHandlerUpgrErr<Void>>,
upgrade_timeout: Duration,
}
impl MockConnectionHandler {
fn new(upgrade_timeout: Duration) -> Self {
Self {
outbound_requested: false,
error: None,
upgrade_timeout,
}
}
fn open_new_outbound(&mut self) {
self.outbound_requested = true;
}
}
impl ConnectionHandler for MockConnectionHandler {
type InEvent = Void;
type OutEvent = Void;
type Error = Void;
type InboundProtocol = DeniedUpgrade;
type OutboundProtocol = DeniedUpgrade;
type InboundOpenInfo = ();
type OutboundOpenInfo = ();
fn listen_protocol(
&self,
) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
}
fn inject_fully_negotiated_inbound(
&mut self,
protocol: <Self::InboundProtocol as InboundUpgradeSend>::Output,
_: Self::InboundOpenInfo,
) {
void::unreachable(protocol)
}
fn inject_fully_negotiated_outbound(
&mut self,
protocol: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
_: Self::OutboundOpenInfo,
) {
void::unreachable(protocol)
}
fn inject_event(&mut self, event: Self::InEvent) {
void::unreachable(event)
}
fn inject_dial_upgrade_error(
&mut self,
_: Self::OutboundOpenInfo,
error: ConnectionHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>,
) {
self.error = Some(error)
}
fn connection_keep_alive(&self) -> KeepAlive {
KeepAlive::Yes
}
fn poll(
&mut self,
_: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::OutEvent,
Self::Error,
>,
> {
if self.outbound_requested {
self.outbound_requested = false;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(DeniedUpgrade, ())
.with_timeout(self.upgrade_timeout),
});
}
Poll::Pending
}
}
}