123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- use crate::config::get_config;
- use crate::errors::Error;
- use arc_swap::ArcSwap;
- use log::{debug, error, info, warn};
- use once_cell::sync::Lazy;
- use std::collections::{HashMap, HashSet};
- use std::io;
- use std::net::IpAddr;
- use std::sync::Arc;
- use std::sync::RwLock;
- use tokio::time::{sleep, Duration};
- use trust_dns_resolver::error::{ResolveError, ResolveResult};
- use trust_dns_resolver::lookup_ip::LookupIp;
- use trust_dns_resolver::TokioAsyncResolver;
- /// Cached Resolver Globally available
- pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
- Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
- // Ip addressed are returned as a set of addresses
- // so we can compare.
- #[derive(Clone, PartialEq, Debug)]
- pub struct AddrSet {
- set: HashSet<IpAddr>,
- }
- impl AddrSet {
- fn new() -> AddrSet {
- AddrSet {
- set: HashSet::new(),
- }
- }
- }
- impl From<LookupIp> for AddrSet {
- fn from(lookup_ip: LookupIp) -> Self {
- let mut addr_set = AddrSet::new();
- for address in lookup_ip.iter() {
- addr_set.set.insert(address);
- }
- addr_set
- }
- }
- ///
- /// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
- ///
- /// The system works as follows:
- ///
- /// When a host is to be resolved, if we have not resolved it before, a new resolution is
- /// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
- /// cache is refreshed.
- ///
- /// # Example:
- ///
- /// ```
- /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
- ///
- /// # tokio_test::block_on(async {
- /// let config = CachedResolverConfig::default();
- /// let resolver = CachedResolver::new(config, None).await.unwrap();
- /// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
- /// # })
- /// ```
- ///
- /// // Now the ip resolution is stored in local cache and subsequent
- /// // calls will be returned from cache. Also, the cache is refreshed
- /// // and updated every 10 seconds.
- ///
- /// // You can now check if an 'old' lookup differs from what it's currently
- /// // store in cache by using `has_changed`.
- /// resolver.has_changed("www.example.com.", addrset)
- #[derive(Default)]
- pub struct CachedResolver {
- // The configuration of the cached_resolver.
- config: CachedResolverConfig,
- // This is the hash that contains the hash.
- data: Option<RwLock<HashMap<String, AddrSet>>>,
- // The resolver to be used for DNS queries.
- resolver: Option<TokioAsyncResolver>,
- // The RefreshLoop
- refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
- }
- ///
- /// Configuration
- #[derive(Clone, Debug, Default, PartialEq)]
- pub struct CachedResolverConfig {
- /// Amount of time in secods that a resolved dns address is considered stale.
- dns_max_ttl: u64,
- /// Enabled or disabled? (this is so we can reload config)
- enabled: bool,
- }
- impl CachedResolverConfig {
- fn new(dns_max_ttl: u64, enabled: bool) -> Self {
- CachedResolverConfig {
- dns_max_ttl,
- enabled,
- }
- }
- }
- impl From<crate::config::Config> for CachedResolverConfig {
- fn from(config: crate::config::Config) -> Self {
- CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
- }
- }
- impl CachedResolver {
- ///
- /// Returns a new Arc<CachedResolver> based on passed configuration.
- /// It also starts the loop that will refresh cache entries.
- ///
- /// # Arguments:
- ///
- /// * `config` - The `CachedResolverConfig` to be used to create the resolver.
- ///
- /// # Example:
- ///
- /// ```
- /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
- ///
- /// # tokio_test::block_on(async {
- /// let config = CachedResolverConfig::default();
- /// let resolver = CachedResolver::new(config, None).await.unwrap();
- /// # })
- /// ```
- ///
- pub async fn new(
- config: CachedResolverConfig,
- data: Option<HashMap<String, AddrSet>>,
- ) -> Result<Arc<Self>, io::Error> {
- // Construct a new Resolver with default configuration options
- let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
- let data = if let Some(hash) = data {
- Some(RwLock::new(hash))
- } else {
- Some(RwLock::new(HashMap::new()))
- };
- let instance = Arc::new(Self {
- config,
- resolver,
- data,
- refresh_loop: RwLock::new(None),
- });
- if instance.enabled() {
- info!("Scheduling DNS refresh loop");
- let refresh_loop = tokio::task::spawn({
- let instance = instance.clone();
- async move {
- instance.refresh_dns_entries_loop().await;
- }
- });
- *(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
- }
- Ok(instance)
- }
- pub fn enabled(&self) -> bool {
- self.config.enabled
- }
- // Schedules the refresher
- async fn refresh_dns_entries_loop(&self) {
- let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
- let interval = Duration::from_secs(self.config.dns_max_ttl);
- loop {
- debug!("Begin refreshing cached DNS addresses.");
- // To minimize the time we hold the lock, we first create
- // an array with keys.
- let mut hostnames: Vec<String> = Vec::new();
- {
- if let Some(ref data) = self.data {
- for hostname in data.read().unwrap().keys() {
- hostnames.push(hostname.clone());
- }
- }
- }
- for hostname in hostnames.iter() {
- let addrset = self
- .fetch_from_cache(hostname.as_str())
- .expect("Could not obtain expected address from cache, this should not happen");
- match resolver.lookup_ip(hostname).await {
- Ok(lookup_ip) => {
- let new_addrset = AddrSet::from(lookup_ip);
- debug!(
- "Obtained address for host ({}) -> ({:?})",
- hostname, new_addrset
- );
- if addrset != new_addrset {
- debug!(
- "Addr changed from {:?} to {:?} updating cache.",
- addrset, new_addrset
- );
- self.store_in_cache(hostname, new_addrset);
- }
- }
- Err(err) => {
- error!(
- "There was an error trying to resolv {}: ({}).",
- hostname, err
- );
- }
- }
- }
- debug!("Finished refreshing cached DNS addresses.");
- sleep(interval).await;
- }
- }
- /// Returns a `AddrSet` given the specified hostname.
- ///
- /// This method first tries to fetch the value from the cache, if it misses
- /// then it is resolved and stored in the cache. TTL from records is ignored.
- ///
- /// # Arguments
- ///
- /// * `host` - A string slice referencing the hostname to be resolved.
- ///
- /// # Example:
- ///
- /// ```
- /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
- ///
- /// # tokio_test::block_on(async {
- /// let config = CachedResolverConfig::default();
- /// let resolver = CachedResolver::new(config, None).await.unwrap();
- /// let response = resolver.lookup_ip("www.google.com.");
- /// # })
- /// ```
- ///
- pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
- debug!("Lookup up {} in cache", host);
- match self.fetch_from_cache(host) {
- Some(addr_set) => {
- debug!("Cache hit!");
- Ok(addr_set)
- }
- None => {
- debug!("Not found, executing a dns query!");
- if let Some(ref resolver) = self.resolver {
- let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
- debug!("Obtained: {:?}", addr_set);
- self.store_in_cache(host, addr_set.clone());
- Ok(addr_set)
- } else {
- Err(ResolveError::from("No resolver available"))
- }
- }
- }
- }
- //
- // Returns true if the stored host resolution differs from the AddrSet passed.
- pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
- if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
- return fetched_addr_set != *addr_set;
- }
- false
- }
- // Fetches an AddrSet from the inner cache adquiring the read lock.
- fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
- if let Some(ref hash) = self.data {
- if let Some(addr_set) = hash.read().unwrap().get(key) {
- return Some(addr_set.clone());
- }
- }
- None
- }
- // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
- // cache.
- pub async fn from_config() -> Result<(), Error> {
- let cached_resolver = CACHED_RESOLVER.load();
- let desired_config = CachedResolverConfig::from(get_config());
- if cached_resolver.config != desired_config {
- if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
- warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
- refresh_loop.abort()
- }
- let new_resolver = if let Some(ref data) = cached_resolver.data {
- let data = Some(data.read().unwrap().clone());
- CachedResolver::new(desired_config, data).await
- } else {
- CachedResolver::new(desired_config, None).await
- };
- match new_resolver {
- Ok(ok) => {
- CACHED_RESOLVER.store(ok);
- Ok(())
- }
- Err(err) => {
- let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
- Err(Error::DNSCachedError(message))
- }
- }
- } else {
- Ok(())
- }
- }
- // Stores the AddrSet in cache adquiring the write lock.
- fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
- if let Some(ref data) = self.data {
- data.write().unwrap().insert(host.to_string(), addr_set);
- } else {
- error!("Could not insert, Hash not initialized");
- }
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use trust_dns_resolver::error::ResolveError;
- #[tokio::test]
- async fn new() {
- let config = CachedResolverConfig {
- dns_max_ttl: 10,
- enabled: true,
- };
- let resolver = CachedResolver::new(config, None).await;
- assert!(resolver.is_ok());
- }
- #[tokio::test]
- async fn lookup_ip() {
- let config = CachedResolverConfig {
- dns_max_ttl: 10,
- enabled: true,
- };
- let resolver = CachedResolver::new(config, None).await.unwrap();
- let response = resolver.lookup_ip("www.google.com.").await;
- assert!(response.is_ok());
- }
- #[tokio::test]
- async fn has_changed() {
- let config = CachedResolverConfig {
- dns_max_ttl: 10,
- enabled: true,
- };
- let resolver = CachedResolver::new(config, None).await.unwrap();
- let hostname = "www.google.com.";
- let response = resolver.lookup_ip(hostname).await;
- let addr_set = response.unwrap();
- assert!(!resolver.has_changed(hostname, &addr_set));
- }
- #[tokio::test]
- async fn unknown_host() {
- let config = CachedResolverConfig {
- dns_max_ttl: 10,
- enabled: true,
- };
- let resolver = CachedResolver::new(config, None).await.unwrap();
- let hostname = "www.idontexists.";
- let response = resolver.lookup_ip(hostname).await;
- assert!(matches!(response, Err(ResolveError { .. })));
- }
- #[tokio::test]
- async fn incorrect_address() {
- let config = CachedResolverConfig {
- dns_max_ttl: 10,
- enabled: true,
- };
- let resolver = CachedResolver::new(config, None).await.unwrap();
- let hostname = "w ww.idontexists.";
- let response = resolver.lookup_ip(hostname).await;
- assert!(matches!(response, Err(ResolveError { .. })));
- assert!(!resolver.has_changed(hostname, &AddrSet::new()));
- }
- #[tokio::test]
- // Ok, this test is based on the fact that google does DNS RR
- // and does not responds with every available ip everytime, so
- // if I cache here, it will miss after one cache iteration or two.
- async fn thread() {
- let config = CachedResolverConfig {
- dns_max_ttl: 10,
- enabled: true,
- };
- let resolver = CachedResolver::new(config, None).await.unwrap();
- let hostname = "www.google.com.";
- let response = resolver.lookup_ip(hostname).await;
- let addr_set = response.unwrap();
- assert!(!resolver.has_changed(hostname, &addr_set));
- let resolver_for_refresher = resolver.clone();
- let _thread_handle = tokio::task::spawn(async move {
- resolver_for_refresher.refresh_dns_entries_loop().await;
- });
- assert!(!resolver.has_changed(hostname, &addr_set));
- }
- }
|