Commit f3049a60 authored by Peter Eisentraut's avatar Peter Eisentraut

Refactor channel binding code to fetch cbind_data only when necessary

As things stand now, channel binding data is fetched from OpenSSL and
saved into the SCRAM exchange context for any SSL connection attempted
for a SCRAM authentication, resulting in data fetched but not used if no
channel binding is used or if a different channel binding type is used
than what the data is here for.

Refactor the code in such a way that binding data is fetched from the
SSL stack only when a specific channel binding is used for both the
frontend and the backend.  In order to achieve that, save the libpq
connection context directly in the SCRAM exchange state, and add a
dependency to SSL in the low-level SCRAM routines.

This makes the interface in charge of initializing the SCRAM context
cleaner as all its data comes from either PGconn* (for frontend) or
Port* (for the backend).

Author: Michael Paquier <michael.paquier@gmail.com>
parent 3ad2afc2
...@@ -110,10 +110,8 @@ typedef struct ...@@ -110,10 +110,8 @@ typedef struct
const char *username; /* username from startup packet */ const char *username; /* username from startup packet */
Port *port;
char cbind_flag; char cbind_flag;
bool ssl_in_use;
const char *tls_finished_message;
size_t tls_finished_len;
char *channel_binding_type; char *channel_binding_type;
int iterations; int iterations;
...@@ -172,21 +170,15 @@ static char *scram_mock_salt(const char *username); ...@@ -172,21 +170,15 @@ static char *scram_mock_salt(const char *username);
* it will fail, as if an incorrect password was given. * it will fail, as if an incorrect password was given.
*/ */
void * void *
pg_be_scram_init(const char *username, pg_be_scram_init(Port *port,
const char *shadow_pass, const char *shadow_pass)
bool ssl_in_use,
const char *tls_finished_message,
size_t tls_finished_len)
{ {
scram_state *state; scram_state *state;
bool got_verifier; bool got_verifier;
state = (scram_state *) palloc0(sizeof(scram_state)); state = (scram_state *) palloc0(sizeof(scram_state));
state->port = port;
state->state = SCRAM_AUTH_INIT; state->state = SCRAM_AUTH_INIT;
state->username = username;
state->ssl_in_use = ssl_in_use;
state->tls_finished_message = tls_finished_message;
state->tls_finished_len = tls_finished_len;
state->channel_binding_type = NULL; state->channel_binding_type = NULL;
/* /*
...@@ -209,7 +201,7 @@ pg_be_scram_init(const char *username, ...@@ -209,7 +201,7 @@ pg_be_scram_init(const char *username,
*/ */
ereport(LOG, ereport(LOG,
(errmsg("invalid SCRAM verifier for user \"%s\"", (errmsg("invalid SCRAM verifier for user \"%s\"",
username))); state->port->user_name)));
got_verifier = false; got_verifier = false;
} }
} }
...@@ -220,7 +212,7 @@ pg_be_scram_init(const char *username, ...@@ -220,7 +212,7 @@ pg_be_scram_init(const char *username,
* authentication with an MD5 hash.) * authentication with an MD5 hash.)
*/ */
state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM verifier."), state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM verifier."),
state->username); state->port->user_name);
got_verifier = false; got_verifier = false;
} }
} }
...@@ -242,8 +234,8 @@ pg_be_scram_init(const char *username, ...@@ -242,8 +234,8 @@ pg_be_scram_init(const char *username,
*/ */
if (!got_verifier) if (!got_verifier)
{ {
mock_scram_verifier(username, &state->iterations, &state->salt, mock_scram_verifier(state->port->user_name, &state->iterations,
state->StoredKey, state->ServerKey); &state->salt, state->StoredKey, state->ServerKey);
state->doomed = true; state->doomed = true;
} }
...@@ -815,7 +807,7 @@ read_client_first_message(scram_state *state, char *input) ...@@ -815,7 +807,7 @@ read_client_first_message(scram_state *state, char *input)
* it supports channel binding, which in this implementation is * it supports channel binding, which in this implementation is
* the case if a connection is using SSL. * the case if a connection is using SSL.
*/ */
if (state->ssl_in_use) if (state->port->ssl_in_use)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
errmsg("SCRAM channel binding negotiation error"), errmsg("SCRAM channel binding negotiation error"),
...@@ -839,7 +831,7 @@ read_client_first_message(scram_state *state, char *input) ...@@ -839,7 +831,7 @@ read_client_first_message(scram_state *state, char *input)
{ {
char *channel_binding_type; char *channel_binding_type;
if (!state->ssl_in_use) if (!state->port->ssl_in_use)
{ {
/* /*
* Without SSL, we don't support channel binding. * Without SSL, we don't support channel binding.
...@@ -1120,8 +1112,9 @@ read_client_final_message(scram_state *state, char *input) ...@@ -1120,8 +1112,9 @@ read_client_final_message(scram_state *state, char *input)
*/ */
if (strcmp(state->channel_binding_type, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0) if (strcmp(state->channel_binding_type, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
{ {
cbind_data = state->tls_finished_message; #ifdef USE_SSL
cbind_data_len = state->tls_finished_len; cbind_data = be_tls_get_peer_finished(state->port, &cbind_data_len);
#endif
} }
else else
{ {
......
...@@ -873,8 +873,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) ...@@ -873,8 +873,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
int inputlen; int inputlen;
int result; int result;
bool initial; bool initial;
char *tls_finished = NULL;
size_t tls_finished_len = 0;
/* /*
* SASL auth is not supported for protocol versions before 3, because it * SASL auth is not supported for protocol versions before 3, because it
...@@ -915,17 +913,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) ...@@ -915,17 +913,6 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1); sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs, p - sasl_mechs + 1);
pfree(sasl_mechs); pfree(sasl_mechs);
#ifdef USE_SSL
/*
* Get data for channel binding.
*/
if (port->ssl_in_use)
{
tls_finished = be_tls_get_peer_finished(port, &tls_finished_len);
}
#endif
/* /*
* Initialize the status tracker for message exchanges. * Initialize the status tracker for message exchanges.
* *
...@@ -937,11 +924,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) ...@@ -937,11 +924,7 @@ CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
* This is because we don't want to reveal to an attacker what usernames * This is because we don't want to reveal to an attacker what usernames
* are valid, nor which users have a valid password. * are valid, nor which users have a valid password.
*/ */
scram_opaq = pg_be_scram_init(port->user_name, scram_opaq = pg_be_scram_init(port, shadow_pass);
shadow_pass,
port->ssl_in_use,
tls_finished,
tls_finished_len);
/* /*
* Loop through SASL message exchange. This exchange can consist of * Loop through SASL message exchange. This exchange can consist of
......
...@@ -13,15 +13,15 @@ ...@@ -13,15 +13,15 @@
#ifndef PG_SCRAM_H #ifndef PG_SCRAM_H
#define PG_SCRAM_H #define PG_SCRAM_H
#include "libpq/libpq-be.h"
/* Status codes for message exchange */ /* Status codes for message exchange */
#define SASL_EXCHANGE_CONTINUE 0 #define SASL_EXCHANGE_CONTINUE 0
#define SASL_EXCHANGE_SUCCESS 1 #define SASL_EXCHANGE_SUCCESS 1
#define SASL_EXCHANGE_FAILURE 2 #define SASL_EXCHANGE_FAILURE 2
/* Routines dedicated to authentication */ /* Routines dedicated to authentication */
extern void *pg_be_scram_init(const char *username, const char *shadow_pass, extern void *pg_be_scram_init(Port *port, const char *shadow_pass);
bool ssl_in_use, const char *tls_finished_message,
size_t tls_finished_len);
extern int pg_be_scram_exchange(void *opaq, char *input, int inputlen, extern int pg_be_scram_exchange(void *opaq, char *input, int inputlen,
char **output, int *outputlen, char **logdetail); char **output, int *outputlen, char **logdetail);
......
...@@ -42,13 +42,9 @@ typedef struct ...@@ -42,13 +42,9 @@ typedef struct
fe_scram_state_enum state; fe_scram_state_enum state;
/* These are supplied by the user */ /* These are supplied by the user */
const char *username; PGconn *conn;
char *password; char *password;
bool ssl_in_use;
char *tls_finished_message;
size_t tls_finished_len;
char *sasl_mechanism; char *sasl_mechanism;
const char *channel_binding_type;
/* We construct these */ /* We construct these */
uint8 SaltedPassword[SCRAM_KEY_LEN]; uint8 SaltedPassword[SCRAM_KEY_LEN];
...@@ -68,14 +64,10 @@ typedef struct ...@@ -68,14 +64,10 @@ typedef struct
char ServerSignature[SCRAM_KEY_LEN]; char ServerSignature[SCRAM_KEY_LEN];
} fe_scram_state; } fe_scram_state;
static bool read_server_first_message(fe_scram_state *state, char *input, static bool read_server_first_message(fe_scram_state *state, char *input);
PQExpBuffer errormessage); static bool read_server_final_message(fe_scram_state *state, char *input);
static bool read_server_final_message(fe_scram_state *state, char *input, static char *build_client_first_message(fe_scram_state *state);
PQExpBuffer errormessage); static char *build_client_final_message(fe_scram_state *state);
static char *build_client_first_message(fe_scram_state *state,
PQExpBuffer errormessage);
static char *build_client_final_message(fe_scram_state *state,
PQExpBuffer errormessage);
static bool verify_server_signature(fe_scram_state *state); static bool verify_server_signature(fe_scram_state *state);
static void calculate_client_proof(fe_scram_state *state, static void calculate_client_proof(fe_scram_state *state,
const char *client_final_message_without_proof, const char *client_final_message_without_proof,
...@@ -84,18 +76,11 @@ static bool pg_frontend_random(char *dst, int len); ...@@ -84,18 +76,11 @@ static bool pg_frontend_random(char *dst, int len);
/* /*
* Initialize SCRAM exchange status. * Initialize SCRAM exchange status.
*
* The non-const char* arguments should be passed in malloc'ed. They will be
* freed by pg_fe_scram_free().
*/ */
void * void *
pg_fe_scram_init(const char *username, pg_fe_scram_init(PGconn *conn,
const char *password, const char *password,
bool ssl_in_use, const char *sasl_mechanism)
const char *sasl_mechanism,
const char *channel_binding_type,
char *tls_finished_message,
size_t tls_finished_len)
{ {
fe_scram_state *state; fe_scram_state *state;
char *prep_password; char *prep_password;
...@@ -107,13 +92,9 @@ pg_fe_scram_init(const char *username, ...@@ -107,13 +92,9 @@ pg_fe_scram_init(const char *username,
if (!state) if (!state)
return NULL; return NULL;
memset(state, 0, sizeof(fe_scram_state)); memset(state, 0, sizeof(fe_scram_state));
state->conn = conn;
state->state = FE_SCRAM_INIT; state->state = FE_SCRAM_INIT;
state->username = username;
state->ssl_in_use = ssl_in_use;
state->tls_finished_message = tls_finished_message;
state->tls_finished_len = tls_finished_len;
state->sasl_mechanism = strdup(sasl_mechanism); state->sasl_mechanism = strdup(sasl_mechanism);
state->channel_binding_type = channel_binding_type;
if (!state->sasl_mechanism) if (!state->sasl_mechanism)
{ {
...@@ -154,8 +135,6 @@ pg_fe_scram_free(void *opaq) ...@@ -154,8 +135,6 @@ pg_fe_scram_free(void *opaq)
if (state->password) if (state->password)
free(state->password); free(state->password);
if (state->tls_finished_message)
free(state->tls_finished_message);
if (state->sasl_mechanism) if (state->sasl_mechanism)
free(state->sasl_mechanism); free(state->sasl_mechanism);
...@@ -188,9 +167,10 @@ pg_fe_scram_free(void *opaq) ...@@ -188,9 +167,10 @@ pg_fe_scram_free(void *opaq)
void void
pg_fe_scram_exchange(void *opaq, char *input, int inputlen, pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
char **output, int *outputlen, char **output, int *outputlen,
bool *done, bool *success, PQExpBuffer errorMessage) bool *done, bool *success)
{ {
fe_scram_state *state = (fe_scram_state *) opaq; fe_scram_state *state = (fe_scram_state *) opaq;
PGconn *conn = state->conn;
*done = false; *done = false;
*success = false; *success = false;
...@@ -205,13 +185,13 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -205,13 +185,13 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
{ {
if (inputlen == 0) if (inputlen == 0)
{ {
printfPQExpBuffer(errorMessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (empty message)\n")); libpq_gettext("malformed SCRAM message (empty message)\n"));
goto error; goto error;
} }
if (inputlen != strlen(input)) if (inputlen != strlen(input))
{ {
printfPQExpBuffer(errorMessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (length mismatch)\n")); libpq_gettext("malformed SCRAM message (length mismatch)\n"));
goto error; goto error;
} }
...@@ -221,7 +201,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -221,7 +201,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
{ {
case FE_SCRAM_INIT: case FE_SCRAM_INIT:
/* Begin the SCRAM handshake, by sending client nonce */ /* Begin the SCRAM handshake, by sending client nonce */
*output = build_client_first_message(state, errorMessage); *output = build_client_first_message(state);
if (*output == NULL) if (*output == NULL)
goto error; goto error;
...@@ -232,10 +212,10 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -232,10 +212,10 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
case FE_SCRAM_NONCE_SENT: case FE_SCRAM_NONCE_SENT:
/* Receive salt and server nonce, send response. */ /* Receive salt and server nonce, send response. */
if (!read_server_first_message(state, input, errorMessage)) if (!read_server_first_message(state, input))
goto error; goto error;
*output = build_client_final_message(state, errorMessage); *output = build_client_final_message(state);
if (*output == NULL) if (*output == NULL)
goto error; goto error;
...@@ -246,7 +226,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -246,7 +226,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
case FE_SCRAM_PROOF_SENT: case FE_SCRAM_PROOF_SENT:
/* Receive server signature */ /* Receive server signature */
if (!read_server_final_message(state, input, errorMessage)) if (!read_server_final_message(state, input))
goto error; goto error;
/* /*
...@@ -260,7 +240,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -260,7 +240,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
else else
{ {
*success = false; *success = false;
printfPQExpBuffer(errorMessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("incorrect server signature\n")); libpq_gettext("incorrect server signature\n"));
} }
*done = true; *done = true;
...@@ -269,7 +249,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen, ...@@ -269,7 +249,7 @@ pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
default: default:
/* shouldn't happen */ /* shouldn't happen */
printfPQExpBuffer(errorMessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("invalid SCRAM exchange state\n")); libpq_gettext("invalid SCRAM exchange state\n"));
goto error; goto error;
} }
...@@ -327,8 +307,9 @@ read_attr_value(char **input, char attr, PQExpBuffer errorMessage) ...@@ -327,8 +307,9 @@ read_attr_value(char **input, char attr, PQExpBuffer errorMessage)
* Build the first exchange message sent by the client. * Build the first exchange message sent by the client.
*/ */
static char * static char *
build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) build_client_first_message(fe_scram_state *state)
{ {
PGconn *conn = state->conn;
char raw_nonce[SCRAM_RAW_NONCE_LEN + 1]; char raw_nonce[SCRAM_RAW_NONCE_LEN + 1];
char *result; char *result;
int channel_info_len; int channel_info_len;
...@@ -341,7 +322,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -341,7 +322,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
*/ */
if (!pg_frontend_random(raw_nonce, SCRAM_RAW_NONCE_LEN)) if (!pg_frontend_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("could not generate nonce\n")); libpq_gettext("could not generate nonce\n"));
return NULL; return NULL;
} }
...@@ -349,7 +330,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -349,7 +330,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1); state->client_nonce = malloc(pg_b64_enc_len(SCRAM_RAW_NONCE_LEN) + 1);
if (state->client_nonce == NULL) if (state->client_nonce == NULL)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return NULL; return NULL;
} }
...@@ -370,11 +351,11 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -370,11 +351,11 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
*/ */
if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0) if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
{ {
Assert(state->ssl_in_use); Assert(conn->ssl_in_use);
appendPQExpBuffer(&buf, "p=%s", state->channel_binding_type); appendPQExpBuffer(&buf, "p=%s", conn->scram_channel_binding);
} }
else if (state->channel_binding_type == NULL || else if (conn->scram_channel_binding == NULL ||
strlen(state->channel_binding_type) == 0) strlen(conn->scram_channel_binding) == 0)
{ {
/* /*
* Client has chosen to not show to server that it supports channel * Client has chosen to not show to server that it supports channel
...@@ -382,7 +363,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -382,7 +363,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
*/ */
appendPQExpBuffer(&buf, "n"); appendPQExpBuffer(&buf, "n");
} }
else if (state->ssl_in_use) else if (conn->ssl_in_use)
{ {
/* /*
* Client supports channel binding, but thinks the server does not. * Client supports channel binding, but thinks the server does not.
...@@ -423,7 +404,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -423,7 +404,7 @@ build_client_first_message(fe_scram_state *state, PQExpBuffer errormessage)
oom_error: oom_error:
termPQExpBuffer(&buf); termPQExpBuffer(&buf);
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return NULL; return NULL;
} }
...@@ -432,9 +413,10 @@ oom_error: ...@@ -432,9 +413,10 @@ oom_error:
* Build the final exchange message sent from the client. * Build the final exchange message sent from the client.
*/ */
static char * static char *
build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage) build_client_final_message(fe_scram_state *state)
{ {
PQExpBufferData buf; PQExpBufferData buf;
PGconn *conn = state->conn;
uint8 client_proof[SCRAM_KEY_LEN]; uint8 client_proof[SCRAM_KEY_LEN];
char *result; char *result;
...@@ -450,22 +432,25 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -450,22 +432,25 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
*/ */
if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0) if (strcmp(state->sasl_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
{ {
char *cbind_data; char *cbind_data = NULL;
size_t cbind_data_len; size_t cbind_data_len = 0;
size_t cbind_header_len; size_t cbind_header_len;
char *cbind_input; char *cbind_input;
size_t cbind_input_len; size_t cbind_input_len;
if (strcmp(state->channel_binding_type, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0) if (strcmp(conn->scram_channel_binding, SCRAM_CHANNEL_BINDING_TLS_UNIQUE) == 0)
{ {
cbind_data = state->tls_finished_message; #ifdef USE_SSL
cbind_data_len = state->tls_finished_len; cbind_data = pgtls_get_finished(state->conn, &cbind_data_len);
if (cbind_data == NULL)
goto oom_error;
#endif
} }
else else
{ {
/* should not happen */ /* should not happen */
termPQExpBuffer(&buf); termPQExpBuffer(&buf);
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("invalid channel binding type\n")); libpq_gettext("invalid channel binding type\n"));
return NULL; return NULL;
} }
...@@ -473,37 +458,46 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -473,37 +458,46 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
/* should not happen */ /* should not happen */
if (cbind_data == NULL || cbind_data_len == 0) if (cbind_data == NULL || cbind_data_len == 0)
{ {
if (cbind_data != NULL)
free(cbind_data);
termPQExpBuffer(&buf); termPQExpBuffer(&buf);
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("empty channel binding data for channel binding type \"%s\"\n"), libpq_gettext("empty channel binding data for channel binding type \"%s\"\n"),
state->channel_binding_type); conn->scram_channel_binding);
return NULL; return NULL;
} }
appendPQExpBuffer(&buf, "c="); appendPQExpBuffer(&buf, "c=");
cbind_header_len = 4 + strlen(state->channel_binding_type); /* p=type,, */ /* p=type,, */
cbind_header_len = 4 + strlen(conn->scram_channel_binding);
cbind_input_len = cbind_header_len + cbind_data_len; cbind_input_len = cbind_header_len + cbind_data_len;
cbind_input = malloc(cbind_input_len); cbind_input = malloc(cbind_input_len);
if (!cbind_input) if (!cbind_input)
{
free(cbind_data);
goto oom_error; goto oom_error;
snprintf(cbind_input, cbind_input_len, "p=%s,,", state->channel_binding_type); }
snprintf(cbind_input, cbind_input_len, "p=%s,,",
conn->scram_channel_binding);
memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len))) if (!enlargePQExpBuffer(&buf, pg_b64_enc_len(cbind_input_len)))
{ {
free(cbind_data);
free(cbind_input); free(cbind_input);
goto oom_error; goto oom_error;
} }
buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len); buf.len += pg_b64_encode(cbind_input, cbind_input_len, buf.data + buf.len);
buf.data[buf.len] = '\0'; buf.data[buf.len] = '\0';
free(cbind_data);
free(cbind_input); free(cbind_input);
} }
else if (state->channel_binding_type == NULL || else if (conn->scram_channel_binding == NULL ||
strlen(state->channel_binding_type) == 0) strlen(conn->scram_channel_binding) == 0)
appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */ appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */
else if (state->ssl_in_use) else if (conn->ssl_in_use)
appendPQExpBuffer(&buf, "c=eSws"); /* base64 of "y,," */ appendPQExpBuffer(&buf, "c=eSws"); /* base64 of "y,," */
else else
appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */ appendPQExpBuffer(&buf, "c=biws"); /* base64 of "n,," */
...@@ -541,7 +535,7 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage) ...@@ -541,7 +535,7 @@ build_client_final_message(fe_scram_state *state, PQExpBuffer errormessage)
oom_error: oom_error:
termPQExpBuffer(&buf); termPQExpBuffer(&buf);
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return NULL; return NULL;
} }
...@@ -550,9 +544,9 @@ oom_error: ...@@ -550,9 +544,9 @@ oom_error:
* Read the first exchange message coming from the server. * Read the first exchange message coming from the server.
*/ */
static bool static bool
read_server_first_message(fe_scram_state *state, char *input, read_server_first_message(fe_scram_state *state, char *input)
PQExpBuffer errormessage)
{ {
PGconn *conn = state->conn;
char *iterations_str; char *iterations_str;
char *endptr; char *endptr;
char *encoded_salt; char *encoded_salt;
...@@ -561,13 +555,14 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -561,13 +555,14 @@ read_server_first_message(fe_scram_state *state, char *input,
state->server_first_message = strdup(input); state->server_first_message = strdup(input);
if (state->server_first_message == NULL) if (state->server_first_message == NULL)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return false; return false;
} }
/* parse the message */ /* parse the message */
nonce = read_attr_value(&input, 'r', errormessage); nonce = read_attr_value(&input, 'r',
&conn->errorMessage);
if (nonce == NULL) if (nonce == NULL)
{ {
/* read_attr_value() has generated an error string */ /* read_attr_value() has generated an error string */
...@@ -578,7 +573,7 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -578,7 +573,7 @@ read_server_first_message(fe_scram_state *state, char *input,
if (strlen(nonce) < strlen(state->client_nonce) || if (strlen(nonce) < strlen(state->client_nonce) ||
memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0) memcmp(nonce, state->client_nonce, strlen(state->client_nonce)) != 0)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("invalid SCRAM response (nonce mismatch)\n")); libpq_gettext("invalid SCRAM response (nonce mismatch)\n"));
return false; return false;
} }
...@@ -586,12 +581,12 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -586,12 +581,12 @@ read_server_first_message(fe_scram_state *state, char *input,
state->nonce = strdup(nonce); state->nonce = strdup(nonce);
if (state->nonce == NULL) if (state->nonce == NULL)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return false; return false;
} }
encoded_salt = read_attr_value(&input, 's', errormessage); encoded_salt = read_attr_value(&input, 's', &conn->errorMessage);
if (encoded_salt == NULL) if (encoded_salt == NULL)
{ {
/* read_attr_value() has generated an error string */ /* read_attr_value() has generated an error string */
...@@ -600,7 +595,7 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -600,7 +595,7 @@ read_server_first_message(fe_scram_state *state, char *input,
state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt))); state->salt = malloc(pg_b64_dec_len(strlen(encoded_salt)));
if (state->salt == NULL) if (state->salt == NULL)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return false; return false;
} }
...@@ -608,7 +603,7 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -608,7 +603,7 @@ read_server_first_message(fe_scram_state *state, char *input,
strlen(encoded_salt), strlen(encoded_salt),
state->salt); state->salt);
iterations_str = read_attr_value(&input, 'i', errormessage); iterations_str = read_attr_value(&input, 'i', &conn->errorMessage);
if (iterations_str == NULL) if (iterations_str == NULL)
{ {
/* read_attr_value() has generated an error string */ /* read_attr_value() has generated an error string */
...@@ -617,13 +612,13 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -617,13 +612,13 @@ read_server_first_message(fe_scram_state *state, char *input,
state->iterations = strtol(iterations_str, &endptr, 10); state->iterations = strtol(iterations_str, &endptr, 10);
if (*endptr != '\0' || state->iterations < 1) if (*endptr != '\0' || state->iterations < 1)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (invalid iteration count)\n")); libpq_gettext("malformed SCRAM message (invalid iteration count)\n"));
return false; return false;
} }
if (*input != '\0') if (*input != '\0')
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (garbage at end of server-first-message)\n")); libpq_gettext("malformed SCRAM message (garbage at end of server-first-message)\n"));
return true; return true;
...@@ -633,16 +628,16 @@ read_server_first_message(fe_scram_state *state, char *input, ...@@ -633,16 +628,16 @@ read_server_first_message(fe_scram_state *state, char *input,
* Read the final exchange message coming from the server. * Read the final exchange message coming from the server.
*/ */
static bool static bool
read_server_final_message(fe_scram_state *state, char *input, read_server_final_message(fe_scram_state *state, char *input)
PQExpBuffer errormessage)
{ {
PGconn *conn = state->conn;
char *encoded_server_signature; char *encoded_server_signature;
int server_signature_len; int server_signature_len;
state->server_final_message = strdup(input); state->server_final_message = strdup(input);
if (!state->server_final_message) if (!state->server_final_message)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("out of memory\n")); libpq_gettext("out of memory\n"));
return false; return false;
} }
...@@ -650,16 +645,18 @@ read_server_final_message(fe_scram_state *state, char *input, ...@@ -650,16 +645,18 @@ read_server_final_message(fe_scram_state *state, char *input,
/* Check for error result. */ /* Check for error result. */
if (*input == 'e') if (*input == 'e')
{ {
char *errmsg = read_attr_value(&input, 'e', errormessage); char *errmsg = read_attr_value(&input, 'e',
&conn->errorMessage);
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("error received from server in SCRAM exchange: %s\n"), libpq_gettext("error received from server in SCRAM exchange: %s\n"),
errmsg); errmsg);
return false; return false;
} }
/* Parse the message. */ /* Parse the message. */
encoded_server_signature = read_attr_value(&input, 'v', errormessage); encoded_server_signature = read_attr_value(&input, 'v',
&conn->errorMessage);
if (encoded_server_signature == NULL) if (encoded_server_signature == NULL)
{ {
/* read_attr_value() has generated an error message */ /* read_attr_value() has generated an error message */
...@@ -667,7 +664,7 @@ read_server_final_message(fe_scram_state *state, char *input, ...@@ -667,7 +664,7 @@ read_server_final_message(fe_scram_state *state, char *input,
} }
if (*input != '\0') if (*input != '\0')
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n")); libpq_gettext("malformed SCRAM message (garbage at end of server-final-message)\n"));
server_signature_len = pg_b64_decode(encoded_server_signature, server_signature_len = pg_b64_decode(encoded_server_signature,
...@@ -675,7 +672,7 @@ read_server_final_message(fe_scram_state *state, char *input, ...@@ -675,7 +672,7 @@ read_server_final_message(fe_scram_state *state, char *input,
state->ServerSignature); state->ServerSignature);
if (server_signature_len != SCRAM_KEY_LEN) if (server_signature_len != SCRAM_KEY_LEN)
{ {
printfPQExpBuffer(errormessage, printfPQExpBuffer(&conn->errorMessage,
libpq_gettext("malformed SCRAM message (invalid server signature)\n")); libpq_gettext("malformed SCRAM message (invalid server signature)\n"));
return false; return false;
} }
......
...@@ -491,8 +491,6 @@ pg_SASL_init(PGconn *conn, int payloadlen) ...@@ -491,8 +491,6 @@ pg_SASL_init(PGconn *conn, int payloadlen)
bool success; bool success;
const char *selected_mechanism; const char *selected_mechanism;
PQExpBufferData mechanism_buf; PQExpBufferData mechanism_buf;
char *tls_finished = NULL;
size_t tls_finished_len = 0;
char *password; char *password;
initPQExpBuffer(&mechanism_buf); initPQExpBuffer(&mechanism_buf);
...@@ -570,32 +568,15 @@ pg_SASL_init(PGconn *conn, int payloadlen) ...@@ -570,32 +568,15 @@ pg_SASL_init(PGconn *conn, int payloadlen)
goto error; goto error;
} }
#ifdef USE_SSL
/*
* Get data for channel binding.
*/
if (strcmp(selected_mechanism, SCRAM_SHA256_PLUS_NAME) == 0)
{
tls_finished = pgtls_get_finished(conn, &tls_finished_len);
if (tls_finished == NULL)
goto oom_error;
}
#endif
/* /*
* Initialize the SASL state information with all the information gathered * Initialize the SASL state information with all the information gathered
* during the initial exchange. * during the initial exchange.
* *
* Note: Only tls-unique is supported for the moment. * Note: Only tls-unique is supported for the moment.
*/ */
conn->sasl_state = pg_fe_scram_init(conn->pguser, conn->sasl_state = pg_fe_scram_init(conn,
password, password,
conn->ssl_in_use, selected_mechanism);
selected_mechanism,
conn->scram_channel_binding,
tls_finished,
tls_finished_len);
if (!conn->sasl_state) if (!conn->sasl_state)
goto oom_error; goto oom_error;
...@@ -603,7 +584,7 @@ pg_SASL_init(PGconn *conn, int payloadlen) ...@@ -603,7 +584,7 @@ pg_SASL_init(PGconn *conn, int payloadlen)
pg_fe_scram_exchange(conn->sasl_state, pg_fe_scram_exchange(conn->sasl_state,
NULL, -1, NULL, -1,
&initialresponse, &initialresponselen, &initialresponse, &initialresponselen,
&done, &success, &conn->errorMessage); &done, &success);
if (done && !success) if (done && !success)
goto error; goto error;
...@@ -684,7 +665,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final) ...@@ -684,7 +665,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
pg_fe_scram_exchange(conn->sasl_state, pg_fe_scram_exchange(conn->sasl_state,
challenge, payloadlen, challenge, payloadlen,
&output, &outputlen, &output, &outputlen,
&done, &success, &conn->errorMessage); &done, &success);
free(challenge); /* don't need the input anymore */ free(challenge); /* don't need the input anymore */
if (final && !done) if (final && !done)
......
...@@ -23,17 +23,13 @@ extern int pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn); ...@@ -23,17 +23,13 @@ extern int pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn);
extern char *pg_fe_getauthname(PQExpBuffer errorMessage); extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
/* Prototypes for functions in fe-auth-scram.c */ /* Prototypes for functions in fe-auth-scram.c */
extern void *pg_fe_scram_init(const char *username, extern void *pg_fe_scram_init(PGconn *conn,
const char *password, const char *password,
bool ssl_in_use, const char *sasl_mechanism);
const char *sasl_mechanism,
const char *channel_binding_type,
char *tls_finished_message,
size_t tls_finished_len);
extern void pg_fe_scram_free(void *opaq); extern void pg_fe_scram_free(void *opaq);
extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen, extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
char **output, int *outputlen, char **output, int *outputlen,
bool *done, bool *success, PQExpBuffer errorMessage); bool *done, bool *success);
extern char *pg_fe_scram_build_verifier(const char *password); extern char *pg_fe_scram_build_verifier(const char *password);
#endif /* FE_AUTH_H */ #endif /* FE_AUTH_H */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment