Commit d981074c authored by Heikki Linnakangas's avatar Heikki Linnakangas

Misc SCRAM code cleanups.

* Move computation of SaltedPassword to a separate function from
  scram_ClientOrServerKey(). This saves a lot of cycles in libpq, by
  computing SaltedPassword only once per authentication. (Computing
  SaltedPassword is expensive by design.)

* Split scram_ClientOrServerKey() into two functions. Improves
  readability, by making the calling code less verbose.

* Rename "server proof" to "server signature", to better match the
  nomenclature used in RFC 5802.

* Rename SCRAM_SALT_LEN to SCRAM_DEFAULT_SALT_LEN, to make it more clear
  that the salt can be of any length, and the constant only specifies how
  long a salt we use when we generate a new verifier. Also rename
  SCRAM_ITERATIONS_DEFAULT to SCRAM_DEFAULT_ITERATIONS, for consistency.

These things caught my eye while working on other upcoming changes.
parent b9a3ef55
...@@ -396,7 +396,8 @@ scram_build_verifier(const char *username, const char *password, ...@@ -396,7 +396,8 @@ scram_build_verifier(const char *username, const char *password,
{ {
char *prep_password = NULL; char *prep_password = NULL;
pg_saslprep_rc rc; pg_saslprep_rc rc;
char saltbuf[SCRAM_SALT_LEN]; char saltbuf[SCRAM_DEFAULT_SALT_LEN];
uint8 salted_password[SCRAM_KEY_LEN];
uint8 keybuf[SCRAM_KEY_LEN]; uint8 keybuf[SCRAM_KEY_LEN];
char *encoded_salt; char *encoded_salt;
char *encoded_storedkey; char *encoded_storedkey;
...@@ -414,10 +415,10 @@ scram_build_verifier(const char *username, const char *password, ...@@ -414,10 +415,10 @@ scram_build_verifier(const char *username, const char *password,
password = (const char *) prep_password; password = (const char *) prep_password;
if (iterations <= 0) if (iterations <= 0)
iterations = SCRAM_ITERATIONS_DEFAULT; iterations = SCRAM_DEFAULT_ITERATIONS;
/* Generate salt, and encode it in base64 */ /* Generate salt, and encode it in base64 */
if (!pg_backend_random(saltbuf, SCRAM_SALT_LEN)) if (!pg_backend_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
{ {
ereport(LOG, ereport(LOG,
(errcode(ERRCODE_INTERNAL_ERROR), (errcode(ERRCODE_INTERNAL_ERROR),
...@@ -425,13 +426,14 @@ scram_build_verifier(const char *username, const char *password, ...@@ -425,13 +426,14 @@ scram_build_verifier(const char *username, const char *password,
return NULL; return NULL;
} }
encoded_salt = palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1); encoded_salt = palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
encoded_len = pg_b64_encode(saltbuf, SCRAM_SALT_LEN, encoded_salt); encoded_len = pg_b64_encode(saltbuf, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
encoded_salt[encoded_len] = '\0'; encoded_salt[encoded_len] = '\0';
/* Calculate StoredKey, and encode it in base64 */ /* Calculate StoredKey, and encode it in base64 */
scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN, scram_SaltedPassword(password, saltbuf, SCRAM_DEFAULT_SALT_LEN,
iterations, SCRAM_CLIENT_KEY_NAME, keybuf); iterations, salted_password);
scram_ClientKey(salted_password, keybuf);
scram_H(keybuf, SCRAM_KEY_LEN, keybuf); /* StoredKey */ scram_H(keybuf, SCRAM_KEY_LEN, keybuf); /* StoredKey */
encoded_storedkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); encoded_storedkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
...@@ -440,8 +442,7 @@ scram_build_verifier(const char *username, const char *password, ...@@ -440,8 +442,7 @@ scram_build_verifier(const char *username, const char *password,
encoded_storedkey[encoded_len] = '\0'; encoded_storedkey[encoded_len] = '\0';
/* And same for ServerKey */ /* And same for ServerKey */
scram_ClientOrServerKey(password, saltbuf, SCRAM_SALT_LEN, iterations, scram_ServerKey(salted_password, keybuf);
SCRAM_SERVER_KEY_NAME, keybuf);
encoded_serverkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1); encoded_serverkey = palloc(pg_b64_enc_len(SCRAM_KEY_LEN) + 1);
encoded_len = pg_b64_encode((const char *) keybuf, SCRAM_KEY_LEN, encoded_len = pg_b64_encode((const char *) keybuf, SCRAM_KEY_LEN,
...@@ -473,6 +474,7 @@ scram_verify_plain_password(const char *username, const char *password, ...@@ -473,6 +474,7 @@ scram_verify_plain_password(const char *username, const char *password,
char *salt; char *salt;
int saltlen; int saltlen;
int iterations; int iterations;
uint8 salted_password[SCRAM_KEY_LEN];
uint8 stored_key[SCRAM_KEY_LEN]; uint8 stored_key[SCRAM_KEY_LEN];
uint8 server_key[SCRAM_KEY_LEN]; uint8 server_key[SCRAM_KEY_LEN];
uint8 computed_key[SCRAM_KEY_LEN]; uint8 computed_key[SCRAM_KEY_LEN];
...@@ -502,9 +504,9 @@ scram_verify_plain_password(const char *username, const char *password, ...@@ -502,9 +504,9 @@ scram_verify_plain_password(const char *username, const char *password,
if (rc == SASLPREP_SUCCESS) if (rc == SASLPREP_SUCCESS)
password = prep_password; password = prep_password;
/* Compute Server key based on the user-supplied plaintext password */ /* Compute Server Key based on the user-supplied plaintext password */
scram_ClientOrServerKey(password, salt, saltlen, iterations, scram_SaltedPassword(password, salt, saltlen, iterations, salted_password);
SCRAM_SERVER_KEY_NAME, computed_key); scram_ServerKey(salted_password, computed_key);
if (prep_password) if (prep_password)
pfree(prep_password); pfree(prep_password);
...@@ -630,12 +632,12 @@ mock_scram_verifier(const char *username, int *iterations, char **salt, ...@@ -630,12 +632,12 @@ mock_scram_verifier(const char *username, int *iterations, char **salt,
/* Generate deterministic salt */ /* Generate deterministic salt */
raw_salt = scram_MockSalt(username); raw_salt = scram_MockSalt(username);
encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_SALT_LEN) + 1); encoded_salt = (char *) palloc(pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN) + 1);
encoded_len = pg_b64_encode(raw_salt, SCRAM_SALT_LEN, encoded_salt); encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt);
encoded_salt[encoded_len] = '\0'; encoded_salt[encoded_len] = '\0';
*salt = encoded_salt; *salt = encoded_salt;
*iterations = SCRAM_ITERATIONS_DEFAULT; *iterations = SCRAM_DEFAULT_ITERATIONS;
/* StoredKey and ServerKey are not used in a doomed authentication */ /* StoredKey and ServerKey are not used in a doomed authentication */
memset(stored_key, 0, SCRAM_KEY_LEN); memset(stored_key, 0, SCRAM_KEY_LEN);
...@@ -1179,7 +1181,7 @@ build_server_final_message(scram_state *state) ...@@ -1179,7 +1181,7 @@ build_server_final_message(scram_state *state)
/* /*
* Determinisitcally generate salt for mock authentication, using a SHA256 * Determinisitcally generate salt for mock authentication, using a SHA256
* hash based on the username and a cluster-level secret key. Returns a * hash based on the username and a cluster-level secret key. Returns a
* pointer to a static buffer of size SCRAM_SALT_LEN. * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN.
*/ */
static char * static char *
scram_MockSalt(const char *username) scram_MockSalt(const char *username)
...@@ -1194,7 +1196,7 @@ scram_MockSalt(const char *username) ...@@ -1194,7 +1196,7 @@ scram_MockSalt(const char *username)
* not larger the SHA256 digest length. If the salt is smaller, the caller * not larger the SHA256 digest length. If the salt is smaller, the caller
* will just ignore the extra data)) * will just ignore the extra data))
*/ */
StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_SALT_LEN, StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
"salt length greater than SHA256 digest length"); "salt length greater than SHA256 digest length");
pg_sha256_init(&ctx); pg_sha256_init(&ctx);
......
...@@ -98,14 +98,16 @@ scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx) ...@@ -98,14 +98,16 @@ scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx)
} }
/* /*
* Iterate hash calculation of HMAC entry using given salt. * Calculate SaltedPassword.
* scram_Hi() is essentially PBKDF2 (see RFC2898) with HMAC() as the *
* pseudorandom function. * The password should already be normalized by SASLprep.
*/ */
static void void
scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *result) scram_SaltedPassword(const char *password,
const char *salt, int saltlen, int iterations,
uint8 *result)
{ {
int str_len = strlen(str); int password_len = strlen(password);
uint32 one = htonl(1); uint32 one = htonl(1);
int i, int i,
j; j;
...@@ -113,8 +115,14 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 * ...@@ -113,8 +115,14 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *
uint8 Ui_prev[SCRAM_KEY_LEN]; uint8 Ui_prev[SCRAM_KEY_LEN];
scram_HMAC_ctx hmac_ctx; scram_HMAC_ctx hmac_ctx;
/*
* Iterate hash calculation of HMAC entry using given salt. This is
* essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
* function.
*/
/* First iteration */ /* First iteration */
scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len); scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
scram_HMAC_update(&hmac_ctx, salt, saltlen); scram_HMAC_update(&hmac_ctx, salt, saltlen);
scram_HMAC_update(&hmac_ctx, (char *) &one, sizeof(uint32)); scram_HMAC_update(&hmac_ctx, (char *) &one, sizeof(uint32));
scram_HMAC_final(Ui_prev, &hmac_ctx); scram_HMAC_final(Ui_prev, &hmac_ctx);
...@@ -123,7 +131,7 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 * ...@@ -123,7 +131,7 @@ scram_Hi(const char *str, const char *salt, int saltlen, int iterations, uint8 *
/* Subsequent iterations */ /* Subsequent iterations */
for (i = 2; i <= iterations; i++) for (i = 2; i <= iterations; i++)
{ {
scram_HMAC_init(&hmac_ctx, (uint8 *) str, str_len); scram_HMAC_init(&hmac_ctx, (uint8 *) password, password_len);
scram_HMAC_update(&hmac_ctx, (const char *) Ui_prev, SCRAM_KEY_LEN); scram_HMAC_update(&hmac_ctx, (const char *) Ui_prev, SCRAM_KEY_LEN);
scram_HMAC_final(Ui, &hmac_ctx); scram_HMAC_final(Ui, &hmac_ctx);
for (j = 0; j < SCRAM_KEY_LEN; j++) for (j = 0; j < SCRAM_KEY_LEN; j++)
...@@ -148,20 +156,27 @@ scram_H(const uint8 *input, int len, uint8 *result) ...@@ -148,20 +156,27 @@ scram_H(const uint8 *input, int len, uint8 *result)
} }
/* /*
* Calculate ClientKey or ServerKey. * Calculate ClientKey.
*
* The password should already be normalized by SASLprep.
*/ */
void void
scram_ClientOrServerKey(const char *password, scram_ClientKey(const uint8 *salted_password, uint8 *result)
const char *salt, int saltlen, int iterations, {
const char *keystr, uint8 *result) scram_HMAC_ctx ctx;
scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
scram_HMAC_update(&ctx, "Client Key", strlen("Client Key"));
scram_HMAC_final(result, &ctx);
}
/*
* Calculate ServerKey.
*/
void
scram_ServerKey(const uint8 *salted_password, uint8 *result)
{ {
uint8 keybuf[SCRAM_KEY_LEN];
scram_HMAC_ctx ctx; scram_HMAC_ctx ctx;
scram_Hi(password, salt, saltlen, iterations, keybuf); scram_HMAC_init(&ctx, salted_password, SCRAM_KEY_LEN);
scram_HMAC_init(&ctx, keybuf, SCRAM_KEY_LEN); scram_HMAC_update(&ctx, "Server Key", strlen("Server Key"));
scram_HMAC_update(&ctx, keystr, strlen(keystr));
scram_HMAC_final(result, &ctx); scram_HMAC_final(result, &ctx);
} }
...@@ -29,14 +29,10 @@ ...@@ -29,14 +29,10 @@
#define SCRAM_RAW_NONCE_LEN 10 #define SCRAM_RAW_NONCE_LEN 10
/* length of salt when generating new verifiers */ /* length of salt when generating new verifiers */
#define SCRAM_SALT_LEN 10 #define SCRAM_DEFAULT_SALT_LEN 10
/* default number of iterations when generating verifier */ /* default number of iterations when generating verifier */
#define SCRAM_ITERATIONS_DEFAULT 4096 #define SCRAM_DEFAULT_ITERATIONS 4096
/* Base name of keys used for proof generation */
#define SCRAM_SERVER_KEY_NAME "Server Key"
#define SCRAM_CLIENT_KEY_NAME "Client Key"
/* /*
* Context data for HMAC used in SCRAM authentication. * Context data for HMAC used in SCRAM authentication.
...@@ -51,9 +47,10 @@ extern void scram_HMAC_init(scram_HMAC_ctx *ctx, const uint8 *key, int keylen); ...@@ -51,9 +47,10 @@ extern void scram_HMAC_init(scram_HMAC_ctx *ctx, const uint8 *key, int keylen);
extern void scram_HMAC_update(scram_HMAC_ctx *ctx, const char *str, int slen); extern void scram_HMAC_update(scram_HMAC_ctx *ctx, const char *str, int slen);
extern void scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx); extern void scram_HMAC_final(uint8 *result, scram_HMAC_ctx *ctx);
extern void scram_SaltedPassword(const char *password, const char *salt,
int saltlen, int iterations, uint8 *result);
extern void scram_H(const uint8 *str, int len, uint8 *result); extern void scram_H(const uint8 *str, int len, uint8 *result);
extern void scram_ClientOrServerKey(const char *password, const char *salt, extern void scram_ClientKey(const uint8 *salted_password, uint8 *result);
int saltlen, int iterations, extern void scram_ServerKey(const uint8 *salted_password, uint8 *result);
const char *keystr, uint8 *result);
#endif /* SCRAM_COMMON_H */ #endif /* SCRAM_COMMON_H */
...@@ -46,6 +46,7 @@ typedef struct ...@@ -46,6 +46,7 @@ typedef struct
char *password; char *password;
/* We construct these */ /* We construct these */
uint8 SaltedPassword[SCRAM_KEY_LEN];
char *client_nonce; char *client_nonce;
char *client_first_message_bare; char *client_first_message_bare;
char *client_final_message_without_proof; char *client_final_message_without_proof;
...@@ -59,7 +60,7 @@ typedef struct ...@@ -59,7 +60,7 @@ typedef struct
/* These come from the server-final message */ /* These come from the server-final message */
char *server_final_message; char *server_final_message;
char ServerProof[SCRAM_KEY_LEN]; char ServerSignature[SCRAM_KEY_LEN];
} fe_scram_state; } fe_scram_state;
static bool read_server_first_message(fe_scram_state *state, char *input, static bool read_server_first_message(fe_scram_state *state, char *input,
...@@ -70,7 +71,7 @@ static char *build_client_first_message(fe_scram_state *state, ...@@ -70,7 +71,7 @@ static char *build_client_first_message(fe_scram_state *state,
PQExpBuffer errormessage); PQExpBuffer errormessage);
static char *build_client_final_message(fe_scram_state *state, static char *build_client_final_message(fe_scram_state *state,
PQExpBuffer errormessage); PQExpBuffer errormessage);
static bool verify_server_proof(fe_scram_state *state); static bool verify_server_signature(fe_scram_state *state);
static void calculate_client_proof(fe_scram_state *state, static void calculate_client_proof(fe_scram_state *state,
const char *client_final_message_without_proof, const char *client_final_message_without_proof,
uint8 *result); uint8 *result);
...@@ -216,12 +217,12 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -216,12 +217,12 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
goto error; goto error;
/* /*
* Verify server proof, to make sure we're talking to the genuine * Verify server signature, to make sure we're talking to the
* server. XXX: A fake server could simply not require * genuine server. XXX: A fake server could simply not require
* authentication, though. There is currently no option in libpq * authentication, though. There is currently no option in libpq
* to reject a connection, if SCRAM authentication did not happen. * to reject a connection, if SCRAM authentication did not happen.
*/ */
if (verify_server_proof(state)) if (verify_server_signature(state))
*success = true; *success = true;
else else
{ {
...@@ -486,12 +487,11 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -486,12 +487,11 @@ read_server_first_message(fe_scram_state *state, char *input,
* Read the final exchange message coming from the server. * Read the final exchange message coming from the server.
*/ */
static bool static bool
read_server_final_message(fe_scram_state *state, read_server_final_message(fe_scram_state *state, char *input,
char *input,
PQExpBuffer errormessage) PQExpBuffer errormessage)
{ {
char *encoded_server_proof; char *encoded_server_signature;
int server_proof_len; int server_signature_len;
state->server_final_message = strdup(input); state->server_final_message = strdup(input);
if (!state->server_final_message) if (!state->server_final_message)
...@@ -513,8 +513,8 @@ read_server_final_message(fe_scram_state *state, ...@@ -513,8 +513,8 @@ read_server_final_message(fe_scram_state *state,
} }
/* Parse the message. */ /* Parse the message. */
encoded_server_proof = read_attr_value(&input, 'v', errormessage); encoded_server_signature = read_attr_value(&input, 'v', errormessage);
if (encoded_server_proof == NULL) if (encoded_server_signature == NULL)
{ {
/* read_attr_value() has generated an error message */ /* read_attr_value() has generated an error message */
return false; return false;
...@@ -524,13 +524,13 @@ read_server_final_message(fe_scram_state *state, ...@@ -524,13 +524,13 @@ read_server_final_message(fe_scram_state *state,
printfPQExpBuffer(errormessage, printfPQExpBuffer(errormessage,
libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n")); libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
server_proof_len = pg_b64_decode(encoded_server_proof, server_signature_len = pg_b64_decode(encoded_server_signature,
strlen(encoded_server_proof), strlen(encoded_server_signature),
state->ServerProof); state->ServerSignature);
if (server_proof_len != SCRAM_KEY_LEN) if (server_signature_len != SCRAM_KEY_LEN)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(errormessage,
libpq_gettext("malformed SCRAM message (invalid server proof)\n")); libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
return false; return false;
} }
...@@ -552,8 +552,14 @@ calculate_client_proof(fe_scram_state *state, ...@@ -552,8 +552,14 @@ calculate_client_proof(fe_scram_state *state,
int i; int i;
scram_HMAC_ctx ctx; scram_HMAC_ctx ctx;
scram_ClientOrServerKey(state->password, state->salt, state->saltlen, /*
state->iterations, SCRAM_CLIENT_KEY_NAME, ClientKey); * Calculate SaltedPassword, and store it in 'state' so that we can reuse
* it later in verify_server_signature.
*/
scram_SaltedPassword(state->password, state->salt, state->saltlen,
state->iterations, state->SaltedPassword);
scram_ClientKey(state->SaltedPassword, ClientKey);
scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey); scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey);
scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN); scram_HMAC_init(&ctx, StoredKey, SCRAM_KEY_LEN);
...@@ -575,19 +581,17 @@ calculate_client_proof(fe_scram_state *state, ...@@ -575,19 +581,17 @@ calculate_client_proof(fe_scram_state *state,
} }
/* /*
* Validate the server proof, received as part of the final exchange message * Validate the server signature, received as part of the final exchange
* received from the server. * message received from the server.
*/ */
static bool static bool
verify_server_proof(fe_scram_state *state) verify_server_signature(fe_scram_state *state)
{ {
uint8 ServerSignature[SCRAM_KEY_LEN]; uint8 expected_ServerSignature[SCRAM_KEY_LEN];
uint8 ServerKey[SCRAM_KEY_LEN]; uint8 ServerKey[SCRAM_KEY_LEN];
scram_HMAC_ctx ctx; scram_HMAC_ctx ctx;
scram_ClientOrServerKey(state->password, state->salt, state->saltlen, scram_ServerKey(state->SaltedPassword, ServerKey);
state->iterations, SCRAM_SERVER_KEY_NAME,
ServerKey);
/* calculate ServerSignature */ /* calculate ServerSignature */
scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN); scram_HMAC_init(&ctx, ServerKey, SCRAM_KEY_LEN);
...@@ -602,9 +606,9 @@ verify_server_proof(fe_scram_state *state) ...@@ -602,9 +606,9 @@ verify_server_proof(fe_scram_state *state)
scram_HMAC_update(&ctx, scram_HMAC_update(&ctx,
state->client_final_message_without_proof, state->client_final_message_without_proof,
strlen(state->client_final_message_without_proof)); strlen(state->client_final_message_without_proof));
scram_HMAC_final(ServerSignature, &ctx); scram_HMAC_final(expected_ServerSignature, &ctx);
if (memcmp(ServerSignature, state->ServerProof, SCRAM_KEY_LEN) != 0) if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
return false; return false;
return true; return true;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment