8000 Changed binary to bytestring-based decoder · postgres-haskell/postgres-wire@93415e0 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 93415e0

Browse files
Changed binary to bytestring-based decoder
1 parent 213bb53 commit 93415e0

File tree

4 files changed

+138
-58
lines changed

4 files changed

+138
-58
lines changed

postgres-wire.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ library
2727
, Database.PostgreSQL.Protocol.Encoders
2828
, Database.PostgreSQL.Protocol.Decoders
2929
, Database.PostgreSQL.Protocol.Store.Encode
30+
, Database.PostgreSQL.Protocol.Store.Decode
3031
build-depends: base >= 4.7 && < 5
3132
, bytestring
3233
, socket

src/Database/PostgreSQL/Driver/Connection.hs

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import Database.PostgreSQL.Protocol.Encoders
2323
import Database.PostgreSQL.Protocol.Decoders
2424
import Database.PostgreSQL.Protocol.Types
2525
import Database.PostgreSQL.Protocol.Store.Encode (runEncode)
26+
import Database.PostgreSQL.Protocol.Store.Decode (runDecode)
2627

2728
import Database.PostgreSQL.Driver.Settings
2829
import Database.PostgreSQL.Driver.StatementStorage
@@ -125,8 +126,8 @@ authorize rawConn settings = do
125126
-- 4096 should be enough for the whole response from a server at
126127
-- the startup phase.
127128
r <- rReceive rawConn 4096
128-
case pushChunk (runGetIncremental decodeAuthResponse) r of
129-
BG.Done rest _ r -> case r of
129+
case runDecode decodeAuthResponse r of
130+
Right (rest, r) -> case r of
130131
AuthenticationOk ->
131132
pure $ parseParameters rest
132133
AuthenticationCleartextPassword ->
@@ -141,10 +142,7 @@ authorize rawConn settings = do
141142
throwAuthErrorInIO $ AuthNotSupported "GSS"
142143
AuthErrorResponse desc ->
143144
throwErrorInIO $ PostgresError desc
144-
-- this case is near impossible and ignored
145-
BG.Partial _ -> throwErrorInIO $
146-
DecodeError "partial auth response"
147-
BG.Fail _ _ reason -> throwErrorInIO . DecodeError $ BS.pack reason
145+
Left reason -> throwErrorInIO . DecodeError $ BS.pack reason
148146

149147
performPasswordAuth password = do
150148
sendMessage rawConn $ PasswordMessage password
@@ -174,16 +172,13 @@ parseParameters str = do
174172
. HM.lookup key
175173
parseBool bs | bs == "on" || bs == "yes" || bs == "1" = True
176174
| otherwise = False
177-
decoder = runGetIncremental decodeServerMessage
178175
go str dict | B.null str = Right dict
179-
| otherwise = case pushChunk decoder str of
180-
BG.Done rest _ v -> case v of
176+
| otherwise = case runDecode decodeServerMessage str of
177+
Right (rest, v) -> case v of
181178
ParameterStatus name value -> go rest $ HM.insert name value dict
182179
-- messages like `BackendData` not handled
183180
_ -> go rest dict
184-
-- this case is near impossible and ignored
185-
BG.Partial _ -> Left $ DecodeError "partial auth response"
186-
BG.Fail _ _ reason -> Left . DecodeError $ BS.pack reason
181+
Left reason -> Left . DecodeError $ BS.pack reason
187182

188183
parseServerVersion :: B.ByteString -> Either Error ServerVersion
189184
parseServerVersion bs =
@@ -222,19 +217,16 @@ receiverThread msgFilter rawConn dataChan allChan modeRef = receiveLoop []
222217
-- print r
223218
go r acc >>= receiveLoop
224219

