8000 Support `caching_sha2_password` authentication mode by KarboniteKream · Pull Request #358 · jasync-sql/jasync-sql · GitHub
[go: up one dir, main page]

Skip to content

Support caching_sha2_password authentication mode #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ postgresql-async/out/*
mysql-async/target/*
pool-async/target/*
postgis-jasync/target/*
r2dbc-mysql/target/*
.rvmrc
.ruby-version
.ruby-gemset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.github.jasync.sql.db.Configuration
import com.github.jasync.sql.db.exceptions.DatabaseException
import com.github.jasync.sql.db.general.MutableResultSet
import com.github.jasync.sql.db.mysql.binary.BinaryRowDecoder
import com.github.jasync.sql.db.mysql.encoder.auth.AuthenticationMethod
import com.github.jasync.sql.db.mysql.message.client.AuthenticationSwitchResponse
import com.github.jasync.sql.db.mysql.message.client.CapabilityRequestMessage
import com.github.jasync.sql.db.mysql.message.client.CloseStatementMessage
Expand All @@ -13,6 +14,7 @@ import com.github.jasync.sql.db.mysql.message.client.PreparedStatementPrepareMes
import com.github.jasync.sql.db.mysql.message.client.QueryMessage
import com.github.jasync.sql.db.mysql.message.client.QuitMessage
import com.github.jasync.sql.db.mysql.message.client.SendLongDataMessage
import com.github.jasync.sql.db.mysql.message.server.AuthMoreDataMessage
import com.github.jasync.sql.db.mysql.message.server.AuthenticationSwitchRequest
import com.github.jasync.sql.db.mysql.message.server.BinaryRowMessage
import com.github.jasync.sql.db.mysql.message.server.ColumnDefinitionMessage
Expand All @@ -23,6 +25,7 @@ import com.github.jasync.sql.db.mysql.message.server.OkMessage
import com.github.jasync.sql.db.mysql.message.server.PreparedStatementPrepareResponse
import com.github.jasync.sql.db.mysql.message.server.ResultSetRowMessage
import com.github.jasync.sql.db.mysql.message.server.ServerMessage
import com.github.jasync.sql.db.mysql.util.CapabilityFlag
import com.github.jasync.sql.db.mysql.util.CharsetMapper
import com.github.jasync.sql.db.util.ExecutorServiceUtils
import com.github.jasync.sql.db.util.FP
Expand Down Expand Up @@ -72,6 +75,7 @@ class MySQLConnectionHandler(
private val parsedStatements = HashMap<String, PreparedStatementHolder>()
private val binaryRowDecoder = BinaryRowDecoder()

private var sslEstablished: Boolean = false
private var currentPreparedStatementHolder: PreparedStatementHolder? = null
private var currentPreparedStatement: PreparedStatement? = null
private var currentQuery: MutableResultSet<ColumnDefinitionMessage>? = null
Expand Down Expand Up @@ -127,6 +131,20 @@ class MySQLConnectionHandler(
ServerMessage.EOF -> {
this.handleEOF(message)
}
ServerMessage.AuthMoreData -> {
val m = message as AuthMoreDataMessage

if (!m.isSuccess()) {
if (!sslEstablished) {
throw IllegalStateException(
"Full authentication mode for ${AuthenticationMethod.CachingSha2} requires SSL"
)
}

val request = AuthenticationSwitchRequest(AuthenticationMethod.CachingSha2, null)
handlerDelegate.switchAuthentication(request)
}
}
ServerMessage.ColumnDefinition -> {
val m = message as ColumnDefinitionMessage

Expand Down Expand Up @@ -278,6 +296,7 @@ class MySQLConnectionHandler(
fun write(message: CapabilityRequestMessage): ChannelFuture = writeAndHandleError(message)

fun write(message: HandshakeResponseMessage): ChannelFuture {
sslEstablished = message.header.flags.contains(CapabilityFlag.CLIENT_SSL)
decoder.hasDoneHandshake = true
return writeAndHandleError(message)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.github.jasync.sql.db.mysql.codec
import com.github.jasync.sql.db.exceptions.BufferNotFullyConsumedException
import com.github.jasync.sql.db.exceptions.NegativeMessageSizeException
import com.github.jasync.sql.db.exceptions.ParserNotAvailableException
import com.github.jasync.sql.db.mysql.decoder.AuthMoreDataDecoder
import com.github.jasync.sql.db.mysql.decoder.AuthenticationSwitchRequestDecoder
import com.github.jasync.sql.db.mysql.decoder.ColumnDefinitionDecoder
import com.github.jasync.sql.db.mysql.decoder.ColumnProcessingFinishedDecoder
Expand Down Expand Up @@ -39,6 +40,7 @@ class MySQLFrameDecoder(val charset: Charset, private val connectionId: String)
private val handshakeDecoder = HandshakeV10Decoder()
private val errorDecoder = ErrorDecoder(charset)
private val okDecoder = OkDecoder(charset)
private val authMoreDataDecoder = AuthMoreDataDecoder()
private val columnDecoder = ColumnDefinitionDecoder(charset, DecoderRegistry(charset))
private val authenticationSwitchRequestDecoder = AuthenticationSwitchRequestDecoder(charset)
private val rowDecoder = ResultSetRowDecoder()
Expand Down Expand Up @@ -89,7 +91,7 @@ class MySQLFrameDecoder(val charset: Charset, private val connectionId: String)
logger.trace {
"[connectionId:$connectionId] - Reading message type $messageType - " +
"(count=$messagesCount,hasDoneHandshake=$hasDoneHandshake,size=$size,isInQuery=$isInQuery,processingColumns=$processingColumns,processingParams=$processingParams,processedColumns=$processedColumns,processedParams=$processedParams)" +
"\n${BufferDumper.dumpAsHex(slice)}}"
"\n${BufferDumper.dumpAsHex(slice)}"
}

slice.markReaderIndex()
Expand Down Expand Up @@ -161,6 +163,13 @@ class MySQLFrameDecoder(val charset: Charset, private val connectionId: String)
}
}
}
ServerMessage.AuthMoreData -> {
if (!isInQuery) {
this.authMoreDataDecoder
} else {
null
}
}
else -> {
if (this.isInQuery) {
null
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.github.jasync.sql.db.mysql.decoder

import com.github.jasync.sql.db.mysql.message.server.AuthMoreDataMessage
import com.github.jasync.sql.db.mysql.message.server.ServerMessage
import io.netty.buffer.ByteBuf

class AuthMoreDataDecoder : MessageDecoder {
override fun decode(buffer: ByteBuf): ServerMessage {
return AuthMoreDataMessage(
data = buffer.readByte(),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ class AuthenticationSwitchRequestDecoder(val charset: Charset) : MessageDecoder
val method = buffer.readCString(charset)
val bytes: Int = buffer.readableBytes()
val terminal = 0.toByte()
val salt = if (bytes > 0 && buffer.getByte(buffer.writerIndex() - 1) == terminal) ByteBufUtil.getBytes(
val seed = if (bytes > 0 && buffer.getByte(buffer.writerIndex() - 1) == terminal) ByteBufUtil.getBytes(
buffer,
buffer.readerIndex(),
bytes - 1
) else ByteBufUtil.getBytes(buffer)
return AuthenticationSwitchRequest(method, salt)
buffer.skipBytes(bytes)
return AuthenticationSwitchRequest(method, seed)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ import java.nio.charset.Charset

interface AuthenticationMethod {

fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray): ByteArray
fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray

companion object {
val Native = "mysql_native_password"
val Old = "mysql_old_password"
const val CachingSha2 = "caching_sha2_password"
const val Native = "mysql_native_password"
const val Old = "mysql_old_password"
const val Sha256 = "sha256_password"

val Availables = mapOf(
CachingSha2 to CachingSha2PasswordAuthentication,
Native to MySQLNativePasswordAuthentication,
Old to OldPasswordAuthentication
Old to OldPasswordAuthentication,
Sha256 to Sha256PasswordAuthentication,
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.util.length
import java.nio.charset.Charset
import java.security.MessageDigest
import kotlin.experimental.xor

object AuthenticationScrambler {

fun scramble411(
algorithm: String,
password: String,
charset: Charset,
seed: ByteArray,
seedFirst: Boolean,
): ByteArray {
val messageDigest = MessageDigest.getInstance(algorithm)
val initialDigest = messageDigest.digest(password.toByteArray(charset))

messageDigest.reset()

val finalDigest = messageDigest.digest(initialDigest)

messageDigest.reset()

if (seedFirst) {
messageDigest.update(seed)
messageDigest.update(finalDigest)
} else {
messageDigest.update(finalDigest)
messageDigest.update(seed)
}

val result = messageDigest.digest()
var counter = 0

while (counter < result.length) {
result[counter] = (result[counter] xor initialDigest[counter])
counter += 1
}

return result
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import java.nio.charset.Charset

object CachingSha2PasswordAuthentication : AuthenticationMethod {

private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
return if (password != null) {
if (seed != null) {
// Fast authentication mode. Requires seed, but not SSL.
AuthenticationScrambler.scramble411("SHA-256", password, charset, seed, false)
} else {
// Full authentication mode.
// Since this sends the plaintext password, SSL is required.
// Without SSL, the server always rejects the password.
Sha256PasswordAuthentication.generateAuthentication(charset, password, null)
}
} else {
EmptyArray
}
}
}
Original file line number Diff line number Diff line change
@@ -1,45 +1,18 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.util.length
import java.nio.charset.Charset
import java.security.MessageDigest
import kotlin.experimental.xor

object MySQLNativePasswordAuthentication : AuthenticationMethod {

val EmptyArray = ByteArray(0)
private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray): ByteArray {
override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
requireNotNull(seed) { "Seed should not be null" }

return if (password != null) {
scramble411(charset, password, seed)
AuthenticationScrambler.scramble411("SHA-1", password, charset, seed, true)
} else {
EmptyArray
}
}

private fun scramble411(charset: Charset, password: String, seed: ByteArray): ByteArray {

val messageDigest = MessageDigest.getInstance("SHA-1")
val initialDigest = messageDigest.digest(password.toByteArray(charset))

messageDigest.reset()

val finalDigest = messageDigest.digest(initialDigest)

messageDigest.reset()

messageDigest.update(seed)
messageDigest.update(finalDigest)

val result = messageDigest.digest()
var counter = 0

while (counter < result.length) {
result[counter] = (result[counter] xor initialDigest[counter])
counter += 1
}

return result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@ import kotlin.math.floor
@Suppress("RedundantExplicitType", "UNUSED_VALUE", "VARIABLE_WITH_REDUNDANT_INITIALIZER")
object OldPasswordAuthentication : AuthenticationMethod {

val EmptyArray = ByteArray(0)
private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
requireNotNull(seed) { "Seed should not be null" }

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray): ByteArray {
return when {
password != null && password.isNotEmpty() -> {
newCrypt(charset, password, String(seed, charset))
!password.isNullOrEmpty() -> {
// The native authentication handshake will provide a 20-byte challenge.
// Use the first 8 bytes as the old password authentication challenge.
val challenge = if (seed.length == 20) {
seed.copyOf(8)
} else {
seed
}
Comment on lines +17 to +23
Copy link
Contributor Author
@KarboniteKream KarboniteKream Jan 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, a "bad handshake" error is thrown. Reference

If the server announces Native Authentication in the Protocol::Handshake packet the client may use the first 8 bytes of its 20-byte auth_plugin_data as input.


newCrypt(charset, password, String(challenge, charset))
}
else -> EmptyArray
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.util.length
import java.nio.charset.Charset

// TODO: Implement public key encryption.
object Sha256PasswordAuthentication : AuthenticationMethod {

private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
return if (password != null) {
val bytes = password.toByteArray(charset)
val result = ByteArray(bytes.length + 1)
bytes.copyInto(result)
result
} else {
EmptyArray
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.github.jasync.sql.db.mysql.message.server

data class AuthMoreDataMessage(
val data: Byte,
) : ServerMessage(AuthMoreData) {
fun isSuccess(): Boolean {
return data == 3.toByte()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package com.github.jasync.sql.db.mysql.message.server

data class AuthenticationSwitchRequest(
val method: String,
val seed: ByteArray
val seed: ByteArray?,
) : ServerMessage(ServerMessage.EOF)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ abstract class ServerMessage(override val kind: Int) : KindedMessage {
const val ServerProtocolVersion = 10
const val Error = -1
const val Ok = 0
const val AuthMoreData = 1
const val EOF = -2

// these messages don't actually exist
Expand Down
Loading
0