8000 Add SSL support.. · mirabout/postgresql-async@0f9a587 · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 0f9a587

Browse files
committed
Add SSL support..
SSL is disabled by default to avoid POLA violations. It is possible to enable and control SSL behavior via url parameters: - `sslmode=<mode>` enable ssl (prefer/require/verify-ca/verify-full [recommended]) - `sslrootcert=<path.pem>` specifies trusted certificates (JDK cacert if missing) Client certificate authentication is not implemented, due to lack of time and interest, but it should be easy to add.
1 parent c3747b5 commit 0f9a587

File tree

21 files changed

+364
-48
lines changed

21 files changed

+364
-48
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ object Configuration {
3737
* @param port database port, defaults to 5432
3838
* @param password password, defaults to no password
3939
* @param database database name, defaults to no database
40+
* @param ssl ssl configuration
4041
* @param charset charset for the connection, defaults to UTF-8, make sure you know what you are doing if you
4142
* change this
4243
* @param maximumMessageSize the maximum size a message from the server could possibly have, this limits possible
@@ -55,6 +56,7 @@ case class Configuration(username: String,
5556
port: Int = 5432,
5657
password: Option[String] = None,
5758
database: Option[String] = None,
59+
ssl: SSLConfiguration = SSLConfiguration(),
5860
charset: Charset = Configuration.DefaultCharset,
5961
maximumMessageSize: Int = 16777216,
6062
allocator: ByteBufAllocator = PooledByteBufAllocator.DEFAULT,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.github.mauricio.async.db
2+
3+
import java.io.File
4+
5+
import SSLConfiguration.Mode
6+
7+
/**
8+
*
9+
* Contains the SSL configuration necessary to connect to a database.
10+
*
11+
* @param mode whether and with what priority a SSL connection will be negotiated, default disabled
12+
* @param rootCert path to PEM encoded trusted root certificates, None to use internal JDK cacerts, defaults to None
13+
*
14+
*/
15+
case class SSLConfiguration(mode: Mode.Value = Mode.Disable, rootCert: Option[java.io.File] = None)
16+
17+
object SSLConfiguration {
18+
19+
object Mode extends Enumeration {
20+
val Disable = Value("disable") // only try a non-SSL connection
21+
val Prefer = Value("prefer") // first try an SSL connection; if that fails, try a non-SSL connection
22+
val Require = Value("require") // only try an SSL connection, but don't verify Certificate Authority
23+
val VerifyCA = Value("verify-ca") // only try an SSL connection, and verify that the server certificate is issued by a trusted certificate authority (CA)
24+
val VerifyFull = Value("verify-full") // only try an SSL connection, verify that the server certificate is issued by a trusted CA and that the server host name matches that in the certificate
25+
}
26+
27+
def apply(properties: Map[String, String]): SSLConfiguration = SSLConfiguration(
28+
mode = Mode.withName(properties.get("sslmode").getOrElse("disable")),
29+
rootCert = properties.get("sslrootcert").map(new File(_))
30+
)
31+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageDecoder.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package com.github.mauricio.async.db.postgresql.codec
1818

1919
import com.github.mauricio.async.db.postgresql.exceptions.{MessageTooLongException}
20-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
20+
import com.github.mauricio.async.db.postgresql.messages.backend.{ServerMessage, SSLResponseMessage}
2121
import com.github.mauricio.async.db.postgresql.parsers.{AuthenticationStartupParser, MessageParsersRegistry}
2222
import com.github.mauricio.async.db.util.{BufferDumper, Log}
2323
import java.nio.charset.Charset
@@ -31,15 +31,21 @@ object MessageDecoder {
3131
val DefaultMaximumSize = 16777216
3232
}
3333

34-
class MessageDecoder(charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
34+
class MessageDecoder(sslEnabled: Boolean, charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
3535

3636
import MessageDecoder.log
3737

3838
private val parser = new MessageParsersRegistry(charset)
3939

40+
private var sslChecked = false
41+
4042
override def decode(ctx: ChannelHandlerContext, b: ByteBuf, out: java.util.List[Object]): Unit = {
4143

42-
if (b.readableBytes() >= 5) {
44+
if (sslEnabled & !sslChecked) {
45+
val code = b.readByte()
46+
sslChecked = true
47+
out.add(new SSLResponseMessage(code == 'S'))
48+
} else if (b.readableBytes() >= 5) {
4349

4450
b.markReaderIndex()
4551

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e
4444
override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = {
4545

4646
val buffer = msg match {
47+
case SSLRequestMessage => SSLMessageEncoder.encode()
48+
case message: StartupMessage => startupEncoder.encode(message)
4749
case message: ClientMessage => {
4850
val encoder = (message.kind: @switch) match {
4951
case ServerMessage.Close => CloseMessageEncoder
5052
case ServerMessage.Execute => this.executeEncoder
5153
case ServerMessage.Parse => this.openEncoder
52-
case ServerMessage.Startup => this.startupEncoder
5354
case ServerMessage.Query => this.queryEncoder
5455
case ServerMessage.PasswordMessage => this.credentialEncoder
5556
case _ => throw new EncoderNotAvailableException(message)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.github.mauricio.async.db.postgresql.codec
1818

1919
import com.github.mauricio.async.db.Configuration
20+
import com.github.mauricio.async.db.SSLConfiguration.Mode
2021
import com.github.mauricio.async.db.column.{ColumnDecoderRegistry, ColumnEncoderRegistry}
2122
import com.github.mauricio.async.db.postgresql.exceptions._
2223
import com.github.mauricio.async.db.postgresql.messages.backend._
@@ -38,6 +39,12 @@ import com.github.mauricio.async.db.postgresql.messages.backend.RowDescriptionMe
3839
import com.github.mauricio.async.db.postgresql.messages.backend.ParameterStatusMessage
3940
import io.netty.channel.socket.nio.NioSocketChannel
4041
import io.netty.handler.codec.CodecException
42+
import io.netty.handler.ssl.{SslContextBuilder, SslHandler}
43+
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
44+
import io.netty.util.concurrent.FutureListener
45+
import javax.net.ssl.{SSLParameters, TrustManagerFactory}
46+
import java.security.KeyStore
47+
import java.io.FileInputStream
4148

4249
object PostgreSQLConnectionHandler {
4350
final val log = Log.get[PostgreSQLConnectionHandler]
@@ -79,7 +86,7 @@ class PostgreSQLConnectionHandler
7986

8087
override def initChannel(ch: channel.Channel): Unit = {
8188
ch.pipeline.addLast(
82-
new MessageDecoder(configuration.charset, configuration.maximumMessageSize),
89+
new MessageDecoder(configuration.ssl.mode != Mode.Disable, configuration.charset, configuration.maximumMessageSize),
8390
new MessageEncoder(configuration.charset, encoderRegistry),
8491
PostgreSQLConnectionHandler.this)
8592
}
@@ -120,13 +127,61 @@ class PostgreSQLConnectionHandler
120127
}
121128

122129
override def channelActive(ctx: ChannelHandlerContext): Unit = {
123-
ctx.writeAndFlush(new StartupMessage(this.properties))
130+
if (configuration.ssl.mode == Mode.Disable)
131+
ctx.writeAndFlush(new StartupMessage(this.properties))
132+
else
133+
ctx.writeAndFlush(SSLRequestMessage)
124134
}
125135

126136
override def channelRead0(ctx: ChannelHandlerContext, msg: Object): Unit = {
127137

128138
msg match {
129139

140+
case SSLResponseMessage(supported) =>
141+
if (supported) {
142+
val ctxBuilder = SslContextBuilder.forClient()
143+
if (configuration.ssl.mode >= Mode.VerifyCA) {
144+
configuration.ssl.rootCert.fold {
145+
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
146+
val ks = KeyStore.getInstance(KeyStore.getDefaultType())
147+
val cacerts = new FileInputStream(System.getProperty("java.home") + "/lib/security/cacerts")
148+
try {
149+
ks.load(cacerts, "changeit".toCharArray)
150+
} finally {
151+
cacerts.close()
152+
}
153+
tmf.init(ks)
154+
ctxBuilder.trustManager(tmf)
155+
} { path =>
156+
ctxBuilder.trustManager(path)
157+
}
158+
} else {
159+
ctxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE)
160+
}
161+
val sslContext = ctxBuilder.build()
162+
val sslEngine = sslContext.newEngine(ctx.alloc(), configuration.host, configuration.port)
163+
if (configuration.ssl.mode >= Mode.VerifyFull) {
164+
val sslParams = sslEngine.getSSLParameters()
165+
sslParams.setEndpointIdentificationAlgorithm("HTTPS")
166+
sslEngine.setSSLParameters(sslParams)
167+
}
168+
val handler = new SslHandler(sslEngine)
169+
ctx.pipeline().addFirst(handler)
170+
handler.handshakeFuture.addListener(new FutureListener[channel.Channel]() {
171+
def operationComplete(future: io.netty.util.concurrent.Future[channel.Channel]) {
172+
if (future.isSuccess()) {
173+
ctx.writeAndFlush(new StartupMessage(properties))
174+
} else {
175+
connectionDelegate.onError(future.cause())
176+
}
177+
}
178+
})
179+
} else if (configuration.ssl.mode < Mode.Require) {
180+
ctx.writeAndFlush(new StartupMessage(properties))
181+
} else {
182+
connectionDelegate.onError(new IllegalArgumentException("SSL is not supported on server"))
183+
}
184+
130185
case m: ServerMessage => {
131186

132187
(m.kind : @switch) match {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import io.netty.buffer.ByteBuf
4+
import io.netty.buffer.Unpooled
5+
6+
object SSLMessageEncoder {
7+
8+
def encode(): ByteBuf = {
9+
val buffer = Unpooled.buffer()
10+
buffer.writeInt(8)
11+
buffer.writeShort(1234)
12+
buffer.writeShort(5679)
13+
buffer
14+
}
15+
16+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/StartupMessageEncoder.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ import com.github.mauricio.async.db.util.ByteBufferUtils
2121
import java.nio.charset.Charset
2222
import io.netty.buffer.{Unpooled, ByteBuf}
2323

24-
class StartupMessageEncoder(charset: Charset) extends Encoder {
24+
class StartupMessageEncoder(charset: Charset) {
2525

2626
//private val log = Log.getByName("StartupMessageEncoder")
2727

28-
override def encode(message: ClientMessage): ByteBuf = {
29-
30-
val startup = message.asInstanceOf[StartupMessage]
28+
def encode(startup: StartupMessage): ByteBuf = {
3129

3230
val buffer = Unpooled.buffer()
3331
buffer.writeInt(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.github.mauricio.async.db.postgresql.messages.backend
2+
3+
case class SSLResponseMessage(supported: Boolean)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/messages/backend/ServerMessage.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ object ServerMessage {
4343
final val Query = 'Q'
4444
final val RowDescription = 'T'
4545
final val ReadyForQuery = 'Z'
46-
final val Startup = '0'
4746
final val Sync = 'S'
4847
}
4948

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
trait InitialClientMessage
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
4+
5+
object SSLRequestMessage extends InitialClientMessage

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/messages/frontend/StartupMessage.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,4 @@
1616

1717
package com.github.mauricio.async.db.postgresql.messages.frontend
1818

19-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
20-
21-
class StartupMessage(val parameters: List[(String, Any)]) extends ClientMessage(ServerMessage.Startup)
19+
class StartupMessage(val parameters: List[(String, Any)]) extends InitialClientMessage

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/util/ParserURL.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,37 @@ object ParserURL {
1616
val PGPORT = "port"
1717
val PGDBNAME = "database"
1818
val PGHOST = "host"
19-
val PGUSERNAME = "username"
19+
val PGUSERNAME = "user"
2020
val PGPASSWORD = "password"
2121

2222
val DEFAULT_PORT = "5432"
2323

24-
private val pgurl1 = """(jdbc:postgresql):(?://([^/:]*|\[.+\])(?::(\d+))?)?(?:/([^/?]*))?(?:\?user=(.*)&password=(.*))?""".r
25-
private val pgurl2 = """(postgres|postgresql)://(.*):(.*)@(.*):(\d+)/(.*)""".r
24+
private val pgurl1 = """(jdbc:postgresql):(?://([^/:]*|\[.+\])(?::(\d+))?)?(?:/([^/?]*))?(?:\?(.*))?""".r
25+
private val pgurl2 = """(postgres|postgresql)://(.*):(.*)@(.*):(\d+)/([^/?]*)(?:\?(.*))?""".r
2626

2727
def parse(connectionURL: String): Map[String, String] = {
2828
val properties: Map[String, String] = Map()
2929

30+
def parseOptions(optionsStr: String): Map[String, String] =
31+
optionsStr.split("&").map { o =>
32+
o.span(_ != '=') match {
33+
case (name, value) => name -> value.drop(1)
34+
}
35+
}.toMap
36+
3037
connectionURL match {
31-
case pgurl1(protocol, server, port, dbname, username, password) => {
38+
case pgurl1(protocol, server, port, dbname, params) => {
3239
var result = properties
3340
if (server != null) result += (PGHOST -> unwrapIpv6address(server))
3441
if (dbname != null && dbname.nonEmpty) result += (PGDBNAME -> dbname)
35-
if(port != null) result += (PGPORT -> port)
36-
if(username != null) result = (result + (PGUSERNAME -> username) + (PGPASSWORD -> password))
42+
if (port != null) result += (PGPORT -> port)
43+
if (params != null) result ++= parseOptions(params)
3744
result
3845
}
39-
case pgurl2(protocol, username, password, server, port, dbname) => {
40-
properties + (PGHOST -> unwrapIpv6address(server)) + (PGPORT -> port) + (PGDBNAME -> dbname) + (PGUSERNAME -> username) + (PGPASSWORD -> password)
46+
case pgurl2(protocol, username, password, server, port, dbname, params) => {
47+
var result = properties + (PGHOST -> unwrapIpv6address(server)) + (PGPORT -> port) + (PGDBNAME -> dbname) + (PGUSERNAME -> username) + (PGPASSWORD -> password)
48+
if (params != null) result ++= parseOptions(params)
49+
result
4150
}
4251
case _ => {
4352
logger.warn(s"Connection url '$connectionURL' could not be parsed.")

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/util/URLParser.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@
1616

1717
package com.github.mauricio.async.db.postgresql.util
1818

19-
import com.github.mauricio.async.db.Configuration
19+
import com.github.mauricio.async.db.{Configuration, SSLConfiguration}
2020
import java.nio.charset.Charset
2121

2222
object URLParser {
2323

24-
private val Username = "username"
25-
private val Password = "password"
26-
2724
import Configuration.Default
2825

2926
def parse(url: String,
@@ -35,11 +32,12 @@ object URLParser {
3532
val port = properties.get(ParserURL.PGPORT).getOrElse(ParserURL.DEFAULT_PORT).toInt
3633

3734
new Configuration(
38-
username = properties.get(Username).getOrElse(Default.username),
39-
password = properties.get(Password),
35+
username = properties.get(ParserURL.PGUSERNAME).getOrElse(Default.username),
36+
password = properties.get(ParserURL.PGPASSWORD),
4037
database = properties.get(ParserURL.PGDBNAME),
4138
host = properties.getOrElse(ParserURL.PGHOST, Default.host),
4239
port = port,
40+
ssl = SSLConfiguration(properties),
4341
charset = charset
4442
)
4543

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/DatabaseTestHelper.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ package com.github.mauricio.async.db.postgresql
1818

1919
import com.github.mauricio.async.db.util.Log
2020
import com.github.mauricio.async.db.{Connection, Configuration}
21+
import java.io.File
2122
import java.util.concurrent.{TimeoutException, TimeUnit}
22-
import scala.Some
2323
import scala.concurrent.duration._
2424
import scala.concurrent.{Future, Await}
25+
import com.github.mauricio.async.db.SSLConfiguration
26+
import com.github.mauricio.async.db.SSLConfiguration.Mode
2527

2628
object DatabaseTestHelper {
2729
val log = Log.get[DatabaseTestHelper]
@@ -54,6 +56,16 @@ trait DatabaseTestHelper {
5456
withHandler(this.timeTestConfiguration, fn)
5557
}
5658

59+
def withSSLHandler[T](mode: SSLConfiguration.Mode.Value, host: String = "localhost", rootCert: Option[File] = Some(new File("script/server.crt")))(fn: (PostgreSQLConnection) => T): T = {
60+
val config = new Configuration(
61+
host = host,
62+
port = databasePort,
63+
username = "postgres",
64+
database = databaseName,
65+
ssl = SSLConfiguration(mode = mode, rootCert = rootCert))
66+
withHandler(config, fn)
67+
}
68+
5769
def withHandler[T](configuration: Configuration, fn: (PostgreSQLConnection) => T): T = {
5870

5971
val handler = new PostgreSQLConnection(configuration)

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/MessageDecoderSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import java.util
2727

2828
class MessageDecoderSpec extends Specification {
2929

30-
val decoder = new MessageDecoder(CharsetUtil.UTF_8)
30+
val decoder = new MessageDecoder(false, CharsetUtil.UTF_8)
3131

3232
"message decoder" should {
3333

0 commit comments

Comments
 (0)
0