8000 Better error handling, remove spurious checks, match postgres coding … · gurjeet/postgres@226839c · GitHub
[go: up one dir, main page]

Skip to content

Commit 226839c

Browse files
committed
Better error handling, remove spurious checks, match postgres coding conventions, etc.
1 parent 4ba609d commit 226839c

File tree

10 files changed

+80
-57
lines changed

10 files changed

+80
-57
lines changed

src/backend/commands/user.c

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,13 @@ validate_and_get_salt(char *rolename, char **salt, const char **logdetail)
158158
*salt = NULL; /* No existing passwords, allow one to be generated */
159159
return true;
160160
}
161-
for (i = 0; i < num; i++) {
161+
for (i = 0; i < num; i++)
162+
{
162163
passtype = get_password_type(current_secrets[i]);
163164
if (passtype == PASSWORD_TYPE_MD5 || passtype == PASSWORD_TYPE_PLAINTEXT)
164165
continue; /* md5 uses rolename as salt so it is always the same and plaintext has no salt */
165-
else if (passtype == PASSWORD_TYPE_SCRAM_SHA_256) {
166+
else if (passtype == PASSWORD_TYPE_SCRAM_SHA_256)
167+
{
166168
int iterations;
167169
int key_length = 0;
168170
pg_cryptohash_type hash_type;
@@ -172,8 +174,10 @@ validate_and_get_salt(char *rolename, char **salt, const char **logdetail)
172174
parse_scram_secret(current_secrets[i], &iterations, &hash_type, &key_length,
173175
&salt1, stored_key, server_key);
174176

175-
if (salt2 != NULL) {
176-
if (strcmp(salt1, salt2)) {
177+
if (salt2 != NULL)
178+
{
179+
if (strcmp(salt1, salt2))
180+
{
177181
*logdetail = psprintf(_("inconsistent salts, clearing password"));
178182
*salt = NULL;
179183
return false;
@@ -1074,7 +1078,8 @@ AlterRole(ParseState *pstate, AlterRoleStmt *stmt)
10741078
*/
10751079

10761080
passExpiresIn_datum = expires_in_datum(dpassExpiresIn);
1077-
if (passExpiresIn_datum != PointerGetDatum(NULL)) {
1081+
if (passExpiresIn_datum != PointerGetDatum(NULL))
1082+
{
10781083
new_password_record[Anum_pg_auth_password_expiration - 1] = passExpiresIn_datum;
10791084
new_password_record_repl[Anum_pg_auth_password_expiration - 1] = true;
10801085
}
@@ -1214,7 +1219,8 @@ AlterRole(ParseState *pstate, AlterRoleStmt *stmt)
12141219

12151220
if (new_password_record_nulls[Anum_pg_auth_password_password - 1] == true) /* delete existing password */
12161221
{
1217-
if (HeapTupleIsValid(password_tuple)) {
1222+
if (HeapTupleIsValid(password_tuple))
1223+
{
12181224
CatalogTupleDelete(pg_auth_password_rel, &password_tuple->t_self);
12191225
ReleaseSysCache(password_tuple);
12201226
}
@@ -1779,7 +1785,8 @@ RenameRole(const char *oldname, const char *newname)
17791785

17801786
/* MD5 uses the username as salt, so just clear it on a rename */
17811787

1782-
if (HeapTupleIsValid(passtuple)) {
1788+
if (HeapTupleIsValid(passtuple))
1789+
{
17831790
CatalogTupleDelete(pg_auth_password_rel, &passtuple->t_self);
17841791
ereport(NOTICE,
17851792
(errmsg("MD5 password cleared because of role rename")));

src/backend/commands/variable.c

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,17 +1249,17 @@ check_password_duration(char **newval, void **extra, GucSource source)
12491249
char *val;
12501250
Interval *interval;
12511251

1252-
if (newval == NULL || *newval == NULL) {
1253-
extra = NULL;
1252+
if (*newval == NULL)
12541253
return true;
1255-
}
12561254

1257-
elog(NOTICE,"Setting password duration to \"%s\"",
1258-
*newval);
1255+
/* TODO: Change to DEBUG2 */
1256+
elog(NOTICE,"Setting password duration to \"%s\"", *newval);
12591257

12601258
while (isspace((unsigned char) *valueptr))
12611259
valueptr++;
1262-
if (*valueptr != '\'') {
1260+
1261+
if (*valueptr != '\'')
1262+
{
12631263
val = pstrdup(valueptr);
12641264
}
12651265
else
@@ -1289,24 +1289,20 @@ check_password_duration(char **newval, void **extra, GucSource source)
12891289

12901290
pfree(val);
12911291

1292-
if (!interval) {
1292+
if (!interval)
12931293
return false;
1294-
}
12951294

12961295
new_interval = guc_malloc(LOG, sizeof(Interval));
12971296
memcpy(new_interval, interval, sizeof(Interval));
12981297
pfree(interval);
12991298

1300-
/*
1301-
* Pass back data for assign_password_validity to use
1302-
*/
13031299
*extra = (void*) new_interval;
13041300

13051301
return true;
13061302
}
13071303

13081304
/*
1309-
* assign_password_validity: GUC assign_hook for timezone
1305+
* assign_password_validity: GUC assign_hook for password_duration
13101306
*/
13111307
void
13121308
assign_password_duration(const char *newval, void *extra)
@@ -1318,15 +1314,16 @@ assign_password_duration(const char *newval, void *extra)
13181314
}
13191315

13201316
/*
1321-
* show_password_validity: GUC show_hook for timezone
1317+
* show_password_validity: GUC show_hook for password_duration
13221318
*/
13231319
const char *
13241320
show_password_duration(void)
13251321
{
13261322
const char *intervalout;
1327-
if (default_password_duration == NULL) {
1323+
1324+
if (default_password_duration == NULL)
13281325
return "";
1329-
}
1326+
13301327
intervalout = DatumGetCString(DirectFunctionCall1(interval_out,
13311328
PointerGetDatum(default_password_duration)));
13321329

src/backend/libpq/auth-sasl.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
* should just pass NULL.
5050
*/
5151
int
52-
CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, const char **passwords, int num,
52+
CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, const char **passwords, int num_passwords,
5353
const char **logdetail)
5454
{
5555
StringInfoData sasl_mechs;
@@ -136,7 +136,7 @@ CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, const char **passwords, i
136136
* This is because we don't want to reveal to an attacker what
137137
* usernames are valid, nor which users have a valid password.
138138
*/
139-
opaq = mech->init(port, selected_mech, passwords, num);
139+
opaq = mech->init(port, selected_mech, passwords, num_passwords);
140140

141141
inputlen = pq_getmsgint(&buf, 4);
142142
if (inputlen == -1)

src/backend/libpq/auth-scram.c

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ typedef struct
138138
uint8 ServerKey[SCRAM_MAX_KEY_LEN];
139139
} scram_secret;
140140

141-
142141
typedef struct
143142
{
144143
scram_state_enum state;
@@ -302,12 +301,16 @@ scram_init(Port *port, const char *selected_mech, const char **secrets, const in
302301
state->secrets[i].StoredKey,
303302
state->secrets[i].ServerKey))
304303
{
305-
if (salt) {
304+
if (salt)
305+
{
306306
/* The stored iterations and salt must match or we cannot proceed, allow failure via mock */
307-
if (strcmp(salt, state->salt) || iterations != state->iterations) {
307+
if (strcmp(salt, state->salt) || iterations != state->iterations)
308+
{
308309
ereport(WARNING, (errmsg("inconsistent salt or iterations for user \"%s\"",
309-
state->port->user_name)));
310+
state->port->user_name)));
310311
got_secret = false; /* fail and allow mock creditials to be created */
312+
pfree(state->secrets);
313+
state->num_secrets = 0;
311314
break;
312315
}
313316
}
@@ -341,6 +344,7 @@ scram_init(Port *port, const char *selected_mech, const char **secrets, const in
341344
if (!got_secret)
342345
{
343346
state->secrets = palloc0(sizeof(scram_secret));
347+
state->num_secrets = 1;
344348

345349
mock_scram_secret(state->port->user_name, &state->hash_type,
346350
&state->iterations, &state->key_length,
@@ -510,13 +514,20 @@ pg_be_scram_build_secret(const char *password, const char *salt)
510514
if (rc == SASLPREP_SUCCESS)
511515
password = (const char *) prep_password;
512516

513-
/* Use passed in salt or generate random salt */
517+
/* Use passed-in salt, or generate random salt */
514518
if (!salt && !pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
519+
{
515520
ereport(ERROR,
516521
(errcode(ERRCODE_INTERNAL_ERROR),
517522
errmsg("could not generate random salt")));
523+
}
518524
else if (salt)
519-
pg_b64_decode(salt, strlen(salt), saltbuf, SCRAM_DEFAULT_SALT_LEN);
525+
{
526+
if (pg_b64_decode(salt, strlen(salt), saltbuf, SCRAM_DEFAULT_SALT_LEN) == -1)
527+
ereport(ERROR,
528+
(errcode(ERRCODE_INTERNAL_ERROR),
529+
errmsg("could not decode SCRAM salt")));
530+
}
520531

521532
result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
522533
saltbuf, SCRAM_DEFAULT_SALT_LEN,
@@ -1174,7 +1185,7 @@ verify_client_proof(scram_state *state)
11741185
for (j = 0; j < state->num_secrets; j++)
11751186
{
11761187
ctx = pg_hmac_create(state->hash_type);
1177-
elog(LOG, "Trying to verify password %d", j);
1188+
elog(LOG, "Trying to verify password %d", j); // TODO: Convert to DEBUG2
11781189

11791190
if (pg_hmac_init(ctx, state->secrets[j].StoredKey, state->key_length) < 0 ||
11801191
pg_hmac_update(ctx,
@@ -1190,12 +1201,14 @@ verify_client_proof(scram_state *state)
11901201
strlen(state->client_final_message_without_proof)) < 0 ||
11911202
pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
11921203
{
1193-
elog(LOG, "could not calculate client signature");
1204+
// TODO: Convert to DEBUG2
1205+
elog(LOG, "could not calculate client signature for secret %d", j);
11941206
pg_hmac_free(ctx);
11951207
continue;
11961208
}
11971209

1198-
elog(LOG, "success on %d", j);
1210+
// TODO: Convert to DEBUG2
1211+
elog(LOG, "succeeded on %d password", j);
11991212

12001213
pg_hmac_free(ctx);
12011214

@@ -1209,6 +1222,7 @@ verify_client_proof(scram_state *state)
12091222
elog(ERROR, "could not hash stored key: %s", errstr);
12101223

12111224
if (memcmp(client_StoredKey, state->secrets[j].StoredKey, state->key_length) == 0) {
1225+
// TODO: Convert to DEBUG2
12121226
elog(LOG, "Moving forward with Password %d", j);
12131227
state->chosen_secret = j;
12141228
return true;

src/backend/libpq/auth.c

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ static void set_authn_id(Port *port, const char *id);
5757
static int CheckPasswordAuth(Port *port, const char **logdetail);
5858
static int CheckPWChallengeAuth(Port *port, const char **logdetail);
5959

60-
static int CheckMD5Auth(Port *port, const char **passwords, int num, const char **logdetail);
60+
static int CheckMD5Auth(Port *port, const char **passwords, int num_passwords, const char **logdetail);
6161

6262

6363
/*----------------------------------------------------------------
@@ -789,7 +789,7 @@ CheckPasswordAuth(Port *port, const char **logdetail)
789789
{
790790
char *passwd;
791791
int result = STATUS_ERROR;
792-
int i, num;
792+
int i, num_passwords;
793793
char **passwords;
794794

795795
sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
@@ -798,16 +798,16 @@ CheckPasswordAuth(Port *port, const char **logdetail)
798798
if (passwd == NULL)
799799
return STATUS_EOF; /* client wouldn't send password */
800800

801-
passwords = get_role_passwords(port->user_name, logdetail, &num);
801+
passwords = get_role_passwords(port->user_name, logdetail, &num_passwords);
802802
if (passwords != NULL) {
803-
for (i = 0; i < num; i++)
803+
for (i = 0; i < num_passwords; i++)
804804
{
805805
result = plain_crypt_verify(port->user_name, passwords[i], passwd,
806806
logdetail);
807807
if (result == STATUS_OK)
808808
break; /* Found a matching password, no need to try any others */
809809
}
810-
for (i = 0; i < num; i++)
810+
for (i = 0; i < num_passwords; i++)
811811
pfree(passwords[i]);
812812

813813
pfree(passwords);
@@ -829,14 +829,14 @@ CheckPWChallengeAuth(Port *port, const char **logdetail)
829829
{
830830
bool scram_pw_avail = false;
831831
int auth_result = STATUS_ERROR;
832-
int i, num;
832+
int i, num_passwords;
833833
char **passwords;
834834

835835
Assert(port->hba->auth_method == uaSCRAM ||
836836
port->hba->auth_method == uaMD5);
837837

838838
/* First look up the user's passwords. */
839-
passwords = get_role_passwords(port->user_name, logdetail, &num);
839+
passwords = get_role_passwords(port->user_name, logdetail, &num_passwords);
840840

841841
/*
842842
* If 'md5' authentication is allowed, decide whether to perform 'md5' or
@@ -848,18 +848,19 @@ CheckPWChallengeAuth(Port *port, const char **logdetail)
848848
* had an MD5 password, CheckSASLAuth() with the SCRAM mechanism will
849849
* fail.
850850
*/
851-
if (passwords != NULL) {
852-
for (i = 0; i < num; i++)
851+
if (passwords != NULL)
852+
{
853+
for (i = 0; i < num_passwords; i++)
853854
if (get_password_type(passwords[i]) == PASSWORD_TYPE_SCRAM_SHA_256)
854855
scram_pw_avail = true;
855856

856857
if (port->hba->auth_method == uaMD5 && !scram_pw_avail)
857-
auth_result = CheckMD5Auth(port, (const char **) passwords, num, logdetail);
858+
auth_result = CheckMD5Auth(port, (const char **) passwords, num_passwords, logdetail);
858859
else
859-
auth_result = CheckSASLAuth(&pg_be_scram_mech, port, (const char **) passwords, num,
860+
auth_result = CheckSASLAuth(&pg_be_scram_mech, port, (const char **) passwords, num_passwords,
860861
logdetail);
861862

862-
for (i = 0; i < num; i++) {
863+
for (i = 0; i < num_passwords; i++) {
863864

864865
if (passwords[i] != NULL)
865866
pfree(passwords[i]);
@@ -877,7 +878,7 @@ CheckPWChallengeAuth(Port *port, const char **logdetail)
877878
}
878879

879880
static int
880-
CheckMD5Auth(Port *port, const char **passwords, int num, const char **logdetail)
881+
CheckMD5Auth(Port *port, const char **passwords, int num_passwords, const char **logdetail)
881882
{
882883
char md5Salt[4]; /* Password salt */
883884
char *passwd;
@@ -898,7 +899,7 @@ CheckMD5Auth(Port *port, const char **passwords, int num, const char **logdetail
898899
if (passwd == NULL)
899900
return STATUS_EOF; /* client wouldn't send password */
900901

901-
for (i = 0; i < num; i++)
902+
for (i = 0; i < num_passwords; i++)
902903
{
903904
result = md5_crypt_verify(port->user_name, passwords[i], passwd,
904905
md5Salt, 4, logdetail);

src/backend/libpq/crypt.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ get_role_passwords(const char *role, const char **logdetail, int *num)
5353
roleTup = SearchSysCache1(AUTHNAME, PointerGetDatum(role));
5454
if (!HeapTupleIsValid(roleTup))
5555
{
56-
*logdetail = psprintf(_("Role \"%s\" does not exist."),
57-
role);
56+
*logdetail = psprintf(_("Role \"%s\" does not exist."), role);
5857
return NULL; /* no such user */
5958
}
6059

@@ -95,6 +94,11 @@ get_role_passwords(const char *role, const char **logdetail, int *num)
9594
return NULL; /* user has no password */
9695
}
9796

97+
/*
98+
* TODO: Merge this and the following loop into one; allocate
99+
* valid_passwords array as long as n_members, then use repalloc() to shrink
100+
* the valid_passwords array.
101+
*/
98102
for (i = 0; i < passlist->n_members; i++)
99103
{
100104
HeapTuple tup = &passlist->members[i]->tuple;
@@ -189,7 +193,7 @@ encrypt_password(PasswordType target_type, const char *salt,
189193

190194
if (!pg_md5_encrypt(password, salt, strlen(salt),
191195
encrypted_password, &errstr))
192-
elog(ERROR, "password encryption failed %s", errstr);
196+
elog(ERROR, "password encryption failed: %s", errstr);
193197
return encrypted_password;
194198

195199
case PASSWORD_TYPE_SCRAM_SHA_256:

src/backend/utils/cache/relcache.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4014,12 +4014,12 @@ RelationCacheInitializePhase2(void)
40144014
Natts_pg_authid, Desc_pg_authid);
40154015
formrdesc("pg_auth_members", AuthMemRelation_Rowtype_Id, true,
40164016
Natts_pg_auth_members, Desc_pg_auth_members);
4017+
formrdesc("pg_auth_password", AuthPasswordRelation_Rowtype_Id, true,
4018+
Natts_pg_auth_password, Desc_pg_auth_password);
40174019
formrdesc("pg_shseclabel", SharedSecLabelRelation_Rowtype_Id, true,
40184020
Natts_pg_shseclabel, Desc_pg_shseclabel);
40194021
formrdesc("pg_subscription", SubscriptionRelation_Rowtype_Id, true,
40204022
Natts_pg_subscription, Desc_pg_subscription);
4021-
formrdesc("pg_auth_password", AuthPasswordRelation_Rowtype_Id, true,
4022-
Natts_pg_auth_password, Desc_pg_auth_password);
40234023

40244024
#define NUM_CRITICAL_SHARED_RELS 6 /* fix if you change list above */
40254025
}
@@ -4158,10 +4158,10 @@ RelationCacheInitializePhase3(void)
41584158
AuthIdRelationId);
41594159
load_critical_index(AuthMemMemRoleIndexId,
41604160
AuthMemRelationId);
4161-
load_critical_index(SharedSecLabelObjectIndexId,
4162-
SharedSecLabelRelationId);
41634161
load_critical_index(AuthPasswordRoleOidIndexId,
41644162
AuthPasswordRelationId);
4163+
load_critical_index(SharedSecLabelObjectIndexId,
4164+
SharedSecLabelRelationId);
41654165

41664166
#define NUM_CRITICAL_SHARED_INDEXES 7 /* fix if you change list above */
41674167

0 commit comments

Comments
 (0)
0