225-
decoder = runGetIncremental decodeServerMessage
226220
go :: B.ByteString -> [V.Vector (Maybe B.ByteString)] -> IO [V.Vector (Maybe B.ByteString)]
227-
go str acc = case pushChunk decoder str of
228-
BG.Done rest _ v -> do
221+
go str acc = case runDecode decodeServerMessage str of
222+
Right (rest, v) -> do
229223
when (msgFilter v) $ writeChan allChan v
230224
mode <- readIORef modeRef
231225
newAcc <- dispatch mode dataChan v acc
232226
if B.null rest
233227
then pure newAcc
234228
else go rest newAcc
235-
-- TODO right parsing
236-
BG.Partial _ -> error "Partial"
237-
BG.Fail _ _ reason -> error reason
229+
Left reason -> error reason
238230

239231
dispatch :: ConnectionMode -> Dispatcher
240232
dispatch SimpleQueryMode = dispatchSimple

src/Database/PostgreSQL/Protocol/Decoders.hs

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ import qualified Data.ByteString as B
1414
import Data.ByteString.Char8 (readInteger, readInt)
1515
import qualified Data.ByteString.Lazy as BL
1616
import qualified Data.HashMap.Strict as HM
17-
import Data.Binary.Get
1817

1918
import Database.PostgreSQL.Protocol.Types
19+
import Database.PostgreSQL.Protocol.Store.Decode
2020

21-
decodeAuthResponse :: Get AuthResponse
21+
decodeAuthResponse :: Decode AuthResponse
2222
decodeAuthResponse = do
2323
c <- getWord8
24-
len <- getInt32be
24+
len <- getInt32BE
2525
case chr $ fromIntegral c of
2626
'E' -> AuthErrorResponse <$>
2727
(getByteString (fromIntegral $ len - 4) >>= decodeErrorDesc)
2828
'R' -> do
29-
rType <- getInt32be
29+
rType <- getInt32BE
3030
case rType of
3131
0 -> pure AuthenticationOk
3232
3 -> pure AuthenticationCleartextPassword
@@ -38,19 +38,19 @@ decodeAuthResponse = do
3838
_ -> fail "Unknown authentication response"
3939
_ -> fail "Invalid auth response"
4040

41-
decodeServerMessage :: Get ServerMessage
41+
decodeServerMessage :: Decode ServerMessage
4242
decodeServerMessage = do
4343
c <- getWord8
44-
len <- getInt32be
44+
len <- getInt32BE
4545
case chr $ fromIntegral c of
46-
'K' -> BackendKeyData <$> (ServerProcessId <$> getInt32be)
47-
<*> (ServerSecretKey <$> getInt32be)
46+
'K' -> BackendKeyData <$> (ServerProcessId <$> getInt32BE)
47+
<*> (ServerSecretKey <$> getInt32BE)
4848
'2' -> pure BindComplete
4949
'3' -> pure CloseComplete
5050
'C' -> CommandComplete <$> (getByteString (fromIntegral $ len - 4)
5151
>>= decodeCommandResult)
5252
'D' -> do
53-
columnCount <- fromIntegral <$> getInt16be
53+
columnCount <- fromIntegral <$> getInt16BE
5454
DataRow <$> V.replicateM columnCount decodeValue
5555
'I' -> pure EmptyQueryResponse
5656
'E' -> ErrorResponse <$>
@@ -60,55 +60,57 @@ decodeServerMessage = do
6060
(getByteString (fromIntegral $ len - 4) >>= decodeNoticeDesc)
6161
'A' -> NotificationResponse <$> decodeNotification
6262
't' -> do
63-
paramCount <- fromIntegral <$> getInt16be
63+
paramCount <- fromIntegral <$> getInt16BE
6464
ParameterDescription <$> V.replicateM paramCount
65-
(Oid <$> getInt32be)
66-
'S' -> ParameterStatus <$> decodePgString <*> decodePgString
65+
(Oid <$> getInt32BE)
66+
'S' -> ParameterStatus <$> getByteStringNull <*> getByteStringNull
6767
'1' -> pure ParseComplete
6868
's' -> pure PortalSuspended
6969
'Z' -> ReadForQuery <$> decodeTransactionStatus
7070
'T' -> do
71-
rowsCount <- fromIntegral <$> getInt16be
71+
rowsCount <- fromIntegral <$> getInt16BE
7272
RowDescription <$> V.replicateM rowsCount decodeFieldDescription
7373

