8000 More accurate decoders module · postgres-haskell/postgres-wire@9f50853 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9f50853

Browse files
More accurate decoders module
1 parent d758bad commit 9f50853

File tree

1 file changed

+85
-72
lines changed

1 file changed

+85
-72
lines changed

src/Database/PostgreSQL/Protocol/Decoders.hs

Lines changed: 85 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
{-# language RecordWildCards #-}
2-
module Database.PostgreSQL.Protocol.Decoders where
3-
4-
import Data.Word
5-
import Data.Int
6-
import Data.Monoid
7-
import Data.Maybe (fromMaybe)
8-
import Data.Foldable
9-
import Data.Char (chr)
10-
import Control.Applicative
11-
import Control.Monad
12-
import Text.Read
2+
3+
module Database.PostgreSQL.Protocol.Decoders
4+
( decodeAuthResponse
5+
, decodeServerMessage
6+
-- * Helpers
7+
, parseServerVersion
8+
, parseIntegerDatetimes
9+
) where
10+
11+
import Control.Applicative
12+
import Control.Monad
13+
import Data.Monoid ((<>))
14+
import Data.Maybe (fromMaybe)
15+
import Data.Char (chr)
16+
import Text.Read (readMaybe)
1317
import qualified Data.Vector as V
1418
import qualified Data.ByteString as B
15-
import Data.ByteString.Char8 as BS(readInteger, readInt, unpack)
16-
import qualified Data.ByteString.Lazy as BL
19+
import Data.ByteString.Char8 as BS(readInteger, readInt, unpack, pack)
1720
import qualified Data.HashMap.Strict as HM
1821

1922
import Database.PostgreSQL.Protocol.Types
@@ -25,7 +28,8 @@ decodeAuthResponse = do
2528
len <- getInt32BE
2629
case chr $ fromIntegral c of
2730
'E' -> AuthErrorResponse <$>
28-
(getByteString (fromIntegral $ len - 4) >>= decodeErrorDesc)
31+
(getByteString (fromIntegral $ len - 4) >>=
32+
eitherToDecode .parseErrorDesc)
2933
'R' -> do
3034
rType <- getInt32BE
3135
case rType of
@@ -49,16 +53,18 @@ decodeServerMessage = do
4953
'2' -> pure BindComplete
5054
'3' -> pure CloseComplete
5155
'C' -> CommandComplete <$> (getByteString (fromIntegral $ len - 4)
52-
>>= decodeCommandResult)
56+
>>= eitherToDecode . parseCommandResult)
5357
'D' -> do
5458
columnCount <- fromIntegral <$> getInt16BE
5559
DataRow <$> V.replicateM columnCount decodeValue
5660
'I' -> pure EmptyQueryResponse
5761
'E' -> ErrorResponse <$>
58-
(getByteString (fromIntegral $ len - 4) >>= decodeErrorDesc)
62+
(getByteString (fromIntegral $ len - 4) >>=
63+
eitherToDecode . parseErrorDesc)
5964
'n' -> pure NoData
6065
'N' -> NoticeResponse <$>
61-
(getByteString (fromIntegral $ len - 4) >>= decodeNoticeDesc)
66+
(getByteString (fromIntegral $ len - 4) >>=
67+
eitherToDecode . parseNoticeDesc)
6268
'A' -> NotificationResponse <$> decodeNotification
6369
't' -> do
6470
paramCount <- fromIntegral <$> getInt16BE
@@ -75,10 +81,10 @@ decodeServerMessage = do
7581
-- | Decodes a single data value. Length `-1` indicates a NULL column value.
7682
-- No value bytes follow in the NULL case.
7783
decodeValue :: Decode (Maybe B.ByteString)
78-
decodeValue = fromIntegral <$> getInt32BE >>= \n ->
84+
decodeValue = getInt32BE >>= \n ->
7985
if n == -1
8086
then pure Nothing
81-
else Just <$> getByteString n
87+
else Just <$> getByteString (fromIntegral n)
8288

8389
decodeTransactionStatus :: Decode TransactionStatus
8490
decodeTransactionStatus = getWord8 >>= \t ->
@@ -111,29 +117,7 @@ decodeFormat = getInt16BE >>= \f ->
111117
1 -> pure Binary
112118
_ -> fail "Unknown field format"
113119

114-
decodeCommandResult :: B.ByteString -> Decode CommandResult
115-
decodeCommandResult s =
116-
let (command, rest) = B.break (== space) s
117-
in case command of
118-
-- format: `INSERT oid rows`
119-
"INSERT" ->
120-
maybe (fail "Invalid format in INSERT command result") pure $ do
121-
(oid, r) <- readInteger $ B.dropWhile (== space) rest
122-
(rows, _) <- readInteger $ B.dropWhile (== space) r
123-
pure $ InsertCompleted (Oid $ fromInteger oid)
124-
(RowsCount $ fromInteger rows)
125-
"DELETE" -> DeleteCompleted <$> readRows rest
126-
"UPDATE" -> UpdateCompleted <$> readRows rest
127-
"SELECT" -> SelectCompleted <$> readRows rest
128-
"MOVE" -> MoveCompleted <$> readRows rest
129-
"FETCH" -> FetchCompleted <$> readRows rest
130-
"COPY" -> CopyCompleted <$> readRows rest
131-
_ -> pure CommandOk
132-
where
133-
space = 32
134-
readRows = maybe (fail "Invalid rows format in command result")
135-
(pure . RowsCount . fromInteger . fst)
136-
. readInteger . B.dropWhile (== space)
120+
-- Parser that just work with B.ByteString, not Decode type
137121

138122
-- Helper to parse, not used by decoder itself
139123
parseServerVersion :: B.ByteString -> Maybe ServerVersion
@@ -154,28 +138,54 @@ parseIntegerDatetimes :: B.ByteString -> Bool
154138
parseIntegerDatetimes bs | bs == "on" || bs == "yes" || bs == "1" = True
155139
| otherwise = False
156140

157-
decodeErrorNoticeFields :: B.ByteString -> HM.HashMap Char B.ByteString
158-
decodeErrorNoticeFields = HM.fromList
141+
parseCommandResult :: B.ByteString -> Either B.ByteString CommandResult
142+
parseCommandResult s =
143+
let (command, rest) = B.break (== space) s
144+
in case command of
145+
-- format: `INSERT oid rows`
146+
"INSERT" ->
147+
maybe (Left "Invalid format in INSERT command result") Right $ do
148+
(oid, r) <- readInteger $ B.dropWhile (== space) rest
149+
(rows, _) <- readInteger $ B.dropWhile (== space) r
150+
Just $ InsertCompleted (Oid $ fromInteger oid)
151+
(RowsCount $ fromInteger rows)
152+
"DELETE" -> DeleteCompleted <$> readRows rest
153+
"UPDATE" -> UpdateCompleted <$> readRows rest
154+
"SELECT" -> SelectCompleted <$> readRows rest
155+
"MOVE" -> MoveCompleted <$> readRows rest
156+
"FETCH" -> FetchCompleted <$> readRows rest
157+
"COPY" -> CopyCompleted <$> readRows rest
158+
_ -> Right CommandOk
159+
where
160+
space = 32
161+
readRows = maybe (Left "Invalid rows format in command result")
162+
(pure . RowsCount . fromInteger . fst)
163+
. readInteger . B.dropWhile (== space)
164+
165+
parseErrorNoticeFields :: B.ByteString -> HM.HashMap Char B.ByteString
166+
parseErrorNoticeFields = HM.fromList
159167
. fmap (\s -> (chr . fromIntegral $ B.head s, B.tail s))
160168
. filter (not . B.null) . B.split 0
161169

