scram.rs 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. // SCRAM-SHA-256 authentication. Heavily inspired by
  2. // https://github.com/sfackler/rust-postgres/
  3. // SASL implementation.
  4. use base64::{engine::general_purpose, Engine as _};
  5. use bytes::BytesMut;
  6. use hmac::{Hmac, Mac};
  7. use rand::{self, Rng};
  8. use sha2::digest::FixedOutput;
  9. use sha2::{Digest, Sha256};
  10. use std::fmt::Write;
  11. use crate::constants::*;
  12. use crate::errors::Error;
  13. /// Normalize a password string. Postgres
  14. /// passwords don't have to be UTF-8.
  15. fn normalize(pass: &[u8]) -> Vec<u8> {
  16. let pass = match std::str::from_utf8(pass) {
  17. Ok(pass) => pass,
  18. Err(_) => return pass.to_vec(),
  19. };
  20. match stringprep::saslprep(pass) {
  21. Ok(pass) => pass.into_owned().into_bytes(),
  22. Err(_) => pass.as_bytes().to_vec(),
  23. }
  24. }
  25. /// Keep the SASL state through the exchange.
  26. /// It takes 3 messages to complete the authentication.
  27. pub struct ScramSha256 {
  28. password: String,
  29. salted_password: [u8; 32],
  30. auth_message: String,
  31. message: BytesMut,
  32. nonce: String,
  33. }
  34. impl ScramSha256 {
  35. /// Create the Scram state from a password. It'll automatically
  36. /// generate a nonce.
  37. pub fn new(password: &str) -> ScramSha256 {
  38. let mut rng = rand::thread_rng();
  39. let nonce = (0..NONCE_LENGTH)
  40. .map(|_| {
  41. let mut v = rng.gen_range(0x21u8..0x7e);
  42. if v == 0x2c {
  43. v = 0x7e
  44. }
  45. v as char
  46. })
  47. .collect::<String>();
  48. Self::from_nonce(password, &nonce)
  49. }
  50. /// Used for testing.
  51. pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
  52. let message = BytesMut::from(format!("{}n=,r={}", "n,,", nonce).as_bytes());
  53. ScramSha256 {
  54. password: password.to_string(),
  55. nonce: String::from(nonce),
  56. message,
  57. salted_password: [0u8; 32],
  58. auth_message: String::new(),
  59. }
  60. }
  61. /// Get the current state of the SASL authentication.
  62. pub fn message(&mut self) -> BytesMut {
  63. self.message.clone()
  64. }
  65. /// Update the state with message received from server.
  66. pub fn update(&mut self, message: &BytesMut) -> Result<BytesMut, Error> {
  67. let server_message = Message::parse(message)?;
  68. if !server_message.nonce.starts_with(&self.nonce) {
  69. return Err(Error::ProtocolSyncError(format!("SCRAM")));
  70. }
  71. let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
  72. Ok(salt) => salt,
  73. Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
  74. };
  75. let salted_password = Self::hi(
  76. &normalize(self.password.as_bytes()),
  77. &salt,
  78. server_message.iterations,
  79. );
  80. // Save for verification of final server message.
  81. self.salted_password = salted_password;
  82. let mut hmac = match Hmac::<Sha256>::new_from_slice(&salted_password) {
  83. Ok(hmac) => hmac,
  84. Err(_) => return Err(Error::ServerError),
  85. };
  86. hmac.update(b"Client Key");
  87. let client_key = hmac.finalize().into_bytes();
  88. let mut hash = Sha256::default();
  89. hash.update(client_key.as_slice());
  90. let stored_key = hash.finalize_fixed();
  91. let mut cbind_input = vec![];
  92. cbind_input.extend("n,,".as_bytes());
  93. let cbind_input = general_purpose::STANDARD.encode(&cbind_input);
  94. self.message.clear();
  95. // Start writing the client reply.
  96. match write!(
  97. &mut self.message,
  98. "c={},r={}",
  99. cbind_input, server_message.nonce
  100. ) {
  101. Ok(_) => (),
  102. Err(_) => return Err(Error::ServerError),
  103. };
  104. let auth_message = format!(
  105. "n=,r={},{},{}",
  106. self.nonce,
  107. String::from_utf8_lossy(&message[..]),
  108. String::from_utf8_lossy(&self.message[..])
  109. );
  110. let mut hmac = match Hmac::<Sha256>::new_from_slice(&stored_key) {
  111. Ok(hmac) => hmac,
  112. Err(_) => return Err(Error::ServerError),
  113. };
  114. hmac.update(auth_message.as_bytes());
  115. // Save the auth message for server final message verification.
  116. self.auth_message = auth_message;
  117. let client_signature = hmac.finalize().into_bytes();
  118. // Sign the client proof.
  119. let mut client_proof = client_key;
  120. for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
  121. *proof ^= signature;
  122. }
  123. match write!(
  124. &mut self.message,
  125. ",p={}",
  126. general_purpose::STANDARD.encode(&*client_proof)
  127. ) {
  128. Ok(_) => (),
  129. Err(_) => return Err(Error::ServerError),
  130. };
  131. Ok(self.message.clone())
  132. }
  133. /// Verify final server message.
  134. pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
  135. let final_message = FinalMessage::parse(message)?;
  136. let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
  137. Ok(verifier) => verifier,
  138. Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
  139. };
  140. let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
  141. Ok(hmac) => hmac,
  142. Err(_) => return Err(Error::ServerError),
  143. };
  144. hmac.update(b"Server Key");
  145. let server_key = hmac.finalize().into_bytes();
  146. let mut hmac = match Hmac::<Sha256>::new_from_slice(&server_key) {
  147. Ok(hmac) => hmac,
  148. Err(_) => return Err(Error::ServerError),
  149. };
  150. hmac.update(self.auth_message.as_bytes());
  151. match hmac.verify_slice(&verifier) {
  152. Ok(_) => Ok(()),
  153. Err(_) => Err(Error::ServerError),
  154. }
  155. }
  156. /// Hash the password with the salt i-times.
  157. fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
  158. let mut hmac =
  159. Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
  160. hmac.update(salt);
  161. hmac.update(&[0, 0, 0, 1]);
  162. let mut prev = hmac.finalize().into_bytes();
  163. let mut hi = prev;
  164. for _ in 1..i {
  165. let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
  166. hmac.update(&prev);
  167. prev = hmac.finalize().into_bytes();
  168. for (hi, prev) in hi.iter_mut().zip(prev) {
  169. *hi ^= prev;
  170. }
  171. }
  172. hi.into()
  173. }
  174. }
  175. /// Parse the server challenge.
  176. struct Message {
  177. nonce: String,
  178. salt: String,
  179. iterations: u32,
  180. }
  181. impl Message {
  182. /// Parse the server SASL challenge.
  183. fn parse(message: &BytesMut) -> Result<Message, Error> {
  184. let parts = String::from_utf8_lossy(&message[..])
  185. .split(',')
  186. .map(|s| s.to_string())
  187. .collect::<Vec<String>>();
  188. if parts.len() != 3 {
  189. return Err(Error::ProtocolSyncError(format!("SCRAM")));
  190. }
  191. let nonce = str::replace(&parts[0], "r=", "");
  192. let salt = str::replace(&parts[1], "s=", "");
  193. let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
  194. Ok(iterations) => iterations,
  195. Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
  196. };
  197. Ok(Message {
  198. nonce,
  199. salt,
  200. iterations,
  201. })
  202. }
  203. }
  204. /// Parse server final validation message.
  205. struct FinalMessage {
  206. value: String,
  207. }
  208. impl FinalMessage {
  209. /// Parse the server final validation message.
  210. pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
  211. if !message.starts_with(b"v=") || message.len() < 4 {
  212. return Err(Error::ProtocolSyncError(format!("SCRAM")));
  213. }
  214. Ok(FinalMessage {
  215. value: String::from_utf8_lossy(&message[2..]).to_string(),
  216. })
  217. }
  218. }
  219. #[cfg(test)]
  220. mod test {
  221. use super::*;
  222. #[test]
  223. fn parse_server_first_message() {
  224. let message = BytesMut::from(
  225. "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes(),
  226. );
  227. let message = Message::parse(&message).unwrap();
  228. assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
  229. assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
  230. assert_eq!(message.iterations, 4096);
  231. }
  232. #[test]
  233. fn parse_server_last_message() {
  234. let f = FinalMessage::parse(&BytesMut::from(
  235. "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes(),
  236. ))
  237. .unwrap();
  238. assert_eq!(
  239. f.value,
  240. "U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".to_string()
  241. );
  242. }
  243. // recorded auth exchange from psql
  244. #[test]
  245. fn exchange() {
  246. let password = "foobar";
  247. let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
  248. let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
  249. let server_first =
  250. "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
  251. =4096";
  252. let client_final =
  253. "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
  254. 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
  255. let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
  256. let mut scram = ScramSha256::from_nonce(password, nonce);
  257. let message = scram.message();
  258. assert_eq!(std::str::from_utf8(&message).unwrap(), client_first);
  259. let result = scram
  260. .update(&BytesMut::from(server_first.as_bytes()))
  261. .unwrap();
  262. assert_eq!(std::str::from_utf8(&result).unwrap(), client_final);
  263. scram
  264. .finish(&BytesMut::from(server_final.as_bytes()))
  265. .unwrap();
  266. }
  267. }