8000 Sasl improvements by hardmet · Pull Request #1 · mpain/postgresql-async · GitHub
[go: up one dir, main page]

Skip to content

Sasl improvements #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sasl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
< 8000 td id="diff-89b91dd55bd4f9262fa3c21e4464cc5590153ef2213b0ddef9c04662283d4db0R25" data-line-number="25" class="blob-num blob-num-context js-linkable-line-number js-blob-rnum">
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.github.mauricio.async.db.pool.TimeoutScheduler
import com.github.mauricio.async.db.postgresql.codec.{PostgreSQLConnectionDelegate, PostgreSQLConnectionHandler}
import com.github.mauricio.async.db.postgresql.column.{PostgreSQLColumnDecoderRegistry, PostgreSQLColumnEncoderRegistry}
import com.github.mauricio.async.db.postgresql.exceptions._
import com.github.mauricio.async.db.postgresql.sasl.SaslEngine
import com.github.mauricio.async.db.postgresql.sasl.{InvalidFinalServerMessageProof, MissingAuthParamException, SaslEngine}
import com.github.mauricio.async.db.postgresql.sasl.SaslEngine.SASLContext
import com.github.mauricio.async.db.util._
import com.github.mauricio.async.db.{Configuration, Connection}
Expand Down Expand Up @@ -229,7 +229,7 @@ class PostgreSQLConnection
override def onAuthenticationResponse(message: AuthenticationMessage) {

message match {
case m: AuthenticationOkMessage => {
case _: AuthenticationOkMessage => {
log.debug("Successfully logged in to database")
this.authenticated = true
saslCtx = None
Expand All @@ -246,7 +246,7 @@ class PostgreSQLConnection
saslCtx = Option(ctx)
write(message)
} else {
throw new DatabaseException(s"Missing username: ${configuration.username}")
throw new MissingAuthParamException(isUsernameEmpty = true, ctx=Option.empty[SASLContext], password = Option.empty[String])
}
case m: AuthenticationSASLContinueMessage =>
saslCtx.flatMap { ctx =>
Expand All @@ -256,16 +256,16 @@ class PostgreSQLConnection
write(message)
}
} getOrElse {
throw new DatabaseException(s"Missing ctx: $saslCtx or password: ${configuration.password}")
throw new MissingAuthParamException(isUsernameEmpty = false, ctx = saslCtx, password = configuration.password)
}
case m: AuthenticationSASLFinalMessage =>
saslCtx.map { ctx =>
if (!SaslEngine.validateContext(ctx, m)) {
if (!SaslEngine.validateFinalMessageProof(ctx, m)) {
saslCtx = None
throw new DatabaseException(s"Bas server proof: $ctx, ${m.msg.toScramMessage}")
throw new InvalidFinalServerMessageProof(m.msg)
} else ()
} getOrElse {
throw new DatabaseException(s"Missing ctx: $saslCtx or password: ${configuration.password}")
throw new MissingAuthParamException(isUsernameEmpty = false, ctx = saslCtx, password = configuration.password)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.github.mauricio.async.db.postgresql.sasl

import com.github.mauricio.async.db.exceptions.DatabaseException
import com.github.mauricio.async.db.postgresql.sasl.ScramMessages.ServerFinalMessage

class InvalidFinalServerMessageProof(serverFinalMessage: ServerFinalMessage)
extends DatabaseException(
"Invalid server proof, authentication failed. Server final message: %s".format(serverFinalMessage.toScramMessage)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.github.mauricio.async.db.postgresql.sasl

import com.github.mauricio.async.db.exceptions.DatabaseException
import com.github.mauricio.async.db.postgresql.sasl.SaslEngine.SASLContext

class MissingAuthParamException(
val isUsernameEmpty: Boolean,
val ctx: Option[SASLContext],
val password: Option[String]
) extends DatabaseException(
s"Missing ${if (isUsernameEmpty) "username" else "password or SASL-context password=[%s] ctx=[%s]"}".format(
password,
ctx
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@ import javax.crypto.{Mac, SecretKeyFactory}
import scala.annotation.tailrec
import scala.util.Random

/**
* Engine contains definition of calculation SASL mechanism which is used for SCRAM
* @see [[https://datatracker.ietf.org/doc/html/rfc5802 RFC-5802]]
* [[https://www.improving.com/thoughts/making-sense-of-scram-sha-256-authentication-in-mongodb Article about SCRAM]]
* for more info about SCRAM
*/
object SaslEngine {
val SaslNonceLength = 20

val SaslHeader = "n,,"
val SaslMechanism = "SCRAM-SHA-256"
private val SaslNonceLength = 20
private val SaslMechanism = "SCRAM-SHA-256"

case class SASLContext(firstMsg: ClientFirstMessage, ourServerProof: Option[String])
case class SASLContext(firstMsg: ClientFirstMessage, serverProof: Option[String])

private val HashAlg = "SHA-256"
private val MacAlg = "HmacSHA256"
Expand All @@ -34,25 +41,25 @@ object SaslEngine {

private val log = Log.getByName(this.getClass.getName)

private def hi(password: String, salt: Array[Byte], iterations: Int): Array[Byte] = {
private[sasl] def hi(password: String, salt: Array[Byte], iterations: Int): Array[Byte] = {
val spec = new PBEKeySpec(password.toCharArray, salt, iterations, Pbkdf2KeyLength * 8)
val skf = SecretKeyFactory.getInstance(Pbkdf2Alg)
val key = skf.generateSecret(spec)
key.getEncoded
}

private def hmac(key: Array[Byte], str: Array[Byte]): Array[Byte] = {
private[sasl] def hmac(key: Array[Byte], str: Array[Byte]): Array[Byte] = {
val mac = Mac.getInstance(MacAlg)
mac.init(new SecretKeySpec(key, MacAlg))
mac.doFinal(str)
}

private def hash(bytes: Array[Byte]): Array[Byte] = MessageDigest.getInstance(HashAlg).digest(bytes)
private[sasl] def hash(bytes: Array[Byte]): Array[Byte] = MessageDigest.getInstance(HashAlg).digest(bytes)

private def xor(right: Array[Byte], left: Array[Byte]): Array[Byte] =
private[sasl] def xor(right: Array[Byte], left: Array[Byte]): Array[Byte] =
right.zip(left).map(t => (t._1 ^ t._2).toByte)

private def random(length: Int): String = {
private[sasl] def random(length: Int): String = {
val random = new Random()
val result = new Array[Byte](length)
(0 until length).foreach { i =>
Expand All @@ -62,7 +69,7 @@ object SaslEngine {
new String(result)
}

private def toHex(bytes: Array[Byte]): String = bytes.map(_.formatted("%02x")).mkString
private[sasl] def toHex(bytes: Array[Byte]): String = bytes.map(_.formatted("%02x")).mkString

private def debug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg)

Expand Down Expand Up @@ -126,13 +133,13 @@ object SaslEngine {
val serverProof = Base64.getEncoder.encodeToString(hmac(serverKey, authMessage.getBytes))
debug(s"ServerProof: $serverProof")

(ctx.copy(ourServerProof = Option(serverProof)), SASLResponse(clientFinalMsg))
(ctx.copy(serverProof = Option(serverProof)), SASLResponse(clientFinalMsg))
}

def validateContext(ctx: SASLContext, msg: AuthenticationSASLFinalMessage): Boolean =
ctx.ourServerProof.exists { proof =>
val theirProof = Base64.getEncoder.encodeToString(msg.msg.serverProof)
debug(s"ourProof: $proof, their: $theirProof")
msg.msg.serverProof.nonEmpty && proof == theirProof
def validateFinalMessageProof(ctx: SASLContext, finalMessage: AuthenticationSASLFinalMessage): Boolean =
ctx.serverProof.exists { ctxServerProof =>
val serverFinalMessageProof = Base64.getEncoder.encodeToString(finalMessage.msg.serverProof)
debug(s"ServerProof on client side: $ctxServerProof, ServerProof from server side in final message: $serverFinalMessageProof")
finalMessage.msg.serverProof.nonEmpty && ctxServerProof == serverFinalMessageProof
}
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
package com.github.mauricio.async.db.sasl

import com.github.mauricio.async.db.postgresql.sasl.ScramMessages.{
ClientFinalMessage,
ClientFirstMessage,
ParseScramMessageOps,
ServerFinalMessage,
ServerFirstMessage
}
import com.github.mauricio.async.db.sasl.SaslBaseFlowSpec.{PBKDF2_ALGORITHM, PBKDF2_KEY_LENGTH}
import org.specs2.matcher.Matchers.beFalse
package com.github.mauricio.async.db.postgresql.sasl

import com.github.mauricio.async.db.postgresql.sasl.SaslEngine._
import com.github.mauricio.async.db.postgresql.sasl.ScramMessages.{ClientFinalMessage, ClientFirstMessage, ParseScramMessageOps, ServerFinalMessage, ServerFirstMessage}
import SaslBaseFlowSpec.PBKDF2_KEY_LENGTH
import org.specs2.mutable.Specification

import java.security.MessageDigest
import java.util.Base64
import javax.crypto.spec.{PBEKeySpec, SecretKeySpec}
import javax.crypto.{Mac, SecretKeyFactory}
import scala.annotation.tailrec
import scala.util.Random

object SaslBaseFlowSpec {
val PBKDF2_ALGORITHM: String = "PBKDF2WithHmacSHA256"
val PBKDF2_KEY_LENGTH: Int = MessageDigest.getInstance("SHA-256").getDigestLength

}

class SaslBaseFlowSpec extends Specification {
Expand All @@ -36,13 +25,6 @@ class SaslBaseFlowSpec extends Specification {
S: v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=
*/

def hi(password: String, salt: Array[Byte], iterations: Int): Array[Byte] = {
val spec = new PBEKeySpec(password.toCharArray, salt, iterations, PBKDF2_KEY_LENGTH * 8)
val skf = SecretKeyFactory.getInstance(PBKDF2_ALGORITHM)
val key = skf.generateSecret(spec)
key.getEncoded
}

def hiAlt(data: String, salt: Array[Byte], iterations: Int): Array[Byte] = {
val bytes = data.getBytes

Expand All @@ -61,29 +43,6 @@ class SaslBaseFlowSpec extends Specification {
step(new Array[Byte](PBKDF2_KEY_LENGTH), new Array[Byte](PBKDF2_KEY_LENGTH), iterations)
}

def hmac(key: Array[Byte], str: Array[Byte]): Array[Byte] = {
val mac = Mac.getInstance("HmacSHA256")
mac.init(new SecretKeySpec(key, "HmacSHA256"))
mac.doFinal(str)
}

def hash(bytes: Array[Byte]): Array[Byte] = MessageDigest.getInstance("SHA-256").digest(bytes)

def xor(right: Array[Byte], left: Array[Byte]): Array[Byte] =
right.zip(left).map(t => (t._1 ^ t._2).toByte)

def toHex(bytes: Array[Byte]): String = bytes.map(_.formatted("%02x")).mkString

def random(length: Int): String = {
val random = new Random()
val result = new Array[Byte](length)
(0 until length).foreach { i =>
val data = (random.nextInt(127 - 33) + 33).toByte
result.update(i, if (data == ','.toByte) 126.toByte else data)
}
new String(result)
}

"An engine" should {
"should generate random nonces" in {
(0 until 1000).foreach { i =>
Expand Down
0