8000 Handle the case where the encoder produces a null value (as it would … · mauricio/postgresql-async@fbcd302 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Dec 3, 2019. It is now read-only.

Commit fbcd302

Browse files
committed
Handle the case where the encoder produces a null value (as it would for Some(null)) - possible fix to #99
1 parent 245359c commit fbcd302

File tree

5 files changed

+65
-14
lines changed

5 files changed

+65
-14
lines changed

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/column/PostgreSQLColumnEncoderRegistry.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package com.github.mauricio.async.db.postgresql.column
1818

1919
import com.github.mauricio.async.db.column._
2020
import org.joda.time._
21-
import scala.Some
21+
2222
import scala.collection.JavaConversions._
2323

2424
object PostgreSQLColumnEncoderRegistry {
@@ -127,7 +127,7 @@ class PostgreSQLColumnEncoderRegistry extends ColumnEncoderRegistry {
127127
val result = collection.map {
128128
item =>
129129

130-
if (item == null) {
130+
if (item == null || item == None) {
131131
"NULL"
132132
} else {
133133
if (this.shouldQuote(item)) {
@@ -177,4 +177,5 @@ class PostgreSQLColumnEncoderRegistry extends ColumnEncoderRegistry {
177177
}
178178
}
179179
}
180+
180181
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/ExecutePreparedStatementEncoder.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
package com.github.mauricio.async.db.postgresql.encoders
1818

19+
import java.nio.charset.Charset
20+
1921
import com.github.mauricio.async.db.column.ColumnEncoderRegistry
20-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
2122
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementExecuteMessage}
22-
import com.github.mauricio.async.db.util.ByteBufferUtils
23-
import java.nio.charset.Charset
2423
import io.netty.buffer.ByteBuf
2524

2625
class ExecutePreparedStatementEncoder(

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementEncoderHelper.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@ trait PreparedStatementEncoderHelper {
3333

3434
def writeExecutePortal(
3535
statementIdBytes: Array[Byte],
36-
query : String,
36+
query: String,
3737
values: Seq[Any],
3838
encoder: ColumnEncoderRegistry,
3939
charset: Charset,
4040
writeDescribe: Boolean = false
4141
): ByteBuf = {
4242

43+
if (log.isDebugEnabled) {
44+
log.debug(s"Preparing execute portal to statement ($query) - values (${values.mkString(", ")}) - ${charset}")
45+
}
46+
4347
val bindBuffer = Unpooled.buffer(1024)
4448

4549
bindBuffer.writeByte(ServerMessage.Bind)
@@ -54,14 +58,14 @@ trait PreparedStatementEncoderHelper {
5458

5559
bindBuffer.writeShort(values.length)
5660

57-
val decodedValues = if ( log.isDebugEnabled ) {
61+
val decodedValues = if (log.isDebugEnabled) {
5862
new ArrayBuffer[String](values.size)
5963
} else {
6064
null
6165
}
6266

6367
for (value <- values) {
64-
if (value == null || value == None) {
68+
if (isNull(value)) {
6569
bindBuffer.writeInt(-1)
6670

6771
if (log.isDebugEnabled) {
@@ -70,25 +74,30 @@ trait PreparedStatementEncoderHelper {
7074
} else {
7175
val encodedValue = encoder.encode(value)
7276

73-
if ( log.isDebugEnabled ) {
77+
if (log.isDebugEnabled) {
7478
decodedValues += encodedValue
7579
}
7680

77-
val content = encodedValue.getBytes(charset)
78-
bindBuffer.writeInt(content.length)
79-
bindBuffer.writeBytes( content )
81+
if (isNull(encodedValue)) {
82+
bindBuffer.writeInt(-1)
83+
} else {
84+
val content = encodedValue.getBytes(charset)
85+
bindBuffer.writeInt(content.length)
86+
bindBuffer.writeBytes(content)
87+
}
88+
8089
}
8190
}
8291

8392
if (log.isDebugEnabled) 9E88 {
84-
log.debug(s"Executing query - statement id (${statementIdBytes.mkString("-")}) - statement ($query) - encoded values (${decodedValues.mkString(", ")}) - original values (${values.mkString(", ")})")
93+
log.debug(s"Executing portal - statement id (${statementIdBytes.mkString("-")}) - statement ($query) - encoded values (${decodedValues.mkString(", ")}) - original values (${values.mkString(", ")})")
8594
}
8695

8796
bindBuffer.writeShort(0)
8897

8998
ByteBufferUtils.writeLength(bindBuffer)
9099

91-
if ( writeDescribe ) {
100+
if (writeDescribe) {
92101
val describeLength = 1 + 4 + 1 + statementIdBytes.length + 1
93102
val describeBuffer = bindBuffer
94103
describeBuffer.writeByte(ServerMessage.Describe)
@@ -122,4 +131,6 @@ trait PreparedStatementEncoderHelper {
122131

123132
}
124133

134+
def isNull(value: Any): Boolean = value == null || value == None
135+
125136
}

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/PostgreSQLColumnEncoderRegistrySpec.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ class PostgreSQLColumnEncoderRegistrySpec extends Specification {
4646
actual mustEqual expected
4747
}
4848

49+
"encodes Some(null) as null" in {
50+
val actual = encoder.encode(Some(null))
51+
actual mustEqual null
52+
}
53+
54+
"encodes null as null" in {
55+
val actual = encoder.encode(null)
56+
actual mustEqual null
57+
}
58+
4959
}
5060

5161
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import com.github.mauricio.async.db.postgresql.column.PostgreSQLColumnEncoderRegistry
4+
import com.github.mauricio.async.db.postgresql.messages.frontend.PreparedStatementExecuteMessage
5+
import io.netty.util.CharsetUtil
6+
import org.specs2.mutable.Specification
7+
8+
class ExecutePreparedStatementEncoderSpec extends Specification {
9+
10+
val registry = new PostgreSQLColumnEncoderRegistry()
11+
val encoder = new ExecutePreparedStatementEncoder(CharsetUtil.UTF_8, registry)
12+
val sampleMessage = Array[Byte](66,0,0,0,18,49,0,49,0,0,0,0,1,-1,-1,-1,-1,0,0,69,0,0,0,10,49,0,0,0,0,0,83,0,0,0,4,67,0,0,0,7,80,49,0)
13+
14+
"encoder" should {
15+
16+
"correctly handle the case where an encoder returns null" in {
17+
18+
val message = new PreparedStatementExecuteMessage(1, "select * from users", List(Some(null)), registry)
19+
20+
val result = encoder.encode(message)
21+
22+
val bytes = new Array[Byte](result.readableBytes())
23+
result.readBytes(bytes)
24+
25+
bytes === sampleMessage
26+
}
27+
28+
}
29+
30+
}

0 commit comments

Comments
 (0)
0