use std::collections::VecDeque;
use std::io;
use crate::error::Error;
use crate::msgs::codec;
use crate::msgs::message::{MessageError, OpaqueMessage};
pub struct MessageDeframer {
frames: VecDeque<OpaqueMessage>,
desynced: bool,
buf: Box<[u8; OpaqueMessage::MAX_WIRE_SIZE]>,
used: usize,
}
impl Default for MessageDeframer {
fn default() -> Self {
Self::new()
}
}
impl MessageDeframer {
pub fn new() -> Self {
Self {
frames: VecDeque::new(),
desynced: false,
buf: Box::new([0u8; OpaqueMessage::MAX_WIRE_SIZE]),
used: 0,
}
}
pub fn pop(&mut self) -> Result<Option<OpaqueMessage>, Error> {
if self.desynced {
return Err(Error::CorruptMessage);
} else if let Some(msg) = self.frames.pop_front() {
return Ok(Some(msg));
}
let mut taken = 0;
loop {
let mut rd = codec::Reader::init(&self.buf[taken..self.used]);
let m = match OpaqueMessage::read(&mut rd) {
Ok(m) => m,
Err(MessageError::TooShortForHeader | MessageError::TooShortForLength) => break,
Err(_) => {
self.desynced = true;
return Err(Error::CorruptMessage);
}
};
taken += rd.used();
self.frames.push_back(m);
}
#[allow(clippy::comparison_chain)]
if taken < self.used {
self.buf
.copy_within(taken..self.used, 0);
self.used -= taken;
} else if taken == self.used {
self.used = 0;
}
Ok(self.frames.pop_front())
}
#[allow(clippy::comparison_chain)]
pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
if self.used == OpaqueMessage::MAX_WIRE_SIZE {
return Err(io::Error::new(io::ErrorKind::Other, "message buffer full"));
}
debug_assert!(self.used <= OpaqueMessage::MAX_WIRE_SIZE);
let new_bytes = rd.read(&mut self.buf[self.used..])?;
self.used += new_bytes;
Ok(new_bytes)
}
pub fn has_pending(&self) -> bool {
!self.frames.is_empty() || self.used > 0
}
}
#[cfg(test)]
mod tests {
use super::MessageDeframer;
use crate::msgs::message::{Message, OpaqueMessage};
use crate::{msgs, Error};
use std::convert::TryFrom;
use std::io;
const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
include_bytes!("../testdata/deframer-empty-applicationdata.bin");
const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
include_bytes!("../testdata/deframer-invalid-contenttype.bin");
const INVALID_VERSION_MESSAGE: &[u8] =
include_bytes!("../testdata/deframer-invalid-version.bin");
const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
struct ByteRead<'a> {
buf: &'a [u8],
offs: usize,
}
impl<'a> ByteRead<'a> {
fn new(bytes: &'a [u8]) -> Self {
ByteRead {
buf: bytes,
offs: 0,
}
}
}
impl<'a> io::Read for ByteRead<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut len = 0;
while len < buf.len() && len < self.buf.len() - self.offs {
buf[len] = self.buf[self.offs + len];
len += 1;
}
self.offs += len;
Ok(len)
}
}
fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
let mut rd = ByteRead::new(bytes);
d.read(&mut rd)
}
fn input_bytes_concat(
d: &mut MessageDeframer,
bytes1: &[u8],
bytes2: &[u8],
) -> io::Result<usize> {
let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
bytes[..bytes1.len()].clone_from_slice(bytes1);
bytes[bytes1.len()..].clone_from_slice(bytes2);
let mut rd = ByteRead::new(&bytes);
d.read(&mut rd)
}
struct ErrorRead {
error: Option<io::Error>,
}
impl ErrorRead {
fn new(error: io::Error) -> Self {
Self { error: Some(error) }
}
}
impl io::Read for ErrorRead {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
for (i, b) in buf.iter_mut().enumerate() {
*b = i as u8;
}
let error = self.error.take().unwrap();
Err(error)
}
}
fn input_error(d: &mut MessageDeframer) {
let error = io::Error::from(io::ErrorKind::TimedOut);
let mut rd = ErrorRead::new(error);
d.read(&mut rd)
.expect_err("error not propagated");
}
fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
let before = d.used;
for i in 0..bytes.len() {
assert_len(1, input_bytes(d, &bytes[i..i + 1]));
assert!(d.has_pending());
}
assert_eq!(before + bytes.len(), d.used);
}
fn assert_len(want: usize, got: io::Result<usize>) {
if let Ok(gotval) = got {
assert_eq!(gotval, want);
} else {
panic!("read failed, expected {:?} bytes", want);
}
}
fn pop_first(d: &mut MessageDeframer) {
let m = d.pop().unwrap().unwrap();
assert_eq!(m.typ, msgs::enums::ContentType::Handshake);
Message::try_from(m.into_plain_message()).unwrap();
}
fn pop_second(d: &mut MessageDeframer) {
let m = d.pop().unwrap().unwrap();
assert_eq!(m.typ, msgs::enums::ContentType::Alert);
Message::try_from(m.into_plain_message()).unwrap();
}
#[test]
fn check_incremental() {
let mut d = MessageDeframer::new();
assert!(!d.has_pending());
input_whole_incremental(&mut d, FIRST_MESSAGE);
assert!(d.has_pending());
assert_eq!(0, d.frames.len());
pop_first(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn check_incremental_2() {
let mut d = MessageDeframer::new();
assert!(!d.has_pending());
input_whole_incremental(&mut d, FIRST_MESSAGE);
assert!(d.has_pending());
input_whole_incremental(&mut d, SECOND_MESSAGE);
assert!(d.has_pending());
assert_eq!(0, d.frames.len());
pop_first(&mut d);
assert!(d.has_pending());
assert_eq!(1, d.frames.len());
pop_second(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn check_whole() {
let mut d = MessageDeframer::new();
assert!(!d.has_pending());
assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
assert!(d.has_pending());
assert_eq!(d.frames.len(), 0);
pop_first(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn check_whole_2() {
let mut d = MessageDeframer::new();
assert!(!d.has_pending());
assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
assert_eq!(d.frames.len(), 0);
pop_first(&mut d);
assert_eq!(d.frames.len(), 1);
pop_second(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_two_in_one_read() {
let mut d = MessageDeframer::new();
assert!(!d.has_pending());
assert_len(
FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
);
assert_eq!(d.frames.len(), 0);
pop_first(&mut d);
assert_eq!(d.frames.len(), 1);
pop_second(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_two_in_one_read_shortest_first() {
let mut d = MessageDeframer::new();
assert!(!d.has_pending());
assert_len(
FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
);
assert_eq!(d.frames.len(), 0);
pop_second(&mut d);
assert_eq!(d.frames.len(), 1);
pop_first(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_incremental_with_nonfatal_read_error() {
let mut d = MessageDeframer::new();
assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
input_error(&mut d);
assert_len(
FIRST_MESSAGE.len() - 3,
input_bytes(&mut d, &FIRST_MESSAGE[3..]),
);
assert_eq!(d.frames.len(), 0);
pop_first(&mut d);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_invalid_contenttype_errors() {
let mut d = MessageDeframer::new();
assert_len(
INVALID_CONTENTTYPE_MESSAGE.len(),
input_bytes(&mut d, INVALID_CONTENTTYPE_MESSAGE),
);
assert_eq!(d.pop().unwrap_err(), Error::CorruptMessage);
}
#[test]
fn test_invalid_version_errors() {
let mut d = MessageDeframer::new();
assert_len(
INVALID_VERSION_MESSAGE.len(),
input_bytes(&mut d, INVALID_VERSION_MESSAGE),
);
assert_eq!(d.pop().unwrap_err(), Error::CorruptMessage);
}
#[test]
fn test_invalid_length_errors() {
let mut d = MessageDeframer::new();
assert_len(
INVALID_LENGTH_MESSAGE.len(),
input_bytes(&mut d, INVALID_LENGTH_MESSAGE),
);
assert_eq!(d.pop().unwrap_err(), Error::CorruptMessage);
}
#[test]
fn test_empty_applicationdata() {
let mut d = MessageDeframer::new();
assert_len(
EMPTY_APPLICATIONDATA_MESSAGE.len(),
input_bytes(&mut d, EMPTY_APPLICATIONDATA_MESSAGE),
);
let m = d.pop().unwrap().unwrap();
assert_eq!(m.typ, msgs::enums::ContentType::ApplicationData);
assert_eq!(m.payload.0.len(), 0);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_invalid_empty_errors() {
let mut d = MessageDeframer::new();
assert_len(
INVALID_EMPTY_MESSAGE.len(),
input_bytes(&mut d, INVALID_EMPTY_MESSAGE),
);
assert_eq!(d.pop().unwrap_err(), Error::CorruptMessage);
assert_eq!(d.pop().unwrap_err(), Error::CorruptMessage);
}
#[test]
fn test_limited_buffer() {
const PAYLOAD_LEN: usize = 16_384;
let mut message = Vec::with_capacity(8192);
message.push(0x17); message.extend(&[0x03, 0x04]); message.extend((PAYLOAD_LEN as u16).to_be_bytes()); message.extend(&[0; PAYLOAD_LEN]);
let mut d = MessageDeframer::new();
assert_len(message.len(), input_bytes(&mut d, &message));
assert_len(
OpaqueMessage::MAX_WIRE_SIZE - 16_389,
input_bytes(&mut d, &message),
);
assert!(input_bytes(&mut d, &message).is_err());
}
}