7474
-- | Decodes a single data value. Length `-1` indicates a NULL column value.
7575
-- No value bytes follow in the NULL case.
76-
decodeValue :: Get (Maybe B.ByteString)
77-
decodeValue = fromIntegral <$> getInt32be >>= \n ->
78-
if n == -1 then pure Nothing else Just <$> getByteString n
76+
decodeValue :: Decode (Maybe B.ByteString)
77+
decodeValue = fromIntegral <$> getInt32BE >>= \n ->
78+
if n == -1
79+
then pure Nothing
80+
else Just <$> getByteString n
7981

80-
decodeTransactionStatus :: Get TransactionStatus
82+
decodeTransactionStatus :: Decode TransactionStatus
8183
decodeTransactionStatus = getWord8 >>= \t ->
8284
case chr $ fromIntegral t of
8385
'I' -> pure TransactionIdle
8486
'T' -> pure TransactionInBlock
8587
'E' -> pure TransactionFailed
8688
_ -> fail "unknown transaction status"
8789

88-
decodeFieldDescription :: Get FieldDescription
90+
decodeFieldDescription :: Decode FieldDescription
8991
decodeFieldDescription = FieldDescription
90-
<$> decodePgString
91-
<*> (Oid <$> getInt32be)
92-
<*> getInt16be
93-
<*> (Oid <$> getInt32be)
94-
<*> getInt16be
95-
<*> getInt32be
92+
<$> getByteStringNull
93+
<*> (Oid <$> getInt32BE)
94+
<*> getInt16BE
95+
<*> (Oid <$> getInt32BE)
96+
<*> getInt16BE
97+
<*> getInt32BE
9698
<*> decodeFormat
9799

98-
decodeNotification :: Get Notification
100+
decodeNotification :: Decode Notification
99101
decodeNotification = Notification
100-
<$> (ServerProcessId <$> getInt32be)
101-
<*> (ChannelName <$> decodePgString)
102-
<*> decodePgString
102+
<$> (ServerProcessId <$> getInt32BE)
103+
<*> (ChannelName <$> getByteStringNull)
104+
<*> getByteStringNull
103105

104-
decodeFormat :: Get Format
105-
decodeFormat = getInt16be >>= \f ->
106+
decodeFormat :: Decode Format
107+ B41A
decodeFormat = getInt16BE >>= \f ->
106108
case f of
107109
0 -> pure Text
108110
1 -> pure Binary
109111
_ -> fail "Unknown field format"
110112

111-
decodeCommandResult :: B.ByteString -> Get CommandResult
113+
decodeCommandResult :: B.ByteString -> Decode CommandResult
112114
decodeCommandResult s =
113115
let (command, rest) = B.break (== space) s
114116
in case command of
@@ -151,7 +153,7 @@ decodeNoticeSeverity "INFO" = SeverityInfo
151153
decodeNoticeSeverity "LOG" = SeverityLog
152154
decodeNoticeSeverity _ = UnknownNoticeSeverity
153155

154-
decodeErrorDesc :: B.ByteString -> Get ErrorDesc
156+
decodeErrorDesc :: B.ByteString -> Decode ErrorDesc
155157
decodeErrorDesc s = do
156158
let hm = decodeErrorNoticeFields s
157159
errorSeverityOld <- lookupKey 'S' hm
@@ -184,7 +186,7 @@ decodeErrorDesc s = do
184186
"is not presented in ErrorResponse message")
185187
pure . HM.lookup c
186188

187-
decodeNoticeDesc :: B.ByteString -> Get NoticeDesc
189+
decodeNoticeDesc :: B.ByteString -> Decode NoticeDesc
188190
decodeNoticeDesc s = do
189191
let hm = decodeErrorNoticeFields s
190192
noticeSeverityOld <- lookupKey 'S' hm
@@ -217,10 +219,3 @@ decodeNoticeDesc s = do
217219
"is not presented in NoticeResponse message")
218220
pure . HM.lookup c
219221

