diff --git a/coderd/userpassword/hashing_bench_test.go b/coderd/userpassword/hashing_bench_test.go new file mode 100644 index 0000000000000..d2f1c8ae3cebe --- /dev/null +++ b/coderd/userpassword/hashing_bench_test.go @@ -0,0 +1,70 @@ +package userpassword_test + +import ( + "crypto/sha256" + "testing" + + "github.com/coder/coder/cryptorand" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/pbkdf2" +) + +var ( + salt = []byte(must(cryptorand.String(16))) + secret = []byte(must(cryptorand.String(24))) + + resBcrypt []byte + resPbkdf2 []byte +) + +func BenchmarkBcryptMinCost(b *testing.B) { + var r []byte + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + r, _ = bcrypt.GenerateFromPassword(secret, bcrypt.MinCost) + } + + resBcrypt = r +} + +func BenchmarkPbkdf2MinCost(b *testing.B) { + var r []byte + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + r = pbkdf2.Key(secret, salt, 1024, 64, sha256.New) + } + + resPbkdf2 = r +} + +func BenchmarkBcryptDefaultCost(b *testing.B) { + var r []byte + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + r, _ = bcrypt.GenerateFromPassword(secret, bcrypt.DefaultCost) + } + + resBcrypt = r +} + +func BenchmarkPbkdf2(b *testing.B) { + var r []byte + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + r = pbkdf2.Key(secret, salt, 65536, 64, sha256.New) + } + + resPbkdf2 = r +} + +func must(s string, err error) string { + if err != nil { + panic(err) + } + + return s +} diff --git a/coderd/userpassword/userpassword.go b/coderd/userpassword/userpassword.go index 44ec3b16be994..e2c9d41a7763d 100644 --- a/coderd/userpassword/userpassword.go +++ b/coderd/userpassword/userpassword.go @@ -6,25 +6,67 @@ import ( "crypto/subtle" "encoding/base64" "fmt" + "os" "strconv" "strings" "golang.org/x/crypto/pbkdf2" + "golang.org/x/exp/slices" "golang.org/x/xerrors" ) -const ( - // This is the length of our output hash. - // bcrypt has a hash size of 59, so we rounded up to a power of 8. +var ( + // The base64 encoder used when producing the string representation of + // hashes. + base64Encoding = base64.RawStdEncoding + + // The number of iterations to use when generating the hash. This was chosen + // to make it about as fast as bcrypt hashes. Increasing this causes hashes + // to take longer to compute. + defaultHashIter = 65535 + + // This is the length of our output hash. bcrypt has a hash size of up to + // 60, so we rounded up to a power of 8. hashLength = 64 + // The scheme to include in our hashed password. hashScheme = "pbkdf2-sha256" + + // A salt size of 16 is the default in passlib. A minimum of 8 can be safely + // used. + defaultSaltSize = 16 + + // The simulated hash is used when trying to simulate password checks for + // users that don't exist. + simulatedHash, _ = Hash("hunter2") ) -// Compare checks the equality of passwords from a hashed pbkdf2 string. -// This uses pbkdf2 to ensure FIPS 140-2 compliance. See: +// Make password hashing much faster in tests. +func init() { + args := os.Args[1:] + + // Ensure this can never be enabled if running in server mode. + if slices.Contains(args, "server") { + return + } + + for _, flag := range args { + if strings.HasPrefix(flag, "-test.") { + defaultHashIter = 1 + return + } + } +} + +// Compare checks the equality of passwords from a hashed pbkdf2 string. This +// uses pbkdf2 to ensure FIPS 140-2 compliance. See: // https://csrc.nist.gov/csrc/media/templates/cryptographic-module-validation-program/documents/security-policies/140sp2261.pdf func Compare(hashed string, password string) (bool, error) { + // If the hased password provided is empty, simulate comparing a real hash. + if hashed == "" { + hashed = simulatedHash + } + if len(hashed) < hashLength { return false, xerrors.Errorf("hash too short: %d", len(hashed)) } @@ -42,7 +84,7 @@ func Compare(hashed string, password string) (bool, error) { if err != nil { return false, xerrors.Errorf("parse iter from hash: %w", err) } - salt, err := base64.RawStdEncoding.DecodeString(parts[3]) + salt, err := base64Encoding.DecodeString(parts[3]) if err != nil { return false, xerrors.Errorf("decode salt: %w", err) } @@ -50,29 +92,32 @@ func Compare(hashed string, password string) (bool, error) { if subtle.ConstantTimeCompare([]byte(hashWithSaltAndIter(password, salt, iter)), []byte(hashed)) != 1 { return false, nil } + return true, nil } // Hash generates a hash using pbkdf2. // See the Compare() comment for rationale. func Hash(password string) (string, error) { - // bcrypt uses a salt size of 16 bytes. - salt := make([]byte, 16) + salt := make([]byte, defaultSaltSize) _, err := rand.Read(salt) if err != nil { return "", xerrors.Errorf("read random bytes for salt: %w", err) } - // The default hash iteration is 1024 for speed. - // As this is increased, the password is hashed more. - return hashWithSaltAndIter(password, salt, 1024), nil + + return hashWithSaltAndIter(password, salt, defaultHashIter), nil } // Produces a string representation of the hash. func hashWithSaltAndIter(password string, salt []byte, iter int) string { - hash := pbkdf2.Key([]byte(password), salt, iter, hashLength, sha256.New) - hash = []byte(base64.RawStdEncoding.EncodeToString(hash)) - salt = []byte(base64.RawStdEncoding.EncodeToString(salt)) - // This format is similar to bcrypt. See: - // https://en.wikipedia.org/wiki/Bcrypt#Description - return fmt.Sprintf("$%s$%d$%s$%s", hashScheme, iter, salt, hash) + var ( + hash = pbkdf2.Key([]byte(password), salt, iter, hashLength, sha256.New) + encHash = make([]byte, base64Encoding.EncodedLen(len(hash))) + encSalt = make([]byte, base64Encoding.EncodedLen(len(salt))) + ) + + base64Encoding.Encode(encHash, hash) + base64Encoding.Encode(encSalt, salt) + + return fmt.Sprintf("$%s$%d$%s$%s", hashScheme, iter, encSalt, encHash) } diff --git a/coderd/users.go b/coderd/users.go index a38af8ba12f63..ca8517aa8a6cc 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -419,21 +419,19 @@ func (api *api) postLogin(rw http.ResponseWriter, r *http.Request) { if !httpapi.Read(rw, r, &loginWithPassword) { return } + user, err := api.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ Email: loginWithPassword.Email, }) - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ - Message: "invalid email or password", - }) - return - } - if err != nil { + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: fmt.Sprintf("get user: %s", err.Error()), }) return } + + // If the user doesn't exist, it will be a default struct. + equal, err := userpassword.Compare(string(user.HashedPassword), loginWithPassword.Password) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{