dns_cache.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. use crate::config::get_config;
  2. use crate::errors::Error;
  3. use arc_swap::ArcSwap;
  4. use log::{debug, error, info, warn};
  5. use once_cell::sync::Lazy;
  6. use std::collections::{HashMap, HashSet};
  7. use std::io;
  8. use std::net::IpAddr;
  9. use std::sync::Arc;
  10. use std::sync::RwLock;
  11. use tokio::time::{sleep, Duration};
  12. use trust_dns_resolver::error::{ResolveError, ResolveResult};
  13. use trust_dns_resolver::lookup_ip::LookupIp;
  14. use trust_dns_resolver::TokioAsyncResolver;
  15. /// Cached Resolver Globally available
  16. pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
  17. Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
  18. // Ip addressed are returned as a set of addresses
  19. // so we can compare.
  20. #[derive(Clone, PartialEq, Debug)]
  21. pub struct AddrSet {
  22. set: HashSet<IpAddr>,
  23. }
  24. impl AddrSet {
  25. fn new() -> AddrSet {
  26. AddrSet {
  27. set: HashSet::new(),
  28. }
  29. }
  30. }
  31. impl From<LookupIp> for AddrSet {
  32. fn from(lookup_ip: LookupIp) -> Self {
  33. let mut addr_set = AddrSet::new();
  34. for address in lookup_ip.iter() {
  35. addr_set.set.insert(address);
  36. }
  37. addr_set
  38. }
  39. }
  40. ///
  41. /// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
  42. ///
  43. /// The system works as follows:
  44. ///
  45. /// When a host is to be resolved, if we have not resolved it before, a new resolution is
  46. /// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
  47. /// cache is refreshed.
  48. ///
  49. /// # Example:
  50. ///
  51. /// ```
  52. /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
  53. ///
  54. /// # tokio_test::block_on(async {
  55. /// let config = CachedResolverConfig::default();
  56. /// let resolver = CachedResolver::new(config, None).await.unwrap();
  57. /// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
  58. /// # })
  59. /// ```
  60. ///
  61. /// // Now the ip resolution is stored in local cache and subsequent
  62. /// // calls will be returned from cache. Also, the cache is refreshed
  63. /// // and updated every 10 seconds.
  64. ///
  65. /// // You can now check if an 'old' lookup differs from what it's currently
  66. /// // store in cache by using `has_changed`.
  67. /// resolver.has_changed("www.example.com.", addrset)
  68. #[derive(Default)]
  69. pub struct CachedResolver {
  70. // The configuration of the cached_resolver.
  71. config: CachedResolverConfig,
  72. // This is the hash that contains the hash.
  73. data: Option<RwLock<HashMap<String, AddrSet>>>,
  74. // The resolver to be used for DNS queries.
  75. resolver: Option<TokioAsyncResolver>,
  76. // The RefreshLoop
  77. refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
  78. }
  79. ///
  80. /// Configuration
  81. #[derive(Clone, Debug, Default, PartialEq)]
  82. pub struct CachedResolverConfig {
  83. /// Amount of time in secods that a resolved dns address is considered stale.
  84. dns_max_ttl: u64,
  85. /// Enabled or disabled? (this is so we can reload config)
  86. enabled: bool,
  87. }
  88. impl CachedResolverConfig {
  89. fn new(dns_max_ttl: u64, enabled: bool) -> Self {
  90. CachedResolverConfig {
  91. dns_max_ttl,
  92. enabled,
  93. }
  94. }
  95. }
  96. impl From<crate::config::Config> for CachedResolverConfig {
  97. fn from(config: crate::config::Config) -> Self {
  98. CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
  99. }
  100. }
  101. impl CachedResolver {
  102. ///
  103. /// Returns a new Arc<CachedResolver> based on passed configuration.
  104. /// It also starts the loop that will refresh cache entries.
  105. ///
  106. /// # Arguments:
  107. ///
  108. /// * `config` - The `CachedResolverConfig` to be used to create the resolver.
  109. ///
  110. /// # Example:
  111. ///
  112. /// ```
  113. /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
  114. ///
  115. /// # tokio_test::block_on(async {
  116. /// let config = CachedResolverConfig::default();
  117. /// let resolver = CachedResolver::new(config, None).await.unwrap();
  118. /// # })
  119. /// ```
  120. ///
  121. pub async fn new(
  122. config: CachedResolverConfig,
  123. data: Option<HashMap<String, AddrSet>>,
  124. ) -> Result<Arc<Self>, io::Error> {
  125. // Construct a new Resolver with default configuration options
  126. let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
  127. let data = if let Some(hash) = data {
  128. Some(RwLock::new(hash))
  129. } else {
  130. Some(RwLock::new(HashMap::new()))
  131. };
  132. let instance = Arc::new(Self {
  133. config,
  134. resolver,
  135. data,
  136. refresh_loop: RwLock::new(None),
  137. });
  138. if instance.enabled() {
  139. info!("Scheduling DNS refresh loop");
  140. let refresh_loop = tokio::task::spawn({
  141. let instance = instance.clone();
  142. async move {
  143. instance.refresh_dns_entries_loop().await;
  144. }
  145. });
  146. *(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
  147. }
  148. Ok(instance)
  149. }
  150. pub fn enabled(&self) -> bool {
  151. self.config.enabled
  152. }
  153. // Schedules the refresher
  154. async fn refresh_dns_entries_loop(&self) {
  155. let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
  156. let interval = Duration::from_secs(self.config.dns_max_ttl);
  157. loop {
  158. debug!("Begin refreshing cached DNS addresses.");
  159. // To minimize the time we hold the lock, we first create
  160. // an array with keys.
  161. let mut hostnames: Vec<String> = Vec::new();
  162. {
  163. if let Some(ref data) = self.data {
  164. for hostname in data.read().unwrap().keys() {
  165. hostnames.push(hostname.clone());
  166. }
  167. }
  168. }
  169. for hostname in hostnames.iter() {
  170. let addrset = self
  171. .fetch_from_cache(hostname.as_str())
  172. .expect("Could not obtain expected address from cache, this should not happen");
  173. match resolver.lookup_ip(hostname).await {
  174. Ok(lookup_ip) => {
  175. let new_addrset = AddrSet::from(lookup_ip);
  176. debug!(
  177. "Obtained address for host ({}) -> ({:?})",
  178. hostname, new_addrset
  179. );
  180. if addrset != new_addrset {
  181. debug!(
  182. "Addr changed from {:?} to {:?} updating cache.",
  183. addrset, new_addrset
  184. );
  185. self.store_in_cache(hostname, new_addrset);
  186. }
  187. }
  188. Err(err) => {
  189. error!(
  190. "There was an error trying to resolv {}: ({}).",
  191. hostname, err
  192. );
  193. }
  194. }
  195. }
  196. debug!("Finished refreshing cached DNS addresses.");
  197. sleep(interval).await;
  198. }
  199. }
  200. /// Returns a `AddrSet` given the specified hostname.
  201. ///
  202. /// This method first tries to fetch the value from the cache, if it misses
  203. /// then it is resolved and stored in the cache. TTL from records is ignored.
  204. ///
  205. /// # Arguments
  206. ///
  207. /// * `host` - A string slice referencing the hostname to be resolved.
  208. ///
  209. /// # Example:
  210. ///
  211. /// ```
  212. /// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
  213. ///
  214. /// # tokio_test::block_on(async {
  215. /// let config = CachedResolverConfig::default();
  216. /// let resolver = CachedResolver::new(config, None).await.unwrap();
  217. /// let response = resolver.lookup_ip("www.google.com.");
  218. /// # })
  219. /// ```
  220. ///
  221. pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
  222. debug!("Lookup up {} in cache", host);
  223. match self.fetch_from_cache(host) {
  224. Some(addr_set) => {
  225. debug!("Cache hit!");
  226. Ok(addr_set)
  227. }
  228. None => {
  229. debug!("Not found, executing a dns query!");
  230. if let Some(ref resolver) = self.resolver {
  231. let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
  232. debug!("Obtained: {:?}", addr_set);
  233. self.store_in_cache(host, addr_set.clone());
  234. Ok(addr_set)
  235. } else {
  236. Err(ResolveError::from("No resolver available"))
  237. }
  238. }
  239. }
  240. }
  241. //
  242. // Returns true if the stored host resolution differs from the AddrSet passed.
  243. pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
  244. if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
  245. return fetched_addr_set != *addr_set;
  246. }
  247. false
  248. }
  249. // Fetches an AddrSet from the inner cache adquiring the read lock.
  250. fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
  251. if let Some(ref hash) = self.data {
  252. if let Some(addr_set) = hash.read().unwrap().get(key) {
  253. return Some(addr_set.clone());
  254. }
  255. }
  256. None
  257. }
  258. // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
  259. // cache.
  260. pub async fn from_config() -> Result<(), Error> {
  261. let cached_resolver = CACHED_RESOLVER.load();
  262. let desired_config = CachedResolverConfig::from(get_config());
  263. if cached_resolver.config != desired_config {
  264. if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
  265. warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
  266. refresh_loop.abort()
  267. }
  268. let new_resolver = if let Some(ref data) = cached_resolver.data {
  269. let data = Some(data.read().unwrap().clone());
  270. CachedResolver::new(desired_config, data).await
  271. } else {
  272. CachedResolver::new(desired_config, None).await
  273. };
  274. match new_resolver {
  275. Ok(ok) => {
  276. CACHED_RESOLVER.store(ok);
  277. Ok(())
  278. }
  279. Err(err) => {
  280. let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
  281. Err(Error::DNSCachedError(message))
  282. }
  283. }
  284. } else {
  285. Ok(())
  286. }
  287. }
  288. // Stores the AddrSet in cache adquiring the write lock.
  289. fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
  290. if let Some(ref data) = self.data {
  291. data.write().unwrap().insert(host.to_string(), addr_set);
  292. } else {
  293. error!("Could not insert, Hash not initialized");
  294. }
  295. }
  296. }
  297. #[cfg(test)]
  298. mod tests {
  299. use super::*;
  300. use trust_dns_resolver::error::ResolveError;
  301. #[tokio::test]
  302. async fn new() {
  303. let config = CachedResolverConfig {
  304. dns_max_ttl: 10,
  305. enabled: true,
  306. };
  307. let resolver = CachedResolver::new(config, None).await;
  308. assert!(resolver.is_ok());
  309. }
  310. #[tokio::test]
  311. async fn lookup_ip() {
  312. let config = CachedResolverConfig {
  313. dns_max_ttl: 10,
  314. enabled: true,
  315. };
  316. let resolver = CachedResolver::new(config, None).await.unwrap();
  317. let response = resolver.lookup_ip("www.google.com.").await;
  318. assert!(response.is_ok());
  319. }
  320. #[tokio::test]
  321. async fn has_changed() {
  322. let config = CachedResolverConfig {
  323. dns_max_ttl: 10,
  324. enabled: true,
  325. };
  326. let resolver = CachedResolver::new(config, None).await.unwrap();
  327. let hostname = "www.google.com.";
  328. let response = resolver.lookup_ip(hostname).await;
  329. let addr_set = response.unwrap();
  330. assert!(!resolver.has_changed(hostname, &addr_set));
  331. }
  332. #[tokio::test]
  333. async fn unknown_host() {
  334. let config = CachedResolverConfig {
  335. dns_max_ttl: 10,
  336. enabled: true,
  337. };
  338. let resolver = CachedResolver::new(config, None).await.unwrap();
  339. let hostname = "www.idontexists.";
  340. let response = resolver.lookup_ip(hostname).await;
  341. assert!(matches!(response, Err(ResolveError { .. })));
  342. }
  343. #[tokio::test]
  344. async fn incorrect_address() {
  345. let config = CachedResolverConfig {
  346. dns_max_ttl: 10,
  347. enabled: true,
  348. };
  349. let resolver = CachedResolver::new(config, None).await.unwrap();
  350. let hostname = "w ww.idontexists.";
  351. let response = resolver.lookup_ip(hostname).await;
  352. assert!(matches!(response, Err(ResolveError { .. })));
  353. assert!(!resolver.has_changed(hostname, &AddrSet::new()));
  354. }
  355. #[tokio::test]
  356. // Ok, this test is based on the fact that google does DNS RR
  357. // and does not responds with every available ip everytime, so
  358. // if I cache here, it will miss after one cache iteration or two.
  359. async fn thread() {
  360. let config = CachedResolverConfig {
  361. dns_max_ttl: 10,
  362. enabled: true,
  363. };
  364. let resolver = CachedResolver::new(config, None).await.unwrap();
  365. let hostname = "www.google.com.";
  366. let response = resolver.lookup_ip(hostname).await;
  367. let addr_set = response.unwrap();
  368. assert!(!resolver.has_changed(hostname, &addr_set));
  369. let resolver_for_refresher = resolver.clone();
  370. let _thread_handle = tokio::task::spawn(async move {
  371. resolver_for_refresher.refresh_dns_entries_loop().await;
  372. });
  373. assert!(!resolver.has_changed(hostname, &addr_set));
  374. }
  375. }