| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 | use crate::config::get_config;use crate::errors::Error;use arc_swap::ArcSwap;use log::{debug, error, info, warn};use once_cell::sync::Lazy;use std::collections::{HashMap, HashSet};use std::io;use std::net::IpAddr;use std::sync::Arc;use std::sync::RwLock;use tokio::time::{sleep, Duration};use trust_dns_resolver::error::{ResolveError, ResolveResult};use trust_dns_resolver::lookup_ip::LookupIp;use trust_dns_resolver::TokioAsyncResolver;/// Cached Resolver Globally availablepub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =    Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));// Ip addressed are returned as a set of addresses// so we can compare.#[derive(Clone, PartialEq, Debug)]pub struct AddrSet {    set: HashSet<IpAddr>,}impl AddrSet {    fn new() -> AddrSet {        AddrSet {            set: HashSet::new(),        }    }}impl From<LookupIp> for AddrSet {    fn from(lookup_ip: LookupIp) -> Self {        let mut addr_set = AddrSet::new();        for address in lookup_ip.iter() {            addr_set.set.insert(address);        }        addr_set    }}////// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.////// The system works as follows:////// When a host is to be resolved, if we have not resolved it before, a new resolution is/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the/// cache is refreshed.////// # Example:////// ```/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};////// # tokio_test::block_on(async {/// let config = CachedResolverConfig::default();/// let resolver = CachedResolver::new(config, None).await.unwrap();/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();/// # })/// ```////// // Now the ip resolution is stored in local cache and subsequent/// // calls will be returned from cache. Also, the cache is refreshed/// // and updated every 10 seconds.////// // You can now check if an 'old' lookup differs from what it's currently/// // store in cache by using `has_changed`./// resolver.has_changed("www.example.com.", addrset)#[derive(Default)]pub struct CachedResolver {    // The configuration of the cached_resolver.    config: CachedResolverConfig,    // This is the hash that contains the hash.    data: Option<RwLock<HashMap<String, AddrSet>>>,    // The resolver to be used for DNS queries.    resolver: Option<TokioAsyncResolver>,    // The RefreshLoop    refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,}////// Configuration#[derive(Clone, Debug, Default, PartialEq)]pub struct CachedResolverConfig {    /// Amount of time in secods that a resolved dns address is considered stale.    dns_max_ttl: u64,    /// Enabled or disabled? (this is so we can reload config)    enabled: bool,}impl CachedResolverConfig {    fn new(dns_max_ttl: u64, enabled: bool) -> Self {        CachedResolverConfig {            dns_max_ttl,            enabled,        }    }}impl From<crate::config::Config> for CachedResolverConfig {    fn from(config: crate::config::Config) -> Self {        CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)    }}impl CachedResolver {    ///    /// Returns a new Arc<CachedResolver> based on passed configuration.    /// It also starts the loop that will refresh cache entries.    ///    /// # Arguments:    ///    /// * `config` - The `CachedResolverConfig` to be used to create the resolver.    ///    /// # Example:    ///    /// ```    /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};    ///    /// # tokio_test::block_on(async {    /// let config = CachedResolverConfig::default();    /// let resolver = CachedResolver::new(config, None).await.unwrap();    /// # })    /// ```    ///    pub async fn new(        config: CachedResolverConfig,        data: Option<HashMap<String, AddrSet>>,    ) -> Result<Arc<Self>, io::Error> {        // Construct a new Resolver with default configuration options        let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);        let data = if let Some(hash) = data {            Some(RwLock::new(hash))        } else {            Some(RwLock::new(HashMap::new()))        };        let instance = Arc::new(Self {            config,            resolver,            data,            refresh_loop: RwLock::new(None),        });        if instance.enabled() {            info!("Scheduling DNS refresh loop");            let refresh_loop = tokio::task::spawn({                let instance = instance.clone();                async move {                    instance.refresh_dns_entries_loop().await;                }            });            *(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);        }        Ok(instance)    }    pub fn enabled(&self) -> bool {        self.config.enabled    }    // Schedules the refresher    async fn refresh_dns_entries_loop(&self) {        let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();        let interval = Duration::from_secs(self.config.dns_max_ttl);        loop {            debug!("Begin refreshing cached DNS addresses.");            // To minimize the time we hold the lock, we first create            // an array with keys.            let mut hostnames: Vec<String> = Vec::new();            {                if let Some(ref data) = self.data {                    for hostname in data.read().unwrap().keys() {                        hostnames.push(hostname.clone());                    }                }            }            for hostname in hostnames.iter() {                let addrset = self                    .fetch_from_cache(hostname.as_str())                    .expect("Could not obtain expected address from cache, this should not happen");                match resolver.lookup_ip(hostname).await {                    Ok(lookup_ip) => {                        let new_addrset = AddrSet::from(lookup_ip);                        debug!(                            "Obtained address for host ({}) -> ({:?})",                            hostname, new_addrset                        );                        if addrset != new_addrset {                            debug!(                                "Addr changed from {:?} to {:?} updating cache.",                                addrset, new_addrset                            );                            self.store_in_cache(hostname, new_addrset);                        }                    }                    Err(err) => {                        error!(                            "There was an error trying to resolv {}: ({}).",                            hostname, err                        );                    }                }            }            debug!("Finished refreshing cached DNS addresses.");            sleep(interval).await;        }    }    /// Returns a `AddrSet` given the specified hostname.    ///    /// This method first tries to fetch the value from the cache, if it misses    /// then it is resolved and stored in the cache. TTL from records is ignored.    ///    /// # Arguments    ///    /// * `host`      - A string slice referencing the hostname to be resolved.    ///    /// # Example:    ///    /// ```    /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};    ///    /// # tokio_test::block_on(async {    /// let config = CachedResolverConfig::default();    /// let resolver = CachedResolver::new(config, None).await.unwrap();    /// let response = resolver.lookup_ip("www.google.com.");    /// # })    /// ```    ///    pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {        debug!("Lookup up {} in cache", host);        match self.fetch_from_cache(host) {            Some(addr_set) => {                debug!("Cache hit!");                Ok(addr_set)            }            None => {                debug!("Not found, executing a dns query!");                if let Some(ref resolver) = self.resolver {                    let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);                    debug!("Obtained: {:?}", addr_set);                    self.store_in_cache(host, addr_set.clone());                    Ok(addr_set)                } else {                    Err(ResolveError::from("No resolver available"))                }            }        }    }    //    // Returns true if the stored host resolution differs from the AddrSet passed.    pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {        if let Some(fetched_addr_set) = self.fetch_from_cache(host) {            return fetched_addr_set != *addr_set;        }        false    }    // Fetches an AddrSet from the inner cache adquiring the read lock.    fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {        if let Some(ref hash) = self.data {            if let Some(addr_set) = hash.read().unwrap().get(key) {                return Some(addr_set.clone());            }        }        None    }    // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS    // cache.    pub async fn from_config() -> Result<(), Error> {        let cached_resolver = CACHED_RESOLVER.load();        let desired_config = CachedResolverConfig::from(get_config());        if cached_resolver.config != desired_config {            if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {                warn!("Killing Dnscache refresh loop as its configuration is being reloaded");                refresh_loop.abort()            }            let new_resolver = if let Some(ref data) = cached_resolver.data {                let data = Some(data.read().unwrap().clone());                CachedResolver::new(desired_config, data).await            } else {                CachedResolver::new(desired_config, None).await            };            match new_resolver {                Ok(ok) => {                    CACHED_RESOLVER.store(ok);                    Ok(())                }                Err(err) => {                    let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);                    Err(Error::DNSCachedError(message))                }            }        } else {            Ok(())        }    }    // Stores the AddrSet in cache adquiring the write lock.    fn store_in_cache(&self, host: &str, addr_set: AddrSet) {        if let Some(ref data) = self.data {            data.write().unwrap().insert(host.to_string(), addr_set);        } else {            error!("Could not insert, Hash not initialized");        }    }}#[cfg(test)]mod tests {    use super::*;    use trust_dns_resolver::error::ResolveError;    #[tokio::test]    async fn new() {        let config = CachedResolverConfig {            dns_max_ttl: 10,            enabled: true,        };        let resolver = CachedResolver::new(config, None).await;        assert!(resolver.is_ok());    }    #[tokio::test]    async fn lookup_ip() {        let config = CachedResolverConfig {            dns_max_ttl: 10,            enabled: true,        };        let resolver = CachedResolver::new(config, None).await.unwrap();        let response = resolver.lookup_ip("www.google.com.").await;        assert!(response.is_ok());    }    #[tokio::test]    async fn has_changed() {        let config = CachedResolverConfig {            dns_max_ttl: 10,            enabled: true,        };        let resolver = CachedResolver::new(config, None).await.unwrap();        let hostname = "www.google.com.";        let response = resolver.lookup_ip(hostname).await;        let addr_set = response.unwrap();        assert!(!resolver.has_changed(hostname, &addr_set));    }    #[tokio::test]    async fn unknown_host() {        let config = CachedResolverConfig {            dns_max_ttl: 10,            enabled: true,        };        let resolver = CachedResolver::new(config, None).await.unwrap();        let hostname = "www.idontexists.";        let response = resolver.lookup_ip(hostname).await;        assert!(matches!(response, Err(ResolveError { .. })));    }    #[tokio::test]    async fn incorrect_address() {        let config = CachedResolverConfig {            dns_max_ttl: 10,            enabled: true,        };        let resolver = CachedResolver::new(config, None).await.unwrap();        let hostname = "w  ww.idontexists.";        let response = resolver.lookup_ip(hostname).await;        assert!(matches!(response, Err(ResolveError { .. })));        assert!(!resolver.has_changed(hostname, &AddrSet::new()));    }    #[tokio::test]    // Ok, this test is based on the fact that google does DNS RR    // and does not responds with every available ip everytime, so    // if I cache here, it will miss after one cache iteration or two.    async fn thread() {        let config = CachedResolverConfig {            dns_max_ttl: 10,            enabled: true,        };        let resolver = CachedResolver::new(config, None).await.unwrap();        let hostname = "www.google.com.";        let response = resolver.lookup_ip(hostname).await;        let addr_set = response.unwrap();        assert!(!resolver.has_changed(hostname, &addr_set));        let resolver_for_refresher = resolver.clone();        let _thread_handle = tokio::task::spawn(async move {            resolver_for_refresher.refresh_dns_entries_loop().await;        });        assert!(!resolver.has_changed(hostname, &addr_set));    }}
 |