#![deny(missing_docs)]
use std::{fmt, io, sync};
#[cfg(not(feature = "openssl"))]
use self::not_openssl::SslErrorStack;
#[cfg(not(feature = "ring"))]
use self::not_ring::Unspecified;
#[cfg(feature = "backtrace")]
#[cfg_attr(docsrs, doc(cfg(feature = "backtrace")))]
pub use backtrace::Backtrace as ExtBacktrace;
use enum_as_inner::EnumAsInner;
#[cfg(feature = "backtrace")]
use lazy_static::lazy_static;
#[cfg(feature = "openssl")]
use openssl::error::ErrorStack as SslErrorStack;
#[cfg(feature = "ring")]
use ring::error::Unspecified;
use thiserror::Error;
use crate::op::Header;
use crate::rr::{Name, RecordType};
use crate::serialize::binary::DecodeError;
#[cfg(feature = "backtrace")]
#[cfg_attr(docsrs, doc(cfg(feature = "backtrace")))]
lazy_static! {
pub static ref ENABLE_BACKTRACE: bool = {
use std::env;
let bt = env::var("RUST_BACKTRACE");
matches!(bt.as_ref().map(|s| s as &str), Ok("full") | Ok("1"))
};
}
#[cfg(feature = "backtrace")]
#[cfg_attr(docsrs, doc(cfg(feature = "backtrace")))]
#[macro_export]
macro_rules! trace {
() => {{
use $crate::error::ExtBacktrace as Backtrace;
if *$crate::error::ENABLE_BACKTRACE {
Some(Backtrace::new())
} else {
None
}
}};
}
pub type ProtoResult<T> = ::std::result::Result<T, ProtoError>;
#[derive(Debug, EnumAsInner, Error)]
#[non_exhaustive]
pub enum ProtoErrorKind {
#[error("there should only be one query per request, got: {0}")]
BadQueryCount(usize),
#[error("resource too busy")]
Busy,
#[error("future was canceled: {0:?}")]
Canceled(futures_channel::oneshot::Canceled),
#[error("char data length exceeds {max}: {len}")]
CharacterDataTooLong {
max: usize,
len: usize,
},
#[error("overlapping labels name {label} other {other}")]
LabelOverlapsWithOther {
label: usize,
other: usize,
},
#[error("dns key value unknown, must be 3: {0}")]
DnsKeyProtocolNot3(u8),
#[error("name label data exceed 255: {0}")]
DomainNameTooLong(usize),
#[error("edns resource record label must be the root label (.): {0}")]
EdnsNameNotRoot(crate::rr::Name),
#[error("message format error: {error}")]
FormError {
header: Header,
error: Box<ProtoError>,
},
#[error("hmac validation failure")]
HmacInvalid(),
#[error("incorrect rdata length read: {read} expected: {len}")]
IncorrectRDataLengthRead {
read: usize,
len: usize,
},
#[error("label bytes exceed 63: {0}")]
LabelBytesTooLong(usize),
#[error("label points to data not prior to idx: {idx} ptr: {ptr}")]
PointerNotPriorToLabel {
idx: usize,
ptr: u16,
},
#[error("maximum buffer size exceeded: {0}")]
MaxBufferSizeExceeded(usize),
#[error("{0}")]
Message(&'static str),
#[error("{0}")]
Msg(String),
#[error("no error specified")]
NoError,
#[error("not all records could be written, wrote: {count}")]
NotAllRecordsWritten {
count: usize,
},
#[error("rrsigs are not present for record set name: {name} record_type: {record_type}")]
RrsigsNotPresent {
name: Name,
record_type: RecordType,
},
#[error("algorithm type value unknown: {0}")]
UnknownAlgorithmTypeValue(u8),
#[error("dns class string unknown: {0}")]
UnknownDnsClassStr(String),
#[error("dns class value unknown: {0}")]
UnknownDnsClassValue(u16),
#[error("record type string unknown: {0}")]
UnknownRecordTypeStr(String),
#[error("record type value unknown: {0}")]
UnknownRecordTypeValue(u16),
#[error("unrecognized label code: {0:b}")]
UnrecognizedLabelCode(u8),
#[error("nsec3 flags should be 0b0000000*: {0:b}")]
UnrecognizedNsec3Flags(u8),
#[error("csync flags should be 0b000000**: {0:b}")]
UnrecognizedCsyncFlags(u16),
#[error("io error: {0}")]
Io(io::Error),
#[error("lock poisoned error")]
Poisoned,
#[error("ring error: {0}")]
Ring(#[from] Unspecified),
#[error("ssl error: {0}")]
SSL(#[from] SslErrorStack),
#[error("timer error")]
Timer,
#[error("request timed out")]
Timeout,
#[error("url parsing error")]
UrlParsing(#[from] url::ParseError),
#[error("error parsing utf8 string")]
Utf8(#[from] std::str::Utf8Error),
#[error("error parsing utf8 string")]
FromUtf8(#[from] std::string::FromUtf8Error),
#[error("error parsing int")]
ParseInt(#[from] std::num::ParseIntError),
#[cfg(feature = "quinn")]
#[error("error creating quic connection: {0}")]
QuinnConnect(#[from] quinn::ConnectError),
#[cfg(feature = "quinn")]
#[error("error with quic connection: {0}")]
QuinnConnection(#[from] quinn::ConnectionError),
#[cfg(feature = "quinn")]
#[error("error writing to quic connection: {0}")]
QuinnWriteError(#[from] quinn::WriteError),
#[cfg(feature = "quinn")]
#[error("error writing to quic read: {0}")]
QuinnReadError(#[from] quinn::ReadExactError),
#[cfg(feature = "quinn")]
#[error("error constructing quic configuration: {0}")]
QuinnConfigError(#[from] quinn::ConfigError),
#[cfg(feature = "quinn")]
#[error("an unknown quic stream was used")]
QuinnUnknownStreamError,
#[cfg(feature = "quinn")]
#[error("quic messages should always be 0, got: {0}")]
QuicMessageIdNot0(u16),
#[cfg(feature = "rustls")]
#[error("rustls construction error: {0}")]
RustlsError(#[from] rustls::Error),
}
#[derive(Error, Clone, Debug)]
#[non_exhaustive]
pub struct ProtoError {
pub kind: Box<ProtoErrorKind>,
#[cfg(feature = "backtrace")]
pub backtrack: Option<ExtBacktrace>,
}
impl ProtoError {
pub fn kind(&self) -> &ProtoErrorKind {
&self.kind
}
pub fn is_busy(&self) -> bool {
matches!(*self.kind, ProtoErrorKind::Busy)
}
}
impl fmt::Display for ProtoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
cfg_if::cfg_if! {
if #[cfg(feature = "backtrace")] {
if let Some(ref backtrace) = self.backtrack {
fmt::Display::fmt(&self.kind, f)?;
fmt::Debug::fmt(backtrace, f)
} else {
fmt::Display::fmt(&self.kind, f)
}
} else {
fmt::Display::fmt(&self.kind, f)
}
}
}
}
impl<E> From<E> for ProtoError
where
E: Into<ProtoErrorKind>,
{
fn from(error: E) -> Self {
let kind: ProtoErrorKind = error.into();
Self {
kind: Box::new(kind),
#[cfg(feature = "backtrace")]
backtrack: trace!(),
}
}
}
impl From<DecodeError> for ProtoError {
fn from(err: DecodeError) -> Self {
match err {
DecodeError::PointerNotPriorToLabel { idx, ptr } => {
ProtoErrorKind::PointerNotPriorToLabel { idx, ptr }
}
DecodeError::LabelBytesTooLong(len) => ProtoErrorKind::LabelBytesTooLong(len),
DecodeError::UnrecognizedLabelCode(code) => ProtoErrorKind::UnrecognizedLabelCode(code),
DecodeError::DomainNameTooLong(len) => ProtoErrorKind::DomainNameTooLong(len),
DecodeError::LabelOverlapsWithOther { label, other } => {
ProtoErrorKind::LabelOverlapsWithOther { label, other }
}
_ => ProtoErrorKind::Msg(err.to_string()),
}
.into()
}
}
impl From<&'static str> for ProtoError {
fn from(msg: &'static str) -> Self {
ProtoErrorKind::Message(msg).into()
}
}
impl From<String> for ProtoError {
fn from(msg: String) -> Self {
ProtoErrorKind::Msg(msg).into()
}
}
impl From<io::Error> for ProtoErrorKind {
fn from(e: io::Error) -> Self {
match e.kind() {
io::ErrorKind::TimedOut => Self::Timeout,
_ => Self::Io(e),
}
}
}
impl<T> From<sync::PoisonError<T>> for ProtoError {
fn from(_e: sync::PoisonError<T>) -> Self {
ProtoErrorKind::Poisoned.into()
}
}
#[cfg(not(feature = "openssl"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "openssl"))))]
pub mod not_openssl {
use std;
#[derive(Clone, Copy, Debug)]
pub struct SslErrorStack;
impl std::fmt::Display for SslErrorStack {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
Ok(())
}
}
impl std::error::Error for SslErrorStack {
fn description(&self) -> &str {
"openssl feature not enabled"
}
}
}
#[cfg(not(feature = "ring"))]
#[cfg_attr(docsrs, doc(cfg(not(feature = "ring"))))]
pub mod not_ring {
use std;
#[derive(Clone, Copy, Debug)]
pub struct Unspecified;
impl std::fmt::Display for Unspecified {
fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
Ok(())
}
}
impl std::error::Error for Unspecified {
fn description(&self) -> &str {
"ring feature not enabled"
}
}
}
impl From<ProtoError> for io::Error {
fn from(e: ProtoError) -> Self {
match *e.kind() {
ProtoErrorKind::Timeout => Self::new(io::ErrorKind::TimedOut, e),
_ => Self::new(io::ErrorKind::Other, e),
}
}
}
impl From<ProtoError> for String {
fn from(e: ProtoError) -> Self {
e.to_string()
}
}
#[cfg(feature = "wasm-bindgen")]
#[cfg_attr(docsrs, doc(cfg(feature = "wasm-bindgen")))]
impl From<ProtoError> for wasm_bindgen_crate::JsValue {
fn from(e: ProtoError) -> Self {
js_sys::Error::new(&e.to_string()).into()
}
}
impl Clone for ProtoErrorKind {
fn clone(&self) -> Self {
use self::ProtoErrorKind::*;
match *self {
BadQueryCount(count) => BadQueryCount(count),
Busy => Busy,
Canceled(ref c) => Canceled(*c),
CharacterDataTooLong { max, len } => CharacterDataTooLong { max, len },
LabelOverlapsWithOther { label, other } => LabelOverlapsWithOther { label, other },
DnsKeyProtocolNot3(protocol) => DnsKeyProtocolNot3(protocol),
DomainNameTooLong(len) => DomainNameTooLong(len),
EdnsNameNotRoot(ref found) => EdnsNameNotRoot(found.clone()),
FormError { header, ref error } => FormError {
header,
error: error.clone(),
},
HmacInvalid() => HmacInvalid(),
IncorrectRDataLengthRead { read, len } => IncorrectRDataLengthRead { read, len },
LabelBytesTooLong(len) => LabelBytesTooLong(len),
PointerNotPriorToLabel { idx, ptr } => PointerNotPriorToLabel { idx, ptr },
MaxBufferSizeExceeded(max) => MaxBufferSizeExceeded(max),
Message(msg) => Message(msg),
Msg(ref msg) => Msg(msg.clone()),
NoError => NoError,
NotAllRecordsWritten { count } => NotAllRecordsWritten { count },
RrsigsNotPresent {
ref name,
ref record_type,
} => RrsigsNotPresent {
name: name.clone(),
record_type: *record_type,
},
UnknownAlgorithmTypeValue(value) => UnknownAlgorithmTypeValue(value),
UnknownDnsClassStr(ref value) => UnknownDnsClassStr(value.clone()),
UnknownDnsClassValue(value) => UnknownDnsClassValue(value),
UnknownRecordTypeStr(ref value) => UnknownRecordTypeStr(value.clone()),
UnknownRecordTypeValue(value) => UnknownRecordTypeValue(value),
UnrecognizedLabelCode(value) => UnrecognizedLabelCode(value),
UnrecognizedNsec3Flags(flags) => UnrecognizedNsec3Flags(flags),
UnrecognizedCsyncFlags(flags) => UnrecognizedCsyncFlags(flags),
Io(ref e) => Io(if let Some(raw) = e.raw_os_error() {
io::Error::from_raw_os_error(raw)
} else {
io::Error::from(e.kind())
}),
Poisoned => Poisoned,
Ring(ref _e) => Ring(Unspecified),
SSL(ref e) => Msg(format!("there was an SSL error: {}", e)),
Timeout => Timeout,
Timer => Timer,
UrlParsing(ref e) => UrlParsing(*e),
Utf8(ref e) => Utf8(*e),
FromUtf8(ref e) => FromUtf8(e.clone()),
ParseInt(ref e) => ParseInt(e.clone()),
#[cfg(feature = "quinn")]
QuinnConnect(ref e) => QuinnConnect(e.clone()),
#[cfg(feature = "quinn")]
QuinnConnection(ref e) => QuinnConnection(e.clone()),
#[cfg(feature = "quinn")]
QuinnWriteError(ref e) => QuinnWriteError(e.clone()),
#[cfg(feature = "quinn")]
QuicMessageIdNot0(val) => QuicMessageIdNot0(val),
#[cfg(feature = "quinn")]
QuinnReadError(ref e) => QuinnReadError(e.clone()),
#[cfg(feature = "quinn")]
QuinnConfigError(ref e) => QuinnConfigError(e.clone()),
#[cfg(feature = "quinn")]
QuinnUnknownStreamError => QuinnUnknownStreamError,
#[cfg(feature = "rustls")]
RustlsError(ref e) => RustlsError(e.clone()),
}
}
}
pub trait FromProtoError: From<ProtoError> + std::error::Error + Clone {}
impl<E> FromProtoError for E where E: From<ProtoError> + std::error::Error + Clone {}