123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- // SCRAM-SHA-256 authentication. Heavily inspired by
- // https://github.com/sfackler/rust-postgres/
- // SASL implementation.
- use base64::{engine::general_purpose, Engine as _};
- use bytes::BytesMut;
- use hmac::{Hmac, Mac};
- use rand::{self, Rng};
- use sha2::digest::FixedOutput;
- use sha2::{Digest, Sha256};
- use std::fmt::Write;
- use crate::constants::*;
- use crate::errors::Error;
- /// Normalize a password string. Postgres
- /// passwords don't have to be UTF-8.
- fn normalize(pass: &[u8]) -> Vec<u8> {
- let pass = match std::str::from_utf8(pass) {
- Ok(pass) => pass,
- Err(_) => return pass.to_vec(),
- };
- match stringprep::saslprep(pass) {
- Ok(pass) => pass.into_owned().into_bytes(),
- Err(_) => pass.as_bytes().to_vec(),
- }
- }
- /// Keep the SASL state through the exchange.
- /// It takes 3 messages to complete the authentication.
- pub struct ScramSha256 {
- password: String,
- salted_password: [u8; 32],
- auth_message: String,
- message: BytesMut,
- nonce: String,
- }
- impl ScramSha256 {
- /// Create the Scram state from a password. It'll automatically
- /// generate a nonce.
- pub fn new(password: &str) -> ScramSha256 {
- let mut rng = rand::thread_rng();
- let nonce = (0..NONCE_LENGTH)
- .map(|_| {
- let mut v = rng.gen_range(0x21u8..0x7e);
- if v == 0x2c {
- v = 0x7e
- }
- v as char
- })
- .collect::<String>();
- Self::from_nonce(password, &nonce)
- }
- /// Used for testing.
- pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
- let message = BytesMut::from(format!("{}n=,r={}", "n,,", nonce).as_bytes());
- ScramSha256 {
- password: password.to_string(),
- nonce: String::from(nonce),
- message,
- salted_password: [0u8; 32],
- auth_message: String::new(),
- }
- }
- /// Get the current state of the SASL authentication.
- pub fn message(&mut self) -> BytesMut {
- self.message.clone()
- }
- /// Update the state with message received from server.
- pub fn update(&mut self, message: &BytesMut) -> Result<BytesMut, Error> {
- let server_message = Message::parse(message)?;
- if !server_message.nonce.starts_with(&self.nonce) {
- return Err(Error::ProtocolSyncError(format!("SCRAM")));
- }
- let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
- Ok(salt) => salt,
- Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
- };
- let salted_password = Self::hi(
- &normalize(self.password.as_bytes()),
- &salt,
- server_message.iterations,
- );
- // Save for verification of final server message.
- self.salted_password = salted_password;
- let mut hmac = match Hmac::<Sha256>::new_from_slice(&salted_password) {
- Ok(hmac) => hmac,
- Err(_) => return Err(Error::ServerError),
- };
- hmac.update(b"Client Key");
- let client_key = hmac.finalize().into_bytes();
- let mut hash = Sha256::default();
- hash.update(client_key.as_slice());
- let stored_key = hash.finalize_fixed();
- let mut cbind_input = vec![];
- cbind_input.extend("n,,".as_bytes());
- let cbind_input = general_purpose::STANDARD.encode(&cbind_input);
- self.message.clear();
- // Start writing the client reply.
- match write!(
- &mut self.message,
- "c={},r={}",
- cbind_input, server_message.nonce
- ) {
- Ok(_) => (),
- Err(_) => return Err(Error::ServerError),
- };
- let auth_message = format!(
- "n=,r={},{},{}",
- self.nonce,
- String::from_utf8_lossy(&message[..]),
- String::from_utf8_lossy(&self.message[..])
- );
- let mut hmac = match Hmac::<Sha256>::new_from_slice(&stored_key) {
- Ok(hmac) => hmac,
- Err(_) => return Err(Error::ServerError),
- };
- hmac.update(auth_message.as_bytes());
- // Save the auth message for server final message verification.
- self.auth_message = auth_message;
- let client_signature = hmac.finalize().into_bytes();
- // Sign the client proof.
- let mut client_proof = client_key;
- for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
- *proof ^= signature;
- }
- match write!(
- &mut self.message,
- ",p={}",
- general_purpose::STANDARD.encode(&*client_proof)
- ) {
- Ok(_) => (),
- Err(_) => return Err(Error::ServerError),
- };
- Ok(self.message.clone())
- }
- /// Verify final server message.
- pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
- let final_message = FinalMessage::parse(message)?;
- let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
- Ok(verifier) => verifier,
- Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
- };
- let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
- Ok(hmac) => hmac,
- Err(_) => return Err(Error::ServerError),
- };
- hmac.update(b"Server Key");
- let server_key = hmac.finalize().into_bytes();
- let mut hmac = match Hmac::<Sha256>::new_from_slice(&server_key) {
- Ok(hmac) => hmac,
- Err(_) => return Err(Error::ServerError),
- };
- hmac.update(self.auth_message.as_bytes());
- match hmac.verify_slice(&verifier) {
- Ok(_) => Ok(()),
- Err(_) => Err(Error::ServerError),
- }
- }
- /// Hash the password with the salt i-times.
- fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
- let mut hmac =
- Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
- hmac.update(salt);
- hmac.update(&[0, 0, 0, 1]);
- let mut prev = hmac.finalize().into_bytes();
- let mut hi = prev;
- for _ in 1..i {
- let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
- hmac.update(&prev);
- prev = hmac.finalize().into_bytes();
- for (hi, prev) in hi.iter_mut().zip(prev) {
- *hi ^= prev;
- }
- }
- hi.into()
- }
- }
- /// Parse the server challenge.
- struct Message {
- nonce: String,
- salt: String,
- iterations: u32,
- }
- impl Message {
- /// Parse the server SASL challenge.
- fn parse(message: &BytesMut) -> Result<Message, Error> {
- let parts = String::from_utf8_lossy(&message[..])
- .split(',')
- .map(|s| s.to_string())
- .collect::<Vec<String>>();
- if parts.len() != 3 {
- return Err(Error::ProtocolSyncError(format!("SCRAM")));
- }
- let nonce = str::replace(&parts[0], "r=", "");
- let salt = str::replace(&parts[1], "s=", "");
- let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
- Ok(iterations) => iterations,
- Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
- };
- Ok(Message {
- nonce,
- salt,
- iterations,
- })
- }
- }
- /// Parse server final validation message.
- struct FinalMessage {
- value: String,
- }
- impl FinalMessage {
- /// Parse the server final validation message.
- pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
- if !message.starts_with(b"v=") || message.len() < 4 {
- return Err(Error::ProtocolSyncError(format!("SCRAM")));
- }
- Ok(FinalMessage {
- value: String::from_utf8_lossy(&message[2..]).to_string(),
- })
- }
- }
- #[cfg(test)]
- mod test {
- use super::*;
- #[test]
- fn parse_server_first_message() {
- let message = BytesMut::from(
- "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes(),
- );
- let message = Message::parse(&message).unwrap();
- assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
- assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
- assert_eq!(message.iterations, 4096);
- }
- #[test]
- fn parse_server_last_message() {
- let f = FinalMessage::parse(&BytesMut::from(
- "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes(),
- ))
- .unwrap();
- assert_eq!(
- f.value,
- "U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".to_string()
- );
- }
- // recorded auth exchange from psql
- #[test]
- fn exchange() {
- let password = "foobar";
- let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
- let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
- let server_first =
- "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
- =4096";
- let client_final =
- "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
- 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
- let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
- let mut scram = ScramSha256::from_nonce(password, nonce);
- let message = scram.message();
- assert_eq!(std::str::from_utf8(&message).unwrap(), client_first);
- let result = scram
- .update(&BytesMut::from(server_first.as_bytes()))
- .unwrap();
- assert_eq!(std::str::from_utf8(&result).unwrap(), client_final);
- scram
- .finish(&BytesMut::from(server_final.as_bytes()))
- .unwrap();
- }
- }
|