use crate::frame::{self, Frame, Kind, Reason};
use crate::frame::{
DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
};
use crate::proto::Error;
use crate::hpack;
use futures_core::Stream;
use bytes::BytesMut;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio_util::codec::FramedRead as InnerFramedRead;
use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
#[derive(Debug)]
pub struct FramedRead<T> {
inner: InnerFramedRead<T, LengthDelimitedCodec>,
hpack: hpack::Decoder,
max_header_list_size: usize,
partial: Option<Partial>,
}
#[derive(Debug)]
struct Partial {
frame: Continuable,
buf: BytesMut,
}
#[derive(Debug)]
enum Continuable {
Headers(frame::Headers),
PushPromise(frame::PushPromise),
}
impl<T> FramedRead<T> {
pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
FramedRead {
inner,
hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
partial: None,
}
}
pub fn get_ref(&self) -> &T {
self.inner.get_ref()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
#[cfg(feature = "unstable")]
#[inline]
pub fn max_frame_size(&self) -> usize {
self.inner.decoder().max_frame_length()
}
#[inline]
pub fn set_max_frame_size(&mut self, val: usize) {
assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
self.inner.decoder_mut().set_max_frame_length(val)
}
#[inline]
pub fn set_max_header_list_size(&mut self, val: usize) {
self.max_header_list_size = val;
}
}
fn decode_frame(
hpack: &mut hpack::Decoder,
max_header_list_size: usize,
partial_inout: &mut Option<Partial>,
mut bytes: BytesMut,
) -> Result<Option<Frame>, Error> {
let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
let _e = span.enter();
tracing::trace!("decoding frame from {}B", bytes.len());
let head = frame::Head::parse(&bytes);
if partial_inout.is_some() && head.kind() != Kind::Continuation {
proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
let kind = head.kind();
tracing::trace!(frame.kind = ?kind);
macro_rules! header_block {
($frame:ident, $head:ident, $bytes:ident) => ({
let _ = $bytes.split_to(frame::HEADER_LEN);
let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
Ok(res) => res,
Err(frame::Error::InvalidDependencyId) => {
proto_err!(stream: "invalid HEADERS dependency ID");
return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
},
Err(e) => {
proto_err!(conn: "failed to load frame; err={:?}", e);
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
};
let is_end_headers = frame.is_end_headers();
match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
Ok(_) => {},
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
Err(frame::Error::MalformedMessage) => {
let id = $head.stream_id();
proto_err!(stream: "malformed header block; stream={:?}", id);
return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
},
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
frame.into()
} else {
tracing::trace!("loaded partial header block");
*partial_inout = Some(Partial {
frame: Continuable::$frame(frame),
buf: payload,
});
return Ok(None);
}
});
}
let frame = match kind {
Kind::Settings => {
let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Ping => {
let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load PING frame; err={:?}", e);
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::WindowUpdate => {
let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Data => {
let _ = bytes.split_to(frame::HEADER_LEN);
let res = frame::Data::load(head, bytes.freeze());
res.map_err(|e| {
proto_err!(conn: "failed to load DATA frame; err={:?}", e);
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Headers => header_block!(Headers, head, bytes),
Kind::Reset => {
let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load RESET frame; err={:?}", e);
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::GoAway => {
let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
Error::library_go_away(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::PushPromise => header_block!(PushPromise, head, bytes),
Kind::Priority => {
if head.stream_id() == 0 {
proto_err!(conn: "invalid stream ID 0");
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
Ok(frame) => frame.into(),
Err(frame::Error::InvalidDependencyId) => {
let id = head.stream_id();
proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
}
Err(e) => {
proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
}
}
Kind::Continuation => {
let is_end_headers = (head.flag() & 0x4) == 0x4;
let mut partial = match partial_inout.take() {
Some(partial) => partial,
None => {
proto_err!(conn: "received unexpected CONTINUATION frame");
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
};
if partial.frame.stream_id() != head.stream_id() {
proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
if partial.buf.is_empty() {
partial.buf = bytes.split_off(frame::HEADER_LEN);
} else {
if partial.frame.is_over_size() {
if partial.buf.len() + bytes.len() > max_header_list_size {
proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
}
}
partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
}
match partial
.frame
.load_hpack(&mut partial.buf, max_header_list_size, hpack)
{
Ok(_) => {}
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
Err(frame::Error::MalformedMessage) => {
let id = head.stream_id();
proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
}
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
partial.frame.into()
} else {
*partial_inout = Some(partial);
return Ok(None);
}
}
Kind::Unknown => {
return Ok(None);
}
};
Ok(Some(frame))
}
impl<T> Stream for FramedRead<T>
where
T: AsyncRead + Unpin,
{
type Item = Result<Frame, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let span = tracing::trace_span!("FramedRead::poll_next");
let _e = span.enter();
loop {
tracing::trace!("poll");
let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(bytes)) => bytes,
Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
None => return Poll::Ready(None),
};
tracing::trace!(read.bytes = bytes.len());
let Self {
ref mut hpack,
max_header_list_size,
ref mut partial,
..
} = *self;
if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? {
tracing::debug!(?frame, "received");
return Poll::Ready(Some(Ok(frame)));
}
}
}
}
fn map_err(err: io::Error) -> Error {
if let io::ErrorKind::InvalidData = err.kind() {
if let Some(custom) = err.get_ref() {
if custom.is::<LengthDelimitedCodecError>() {
return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
}
}
}
err.into()
}
impl Continuable {
fn stream_id(&self) -> frame::StreamId {
match *self {
Continuable::Headers(ref h) => h.stream_id(),
Continuable::PushPromise(ref p) => p.stream_id(),
}
}
fn is_over_size(&self) -> bool {
match *self {
Continuable::Headers(ref h) => h.is_over_size(),
Continuable::PushPromise(ref p) => p.is_over_size(),
}
}
fn load_hpack(
&mut self,
src: &mut BytesMut,
max_header_list_size: usize,
decoder: &mut hpack::Decoder,
) -> Result<(), frame::Error> {
match *self {
Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
}
}
}
impl<T> From<Continuable> for Frame<T> {
fn from(cont: Continuable) -> Self {
match cont {
Continuable::Headers(mut headers) => {
headers.set_end_headers();
headers.into()
}
Continuable::PushPromise(mut push) => {
push.set_end_headers();
push.into()
}
}
}
}