220-
------
221-
-- Utils
222-
------
223-
224-
decodePgString :: Get B.ByteString
225-
decodePgString = BL.toStrict <$> getLazyByteStringNul
226-
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
module Database.PostgreSQL.Protocol.Store.Decode where
2+
3+
import Prelude hiding (takeWhile)
4+
import qualified Data.ByteString as B
5+
import Data.Word
6+
import Data.Int
7+
import Data.Tuple
8+
9+
import Control.Monad
10+
import Control.Applicative
11+
12+
-- Change to Ptr-based parser later
13+
data Decode a = Decode
14+
{ runDecode :: B.ByteString -> Either String (B.ByteString, a)}
15+
16+
instance Functor Decode where
17+
fmap f p = Decode $ fmap (fmap f) . runDecode p
18+
19+
instance Applicative Decode where
20+
pure x = Decode $ \bs -> Right (bs, x)
21+
22+
p1 <*> p2 = Decode $ \bs -> do
23+
(bs2, f) <- runDecode p1 bs
24+
(bs3, x) <- runDecode p2 bs2
25+
pure (bs3, f x)
26+
27+
instance Monad Decode where
28+
return = pure
29+
30+
p >>= f = Decode $ \bs -> do
31+
(bs2, x) <- runDecode p bs
32+
runDecode (f x) bs2
33+
34+
fail = Decode . const . Left
35+
36+
checkLen :: B.ByteString -> Int -> Either String ()
37+
checkLen bs len | len > B.length bs = Left "too many bytes to read"
38+
| otherwise = Right ()
39+
40+
41+
takeWhile :: (Word8 -> Bool) -> Decode B.ByteString
42+
takeWhile f = Decode $ \bs -> Right . swap $ B.span f bs
43+
44+
getByte :: Decode Word8
45+
getByte = Decode $ \bs -> do
46+
checkLen bs 1
47+
Right (B.drop 1 bs, B.index bs 0)
48+
49+
getTwoBytes :: Decode (Word8, Word8)
50+
getTwoBytes = Decode $ \bs -> do
51+
checkLen bs 2
52+
Right (B.drop 2 bs, (B.index bs 0, B.index bs 1))
53+
54+
getFourBytes :: Decode (Word8, Word8, Word8, Word8)
55+
getFourBytes = Decode $ \bs -> do
56+
checkLen bs 4
57+
Right (B.drop 4 bs, (B.index bs 0, B.index bs 1, B.index bs 2, B.index bs 3))
58+
59+
-----------
60+
-- Public
61+
62+
getByteString :: Int -> Decode B.ByteString
63+
getByteString len = Decode $ \bs -> do
64+
checkLen bs len
65+
Right . swap $ B.splitAt len bs
66+
67+
getByteStringNull :: Decode B.ByteString
68+
getByteStringNull = takeWhile (/= 0) <* getWord8
69+
70+
getWord8 :: Decode Word8
71+
getWord8 = getByte
72+
73+
getWord16BE :: Decode Word16
74+
getWord16BE = do
75+
(w1, w2) <- getTwoBytes
76+
pure $ fromIntegral w1 * 256 +
77+
fromIntegral w2
78+
79+
getWord32BE :: Decode Word32
80+
getWord32BE = do
81+
(w1, w2, w3, w4) <- getFourBytes
82+
pure $ fromIntegral w1 * 256 *256 *256 +
83+
fromIntegral w2 * 256 *256 +
84+
fromIntegral w3 * 256 +
85+
fromIntegral w4
86+
87+
getInt16BE :: Decode Int16
88+
getInt16BE = fromIntegral <$> getWord16BE
89+
90+
getInt32BE :: Decode Int32
91+
getInt32BE = fromIntegral <$> getWord32BE
92+

0 commit comments

Comments
 (0)
0