use std::cmp::Ordering;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures_util::future::FutureExt;
use futures_util::stream::{once, FuturesUnordered, Stream, StreamExt};
use smallvec::SmallVec;
use proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
use proto::Time;
use tracing::debug;
use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
use crate::error::{ResolveError, ResolveErrorKind};
#[cfg(feature = "mdns")]
use crate::name_server;
use crate::name_server::{ConnectionProvider, NameServer};
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
use crate::name_server::{TokioConnection, TokioConnectionProvider, TokioHandle};
#[derive(Clone)]
pub struct NameServerPool<
C: DnsHandle<Error = ResolveError> + Send + Sync + 'static,
P: ConnectionProvider<Conn = C> + Send + 'static,
> {
datagram_conns: Arc<[NameServer<C, P>]>, stream_conns: Arc<[NameServer<C, P>]>, #[cfg(feature = "mdns")]
mdns_conns: NameServer<C, P>, options: ResolverOpts,
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
impl NameServerPool<TokioConnection, TokioConnectionProvider> {
pub(crate) fn tokio_from_config(
config: &ResolverConfig,
options: &ResolverOpts,
runtime: TokioHandle,
) -> Self {
Self::from_config_with_provider(config, options, TokioConnectionProvider::new(runtime))
}
}
impl<C, P> NameServerPool<C, P>
where
C: DnsHandle<Error = ResolveError> + Sync + 'static,
P: ConnectionProvider<Conn = C> + 'static,
{
pub(crate) fn from_config_with_provider(
config: &ResolverConfig,
options: &ResolverOpts,
conn_provider: P,
) -> Self {
let datagram_conns: Vec<NameServer<C, P>> = config
.name_servers()
.iter()
.filter(|ns_config| ns_config.protocol.is_datagram())
.map(|ns_config| {
#[cfg(feature = "dns-over-rustls")]
let ns_config = {
let mut ns_config = ns_config.clone();
ns_config.tls_config = config.client_config().clone();
ns_config
};
#[cfg(not(feature = "dns-over-rustls"))]
let ns_config = { ns_config.clone() };
NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
})
.collect();
let stream_conns: Vec<NameServer<C, P>> = config
.name_servers()
.iter()
.filter(|ns_config| ns_config.protocol.is_stream())
.map(|ns_config| {
#[cfg(feature = "dns-over-rustls")]
let ns_config = {
let mut ns_config = ns_config.clone();
ns_config.tls_config = config.client_config().clone();
ns_config
};
#[cfg(not(feature = "dns-over-rustls"))]
let ns_config = { ns_config.clone() };
NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
})
.collect();
Self {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
#[cfg(feature = "mdns")]
mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
options: *options,
}
}
pub fn from_config(
name_servers: NameServerConfigGroup,
options: &ResolverOpts,
conn_provider: P,
) -> Self {
let map_config_to_ns = |ns_config| {
NameServer::<C, P>::new_with_provider(ns_config, *options, conn_provider.clone())
};
let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
.into_inner()
.into_iter()
.partition(|ns| ns.protocol.is_datagram());
let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
Self {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
#[cfg(feature = "mdns")]
mdns_conns: name_server::mdns_nameserver(*options, conn_provider.clone(), false),
options: *options,
}
}
#[doc(hidden)]
#[cfg(not(feature = "mdns"))]
pub fn from_nameservers(
options: &ResolverOpts,
datagram_conns: Vec<NameServer<C, P>>,
stream_conns: Vec<NameServer<C, P>>,
) -> Self {
Self {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
options: *options,
}
}
#[doc(hidden)]
#[cfg(feature = "mdns")]
pub fn from_nameservers(
options: &ResolverOpts,
datagram_conns: Vec<NameServer<C, P>>,
stream_conns: Vec<NameServer<C, P>>,
mdns_conns: NameServer<C, P>,
) -> Self {
NameServerPool {
datagram_conns: Arc::from(datagram_conns),
stream_conns: Arc::from(stream_conns),
mdns_conns,
options: *options,
}
}
#[cfg(test)]
#[cfg(not(feature = "mdns"))]
#[allow(dead_code)]
fn from_nameservers_test(
options: &ResolverOpts,
datagram_conns: Arc<[NameServer<C, P>]>,
stream_conns: Arc<[NameServer<C, P>]>,
) -> Self {
Self {
datagram_conns,
stream_conns,
options: *options,
}
}
#[cfg(test)]
#[cfg(feature = "mdns")]
fn from_nameservers_test(
options: &ResolverOpts,
datagram_conns: Arc<[NameServer<C, P>]>,
stream_conns: Arc<[NameServer<C, P>]>,
mdns_conns: NameServer<C, P>,
) -> Self {
NameServerPool {
datagram_conns,
stream_conns,
mdns_conns,
options: *options,
conn_provider,
}
}
async fn try_send(
opts: ResolverOpts,
conns: Arc<[NameServer<C, P>]>,
request: DnsRequest,
) -> Result<DnsResponse, ResolveError> {
let mut conns: Vec<NameServer<C, P>> = conns.to_vec();
match opts.server_ordering_strategy {
ServerOrderingStrategy::QueryStatistics => conns.sort_unstable(),
ServerOrderingStrategy::UserProvidedOrder => {}
}
let request_loop = request.clone();
parallel_conn_loop(conns, request_loop, opts).await
}
}
impl<C, P> DnsHandle for NameServerPool<C, P>
where
C: DnsHandle<Error = ResolveError> + Sync + 'static,
P: ConnectionProvider<Conn = C> + 'static,
{
type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>;
type Error = ResolveError;
fn send<R: Into<DnsRequest>>(&mut self, request: R) -> Self::Response {
let opts = self.options;
let request = request.into();
let datagram_conns = Arc::clone(&self.datagram_conns);
let stream_conns = Arc::clone(&self.stream_conns);
let tcp_message = request.clone();
#[cfg(feature = "mdns")]
let mdns = mdns::maybe_local(&mut self.mdns_conns, request);
#[cfg(not(feature = "mdns"))]
let mdns = Local::NotMdns(request);
if mdns.is_local() {
return mdns.take_stream();
}
let request = mdns.take_request();
Box::pin(once(async move {
debug!("sending request: {:?}", request.queries());
let udp_res = match Self::try_send(opts, datagram_conns, request).await {
Ok(response) if response.truncated() => {
debug!("truncated response received, retrying over TCP");
Ok(response)
}
Err(e) if opts.try_tcp_on_error || e.is_no_connections() => {
debug!("error from UDP, retrying over TCP: {}", e);
Err(e)
}
result => return result,
};
if stream_conns.is_empty() {
debug!("no TCP connections available");
return udp_res;
}
let tcp_res = Self::try_send(opts, stream_conns, tcp_message).await;
let tcp_err = match tcp_res {
res @ Ok(..) => return res,
Err(e) => e,
};
let udp_err = match udp_res {
Ok(response) => return Ok(response),
Err(e) => e,
};
match udp_err.cmp_specificity(&tcp_err) {
Ordering::Greater => Err(udp_err),
_ => Err(tcp_err),
}
}))
}
}
async fn parallel_conn_loop<C, P>(
mut conns: Vec<NameServer<C, P>>,
request: DnsRequest,
opts: ResolverOpts,
) -> Result<DnsResponse, ResolveError>
where
C: DnsHandle<Error = ResolveError> + 'static,
P: ConnectionProvider<Conn = C> + 'static,
{
let mut err = ResolveError::no_connections();
let mut backoff = Duration::from_millis(20);
let mut busy = SmallVec::<[NameServer<C, P>; 2]>::new();
loop {
let request_cont = request.clone();
let mut par_conns = SmallVec::<[NameServer<C, P>; 2]>::new();
let count = conns.len().min(opts.num_concurrent_reqs.max(1));
for conn in conns.drain(..count) {
par_conns.push(conn);
}
if par_conns.is_empty() {
if !busy.is_empty() && backoff < Duration::from_millis(300) {
P::Time::delay_for(backoff).await;
conns.extend(busy.drain(..));
backoff *= 2;
continue;
}
return Err(err);
}
let mut requests = par_conns
.into_iter()
.map(move |mut conn| {
conn.send(request_cont.clone())
.first_answer()
.map(|result| result.map_err(|e| (conn, e)))
})
.collect::<FuturesUnordered<_>>();
while let Some(result) = requests.next().await {
let (conn, e) = match result {
Ok(sent) => return Ok(sent),
Err((conn, e)) => (conn, e),
};
match e.kind() {
ResolveErrorKind::NoRecordsFound { trusted, .. } if *trusted => {
return Err(e);
}
ResolveErrorKind::Proto(e) if e.is_busy() => {
busy.push(conn);
}
_ if err.cmp_specificity(&e) == Ordering::Less => {
err = e;
}
_ => {}
}
}
}
}
#[cfg(feature = "mdns")]
mod mdns {
use super::*;
use proto::rr::domain::usage;
use proto::DnsHandle;
pub(crate) fn maybe_local<C, P>(
name_server: &mut NameServer<C, P>,
request: DnsRequest,
) -> Local
where
C: DnsHandle<Error = ResolveError> + 'static,
P: ConnectionProvider<Conn = C> + 'static,
P: ConnectionProvider,
{
if request
.queries()
.iter()
.any(|query| usage::LOCAL.name().zone_of(query.name()))
{
Local::ResolveStream(name_server.send(request))
} else {
Local::NotMdns(request)
}
}
}
pub(crate) enum Local {
#[allow(dead_code)]
ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>>),
NotMdns(DnsRequest),
}
impl Local {
fn is_local(&self) -> bool {
matches!(*self, Self::ResolveStream(..))
}
fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ResolveError>> + Send>> {
match self {
Self::ResolveStream(future) => future,
_ => panic!("non Local queries have no future, see take_message()"),
}
}
fn take_request(self) -> DnsRequest {
match self {
Self::NotMdns(request) => request,
_ => panic!("Local queries must be polled, see take_future()"),
}
}
}
impl Stream for Local {
type Item = Result<DnsResponse, ResolveError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match *self {
Self::ResolveStream(ref mut ns) => ns.as_mut().poll_next(cx),
Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), }
}
}
#[cfg(test)]
#[cfg(feature = "tokio-runtime")]
mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use tokio::runtime::Runtime;
use proto::op::Query;
use proto::rr::{Name, RecordType};
use proto::xfer::{DnsHandle, DnsRequestOptions};
use trust_dns_proto::rr::RData;
use super::*;
use crate::config::NameServerConfig;
use crate::config::Protocol;
#[ignore]
#[test]
fn test_failed_then_success_pool() {
let config1 = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
protocol: Protocol::Udp,
tls_dns_name: None,
trust_nx_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let config2 = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
protocol: Protocol::Udp,
tls_dns_name: None,
trust_nx_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let mut resolver_config = ResolverConfig::new();
resolver_config.add_name_server(config1);
resolver_config.add_name_server(config2);
let io_loop = Runtime::new().unwrap();
let mut pool = NameServerPool::<_, TokioConnectionProvider>::tokio_from_config(
&resolver_config,
&ResolverOpts::default(),
TokioHandle,
);
let name = Name::parse("www.example.com.", None).unwrap();
for i in 0..2 {
assert!(
io_loop
.block_on(
pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default()
)
.first_answer()
)
.is_err(),
"iter: {}",
i
);
}
for i in 0..10 {
assert!(
io_loop
.block_on(
pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default()
)
.first_answer()
)
.is_ok(),
"iter: {}",
i
);
}
}
#[test]
fn test_multi_use_conns() {
let io_loop = Runtime::new().unwrap();
let conn_provider = TokioConnectionProvider::new(TokioHandle);
let tcp = NameServerConfig {
socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
protocol: Protocol::Tcp,
tls_dns_name: None,
trust_nx_responses: false,
#[cfg(feature = "dns-over-rustls")]
tls_config: None,
bind_addr: None,
};
let opts = ResolverOpts {
try_tcp_on_error: true,
..ResolverOpts::default()
};
let ns_config = { tcp };
let name_server = NameServer::new_with_provider(ns_config, opts, conn_provider);
let name_servers: Arc<[_]> = Arc::from([name_server]);
let mut pool = NameServerPool::from_nameservers_test(
&opts,
Arc::from([]),
Arc::clone(&name_servers),
#[cfg(feature = "mdns")]
name_server::mdns_nameserver(opts, TokioConnectionProvider::new(TokioHandle)),
);
let name = Name::from_str("www.example.com.").unwrap();
let response = io_loop
.block_on(
pool.lookup(
Query::query(name.clone(), RecordType::A),
DnsRequestOptions::default(),
)
.first_answer(),
)
.expect("lookup failed");
assert_eq!(
*response.answers()[0]
.data()
.and_then(RData::as_a)
.expect("no a record available"),
Ipv4Addr::new(93, 184, 216, 34)
);
assert!(
name_servers[0].is_connected(),
"if this is failing then the NameServers aren't being properly shared."
);
let response = io_loop
.block_on(
pool.lookup(
Query::query(name, RecordType::AAAA),
DnsRequestOptions::default(),
)
.first_answer(),
)
.expect("lookup failed");
assert_eq!(
*response.answers()[0]
.data()
.and_then(RData::as_aaaa)
.expect("no aaaa record available"),
Ipv6Addr::new(0x2606, 0x2800, 0x0220, 0x0001, 0x0248, 0x1893, 0x25c8, 0x1946)
);
assert!(
name_servers[0].is_connected(),
"if this is failing then the NameServers aren't being properly shared."
);
}
}