use crate::{
as_u64,
base::{Header, OpCode},
connection::Mode,
extension::{Extension, Param},
BoxedError, Storage,
};
use flate2::{write::DeflateDecoder, Compress, Compression, FlushCompress, Status};
use std::{
convert::TryInto,
io::{self, Write},
mem,
};
const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover";
const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits";
const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover";
const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits";
#[derive(Debug)]
pub struct Deflate {
mode: Mode,
enabled: bool,
buffer: Vec<u8>,
params: Vec<Param<'static>>,
our_max_window_bits: u8,
their_max_window_bits: u8,
await_last_fragment: bool,
}
impl Deflate {
pub fn new(mode: Mode) -> Self {
let params = match mode {
Mode::Server => Vec::new(),
Mode::Client => {
let mut params = Vec::new();
params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER));
params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER));
params.push(Param::new(CLIENT_MAX_WINDOW_BITS));
params
}
};
Deflate {
mode,
enabled: false,
buffer: Vec::new(),
params,
our_max_window_bits: 15,
their_max_window_bits: 15,
await_last_fragment: false,
}
}
pub fn set_max_server_window_bits(&mut self, max: u8) {
assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode");
assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15");
self.their_max_window_bits = max; let mut p = Param::new(SERVER_MAX_WINDOW_BITS);
p.set_value(Some(max.to_string()));
self.params.push(p)
}
pub fn set_max_client_window_bits(&mut self, max: u8) {
assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode");
assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15");
self.our_max_window_bits = max; if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) {
p.set_value(Some(max.to_string()));
} else {
let mut p = Param::new(CLIENT_MAX_WINDOW_BITS);
p.set_value(Some(max.to_string()));
self.params.push(p)
}
}
fn set_their_max_window_bits(&mut self, p: &Param, expected: Option<u8>) -> Result<(), ()> {
if let Some(Ok(v)) = p.value().map(|s| s.parse::<u8>()) {
if v < 8 || v > 15 {
log::debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v);
return Err(());
}
if let Some(x) = expected {
if v > x {
log::debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x);
return Err(());
}
}
self.their_max_window_bits = std::cmp::max(9, v);
}
Ok(())
}
}
impl Extension for Deflate {
fn name(&self) -> &str {
"permessage-deflate"
}
fn is_enabled(&self) -> bool {
self.enabled
}
fn params(&self) -> &[Param] {
&self.params
}
fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> {
match self.mode {
Mode::Server => {
self.params.clear();
for p in params {
log::trace!("configure server with: {}", p);
match p.name() {
CLIENT_MAX_WINDOW_BITS => {
if self.set_their_max_window_bits(&p, None).is_err() {
return Ok(());
}
}
SERVER_MAX_WINDOW_BITS => {
if let Some(Ok(v)) = p.value().map(|s| s.parse::<u8>()) {
if v < 9 || v > 15 {
log::debug!("unacceptable server_max_window_bits: {}", v);
return Ok(());
}
let mut x = Param::new(SERVER_MAX_WINDOW_BITS);
x.set_value(Some(v.to_string()));
self.params.push(x);
self.our_max_window_bits = v;
} else {
log::debug!("invalid server_max_window_bits: {:?}", p.value());
return Ok(());
}
}
CLIENT_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)),
SERVER_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)),
_ => {
log::debug!("{}: unknown parameter: {}", self.name(), p.name());
return Ok(());
}
}
}
}
Mode::Client => {
let mut server_no_context_takeover = false;
for p in params {
log::trace!("configure client with: {}", p);
match p.name() {
SERVER_NO_CONTEXT_TAKEOVER => server_no_context_takeover = true,
CLIENT_NO_CONTEXT_TAKEOVER => {} SERVER_MAX_WINDOW_BITS => {
let expected = Some(self.their_max_window_bits);
if self.set_their_max_window_bits(&p, expected).is_err() {
return Ok(());
}
}
CLIENT_MAX_WINDOW_BITS => {
if let Some(Ok(v)) = p.value().map(|s| s.parse::<u8>()) {
if v < 8 || v > 15 {
log::debug!("unacceptable client_max_window_bits: {}", v);
return Ok(());
}
use std::cmp::{max, min};
self.our_max_window_bits = min(self.our_max_window_bits, max(9, v));
}
}
_ => {
log::debug!("{}: unknown parameter: {}", self.name(), p.name());
return Ok(());
}
}
}
if !server_no_context_takeover {
log::debug!("{}: server did not confirm no context takeover", self.name());
return Ok(());
}
}
}
self.enabled = true;
Ok(())
}
fn reserved_bits(&self) -> (bool, bool, bool) {
(true, false, false)
}
fn decode(&mut self, header: &mut Header, data: &mut Vec<u8>) -> Result<(), BoxedError> {
if data.is_empty() {
return Ok(());
}
match header.opcode() {
OpCode::Binary | OpCode::Text if header.is_rsv1() => {
if !header.is_fin() {
self.await_last_fragment = true;
log::trace!("deflate: not decoding {}; awaiting last fragment", header);
return Ok(());
}
log::trace!("deflate: decoding {}", header)
}
OpCode::Continue if header.is_fin() && self.await_last_fragment => {
self.await_last_fragment = false;
log::trace!("deflate: decoding {}", header)
}
_ => {
log::trace!("deflate: not decoding {}", header);
return Ok(());
}
}
data.extend_from_slice(&[0, 0, 0xFF, 0xFF]); self.buffer.clear();
let mut decoder = DeflateDecoder::new(&mut self.buffer);
decoder.write_all(&data)?;
decoder.finish()?;
mem::swap(data, &mut self.buffer);
header.set_rsv1(false);
header.set_payload_len(data.len());
Ok(())
}
fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> {
if data.as_ref().is_empty() {
return Ok(());
}
if let OpCode::Binary | OpCode::Text = header.opcode() {
log::trace!("deflate: encoding {}", header)
} else {
log::trace!("deflate: not encoding {}", header);
return Ok(());
}
self.buffer.clear();
self.buffer.reserve(data.as_ref().len());
let mut encoder = Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits);
while encoder.total_in() < as_u64(data.as_ref().len()) {
let i: usize = encoder.total_in().try_into()?;
match encoder.compress_vec(&data.as_ref()[i..], &mut self.buffer, FlushCompress::None)? {
Status::BufError => self.buffer.reserve(4096),
Status::Ok => continue,
Status::StreamEnd => break,
}
}
while !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) {
self.buffer.reserve(5); match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? {
Status::Ok => continue,
Status::BufError => continue, Status::StreamEnd => break,
}
}
if !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) {
log::error!("missing 00 00 FF FF");
return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into());
}
self.buffer.truncate(self.buffer.len() - 4); if let Storage::Owned(d) = data {
mem::swap(d, &mut self.buffer)
} else {
*data = Storage::Owned(mem::take(&mut self.buffer))
}
header.set_rsv1(true);
header.set_payload_len(data.as_ref().len());
Ok(())
}
}