use asynchronous_codec::Framed;
use bytes::BytesMut;
use futures::prelude::*;
use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use log::{error, warn};
use sc_network_common::protocol::ProtocolName;
use std::{
convert::Infallible,
io, mem,
pin::Pin,
task::{Context, Poll},
vec,
};
use unsigned_varint::codec::UviBytes;
const MAX_HANDSHAKE_SIZE: usize = 1024;
#[derive(Debug, Clone)]
pub struct NotificationsIn {
protocol_names: Vec<ProtocolName>,
max_notification_size: u64,
}
#[derive(Debug, Clone)]
pub struct NotificationsOut {
protocol_names: Vec<ProtocolName>,
initial_message: Vec<u8>,
max_notification_size: u64,
}
#[pin_project::pin_project]
pub struct NotificationsInSubstream<TSubstream> {
#[pin]
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
handshake: NotificationsInSubstreamHandshake,
}
enum NotificationsInSubstreamHandshake {
NotSent,
PendingSend(Vec<u8>),
Flush,
Sent,
ClosingInResponseToRemote,
BothSidesClosed,
}
#[pin_project::pin_project]
pub struct NotificationsOutSubstream<TSubstream> {
#[pin]
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
}
impl NotificationsIn {
pub fn new(
main_protocol_name: impl Into<ProtocolName>,
fallback_names: Vec<ProtocolName>,
max_notification_size: u64,
) -> Self {
let mut protocol_names = fallback_names;
protocol_names.insert(0, main_protocol_name.into());
Self { protocol_names, max_notification_size }
}
}
impl UpgradeInfo for NotificationsIn {
type Info = ProtocolName;
type InfoIter = vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.protocol_names.clone().into_iter()
}
}
impl<TSubstream> InboundUpgrade<TSubstream> for NotificationsIn
where
TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = NotificationsInOpen<TSubstream>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type Error = NotificationsHandshakeError;
fn upgrade_inbound(self, mut socket: TSubstream, negotiated_name: Self::Info) -> Self::Future {
Box::pin(async move {
let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
if handshake_len > MAX_HANDSHAKE_SIZE {
return Err(NotificationsHandshakeError::TooLarge {
requested: handshake_len,
max: MAX_HANDSHAKE_SIZE,
})
}
let mut handshake = vec![0u8; handshake_len];
if !handshake.is_empty() {
socket.read_exact(&mut handshake).await?;
}
let mut codec = UviBytes::default();
codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
let substream = NotificationsInSubstream {
socket: Framed::new(socket, codec),
handshake: NotificationsInSubstreamHandshake::NotSent,
};
Ok(NotificationsInOpen {
handshake,
negotiated_fallback: if negotiated_name == self.protocol_names[0] {
None
} else {
Some(negotiated_name)
},
substream,
})
})
}
}
pub struct NotificationsInOpen<TSubstream> {
pub handshake: Vec<u8>,
pub negotiated_fallback: Option<ProtocolName>,
pub substream: NotificationsInSubstream<TSubstream>,
}
impl<TSubstream> NotificationsInSubstream<TSubstream>
where
TSubstream: AsyncRead + AsyncWrite + Unpin,
{
pub fn send_handshake(&mut self, message: impl Into<Vec<u8>>) {
if !matches!(self.handshake, NotificationsInSubstreamHandshake::NotSent) {
error!(target: "sub-libp2p", "Tried to send handshake twice");
return
}
self.handshake = NotificationsInSubstreamHandshake::PendingSend(message.into());
}
pub fn poll_process(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Result<Infallible, io::Error>> {
let mut this = self.project();
loop {
match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
NotificationsInSubstreamHandshake::PendingSend(msg) => {
match Sink::poll_ready(this.socket.as_mut(), cx) {
Poll::Ready(_) => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
Ok(()) => {},
Err(err) => return Poll::Ready(Err(err)),
}
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::Flush => {
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
},
}
},
st @ NotificationsInSubstreamHandshake::NotSent |
st @ NotificationsInSubstreamHandshake::Sent |
st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
*this.handshake = st;
return Poll::Pending
},
}
}
}
}
impl<TSubstream> Stream for NotificationsInSubstream<TSubstream>
where
TSubstream: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<BytesMut, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
NotificationsInSubstreamHandshake::NotSent => {
*this.handshake = NotificationsInSubstreamHandshake::NotSent;
return Poll::Pending
},
NotificationsInSubstreamHandshake::PendingSend(msg) => {
match Sink::poll_ready(this.socket.as_mut(), cx) {
Poll::Ready(_) => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
Ok(()) => {},
Err(err) => return Poll::Ready(Some(Err(err))),
}
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::Flush => {
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::Sent => {
match Stream::poll_next(this.socket.as_mut(), cx) {
Poll::Ready(None) =>
*this.handshake =
NotificationsInSubstreamHandshake::ClosingInResponseToRemote,
Poll::Ready(Some(msg)) => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Ready(Some(msg))
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::ClosingInResponseToRemote =>
match Sink::poll_close(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::BothSidesClosed,
Poll::Pending => {
*this.handshake =
NotificationsInSubstreamHandshake::ClosingInResponseToRemote;
return Poll::Pending
},
},
NotificationsInSubstreamHandshake::BothSidesClosed => return Poll::Ready(None),
}
}
}
}
impl NotificationsOut {
pub fn new(
main_protocol_name: impl Into<ProtocolName>,
fallback_names: Vec<ProtocolName>,
initial_message: impl Into<Vec<u8>>,
max_notification_size: u64,
) -> Self {
let initial_message = initial_message.into();
if initial_message.len() > MAX_HANDSHAKE_SIZE {
error!(target: "sub-libp2p", "Outbound networking handshake is above allowed protocol limit");
}
let mut protocol_names = fallback_names;
protocol_names.insert(0, main_protocol_name.into());
Self { protocol_names, initial_message, max_notification_size }
}
}
impl UpgradeInfo for NotificationsOut {
type Info = ProtocolName;
type InfoIter = vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.protocol_names.clone().into_iter()
}
}
impl<TSubstream> OutboundUpgrade<TSubstream> for NotificationsOut
where
TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = NotificationsOutOpen<TSubstream>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type Error = NotificationsHandshakeError;
fn upgrade_outbound(self, mut socket: TSubstream, negotiated_name: Self::Info) -> Self::Future {
Box::pin(async move {
upgrade::write_length_prefixed(&mut socket, &self.initial_message).await?;
let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
if handshake_len > MAX_HANDSHAKE_SIZE {
return Err(NotificationsHandshakeError::TooLarge {
requested: handshake_len,
max: MAX_HANDSHAKE_SIZE,
})
}
let mut handshake = vec![0u8; handshake_len];
if !handshake.is_empty() {
socket.read_exact(&mut handshake).await?;
}
let mut codec = UviBytes::default();
codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
Ok(NotificationsOutOpen {
handshake,
negotiated_fallback: if negotiated_name == self.protocol_names[0] {
None
} else {
Some(negotiated_name)
},
substream: NotificationsOutSubstream { socket: Framed::new(socket, codec) },
})
})
}
}
pub struct NotificationsOutOpen<TSubstream> {
pub handshake: Vec<u8>,
pub negotiated_fallback: Option<ProtocolName>,
pub substream: NotificationsOutSubstream<TSubstream>,
}
impl<TSubstream> Sink<Vec<u8>> for NotificationsOutSubstream<TSubstream>
where
TSubstream: AsyncRead + AsyncWrite + Unpin,
{
type Error = NotificationsOutError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_ready(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
let mut this = self.project();
Sink::start_send(this.socket.as_mut(), io::Cursor::new(item))
.map_err(NotificationsOutError::Io)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_close(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}
}
#[derive(Debug, thiserror::Error)]
pub enum NotificationsHandshakeError {
#[error(transparent)]
Io(#[from] io::Error),
#[error("Initial message or handshake was too large: {requested}")]
TooLarge {
requested: usize,
max: usize,
},
#[error(transparent)]
VarintDecode(#[from] unsigned_varint::decode::Error),
}
impl From<unsigned_varint::io::ReadError> for NotificationsHandshakeError {
fn from(err: unsigned_varint::io::ReadError) -> Self {
match err {
unsigned_varint::io::ReadError::Io(err) => Self::Io(err),
unsigned_varint::io::ReadError::Decode(err) => Self::VarintDecode(err),
_ => {
warn!("Unrecognized varint decoding error");
Self::Io(From::from(io::ErrorKind::InvalidData))
},
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum NotificationsOutError {
#[error(transparent)]
Io(#[from] io::Error),
}
#[cfg(test)]
mod tests {
use super::{NotificationsIn, NotificationsInOpen, NotificationsOut, NotificationsOutOpen};
use futures::{channel::oneshot, prelude::*};
use libp2p::core::upgrade;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::TokioAsyncReadCompatExt;
#[tokio::test]
async fn basic_works() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, mut substream, .. } = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await
.unwrap();
assert_eq!(handshake, b"hello world");
substream.send(b"test message".to_vec()).await.unwrap();
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();
assert_eq!(handshake, b"initial message");
substream.send_handshake(&b"hello world"[..]);
let msg = substream.next().await.unwrap().unwrap();
assert_eq!(msg.as_ref(), b"test message");
client.await.unwrap();
}
#[tokio::test]
async fn empty_handshake() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, mut substream, .. } = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), vec![], 1024 * 1024),
upgrade::Version::V1,
)
.await
.unwrap();
assert!(handshake.is_empty());
substream.send(Default::default()).await.unwrap();
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();
assert!(handshake.is_empty());
substream.send_handshake(vec![]);
let msg = substream.next().await.unwrap().unwrap();
assert!(msg.as_ref().is_empty());
client.await.unwrap();
}
#[tokio::test]
async fn refused() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let outcome = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"hello"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await;
assert!(outcome.is_err());
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();
assert_eq!(handshake, b"hello");
drop(substream);
client.await.unwrap();
}
#[tokio::test]
async fn large_initial_message_refused() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let ret = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(
PROTO_NAME,
Vec::new(),
(0..32768).map(|_| 0).collect::<Vec<_>>(),
1024 * 1024,
),
upgrade::Version::V1,
)
.await;
assert!(ret.is_err());
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let ret = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await;
assert!(ret.is_err());
client.await.unwrap();
}
#[tokio::test]
async fn large_handshake_refused() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let ret = upgrade::apply_outbound(
socket.compat(),
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
upgrade::Version::V1,
)
.await;
assert!(ret.is_err());
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = upgrade::apply_inbound(
socket.compat(),
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
)
.await
.unwrap();
assert_eq!(handshake, b"initial message");
substream.send_handshake((0..32768).map(|_| 0).collect::<Vec<_>>());
let _ = substream.next().await;
client.await.unwrap();
}
}