162-
decodeErrorSeverity :: B.ByteString -> ErrorSeverity
163-
decodeErrorSeverity "ERROR" = SeverityError
164-
decodeErrorSeverity "FATAL" = SeverityFatal
165-
decodeErrorSeverity "PANIC" = SeverityPanic
166-
decodeErrorSeverity _ = UnknownErrorSeverity
167-
168-
decodeNoticeSeverity :: B.ByteString -> NoticeSeverity
169-
decodeNoti 10000 ceSeverity "WARNING" = SeverityWarning
170-
decodeNoticeSeverity "NOTICE" = SeverityNotice
171-
decodeNoticeSeverity "DEBUG" = SeverityDebug
172-
decodeNoticeSeverity "INFO" = SeverityInfo
173-
decodeNoticeSeverity "LOG" = SeverityLog
174-
decodeNoticeSeverity _ = UnknownNoticeSeverity
175-
176-
decodeErrorDesc :: B.ByteString -> Decode ErrorDesc
177-
decodeErrorDesc s = do
178-
let hm = decodeErrorNoticeFields s
170+
parseErrorSeverity :: B.ByteString -> ErrorSeverity
171+
parseErrorSeverity bs = case bs of
172+
"ERROR" -> SeverityError
173+
"FATAL" -> SeverityFatal
174+
"PANIC" -> SeverityPanic
175+
_ -> UnknownErrorSeverity
176+
177+
parseNoticeSeverity :: B.ByteString -> NoticeSeverity
178+
parseNoticeSeverity bs = case bs of
179+
"WARNING" -> SeverityWarning
180+
"NOTICE" -> SeverityNotice
181+
"DEBUG" -> SeverityDebug
182+
"INFO" -> SeverityInfo
183+
"LOG" -> SeverityLog
184+
_ -> UnknownNoticeSeverity
185+
186+
parseErrorDesc :: B.ByteString -> Either B.ByteString ErrorDesc
187+
parseErrorDesc s = do
188+
let hm = parseErrorNoticeFields s
179189
errorSeverityOld <- lookupKey 'S' hm
180190
errorCode <- lookupKey 'C' hm
181191
errorMessage <- lookupKey 'M' hm
@@ -184,7 +194,7 @@ decodeErrorDesc s = do
184194
-- never localized. This is present only in messages generated by
185195
-- PostgreSQL versions 9.6 and later.
186196
errorSeverityNew = HM.lookup 'V' hm
187-
errorSeverity = decodeErrorSeverity $
197+
errorSeverity = parseErrorSeverity $
188198
fromMaybe errorSeverityOld errorSeverityNew
189199
errorDetail = HM.lookup 'D' hm
190200
errorHint = HM.lookup 'H' hm
@@ -200,15 +210,15 @@ decodeErrorDesc s = do
200210
errorSourceFilename = HM.lookup 'F' hm
201211
errorSourceLine = HM.lookup 'L' hm >>= fmap fst . readInt
202212
errorSourceRoutine = HM.lookup 'R' hm
203-
pure ErrorDesc{..}
213+
Right ErrorDesc{..}
204214
where
205-
lookupKey c = maybe (fail $ "Neccessary key " ++ show c ++
215+
lookupKey c = maybe (Left $ "Neccessary key " <> BS.pack (show c) <>
206216
"is not presented in ErrorResponse message")
207-
pure . HM.lookup c
217+
Right . HM.lookup c
208218

209-
decodeNoticeDesc :: B.ByteString -> Decode NoticeDesc
210-
decodeNoticeDesc s = do
211-
let hm = decodeErrorNoticeFields s
219+
parseNoticeDesc :: B.ByteString -> Either B.ByteString NoticeDesc
220+
parseNoticeDesc s = do
221+
let hm = parseErrorNoticeFields s
212222
noticeSeverityOld <- lookupKey 'S' hm
213223
noticeCode <- lookupKey 'C' hm
214224
noticeMessage <- lookupKey 'M' hm
@@ -217,7 +227,7 @@ decodeNoticeDesc s = do
217227
-- never localized. This is present only in messages generated by
218228
-- PostgreSQL versions 9.6 and later.
219229
noticeSeverityNew = HM.lookup 'V' hm
220-
noticeSeverity = decodeNoticeSeverity $
230+
noticeSeverity = parseNoticeSeverity $
221231
fromMaybe noticeSeverityOld noticeSeverityNew
222232
noticeDetail = HM.lookup 'D' hm
223233
noticeHint = HM.lookup 'H' hm
@@ -233,9 +243,12 @@ decodeNoticeDesc s = do
233243
noticeSourceFilename = HM.lookup 'F' hm
234244
noticeSourceLine = HM.lookup 'L' hm >>= fmap fst . readInt
235245
noticeSourceRoutine = HM.lookup 'R' hm
236-
pure NoticeDesc{..}
246+
Right NoticeDesc{..}
237247
where
238-
lookupKey c = maybe (fail $ "Neccessary key " ++ show c ++
248+
lookupKey c = maybe (Left $ "Neccessary key " <> BS.pack (show c) <>
239249
"is not presented in NoticeResponse message")
240-
pure . HM.lookup c
250+
Right . HM.lookup c
251+
252+
eitherToDecode :: Either B.ByteString a -> Decode a
253+
eitherToDecode = either (fail . BS.unpack) pure
241254

0 commit comments

Comments
 (0)
0