8000 Added connection errors · postgres-haskell/postgres-wire@f2a59ae · GitHub
[go: up one dir, main page]

Skip to content

Commit f2a59ae

Browse files
Added connection errors
1 parent eda8180 commit f2a59ae

File tree

2 files changed

+78
-52
lines changed

2 files changed

+78
-52
lines changed

src/Database/PostgreSQL/Driver/Connection.hs

Lines changed: 71 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
{-# LANGUAGE RecursiveDo #-}
21
module Database.PostgreSQL.Driver.Connection where
32

43

@@ -62,12 +61,13 @@ type NotificationHandler = Notification -> IO ()
6261
-- All possible at errors
6362
data Error
6463
= PostgresError ErrorDesc
65-
| ImpossibleError
64+
| AuthError AuthError
65+
| ImpossibleError B.ByteString
6666
deriving (Show)
6767

6868
data AuthError
69-
= AuthPostgresError ErrorDesc
70-
| AuthNotSupported B.ByteString
69+
= AuthNotSupported B.ByteString
70+
| AuthInvalidAddress
7171
deriving (Show)
7272

7373
data DataMessage = DataMessage [V.Vector B.ByteString]
@@ -87,32 +87,39 @@ defaultUnixPathDirectory = "/var/run/postgresql"
8787
unixPathFilename :: B.ByteString
8888
unixPathFilename = ".s.PGSQL."
8989

90-
createRawConnection :: ConnectionSettings -> IO RawConnection
90+
91+
createRawConnection :: ConnectionSettings -> IO (Either Error RawConnection)
9192
createRawConnection settings
9293
| host == "" = unixConnection defaultUnixPathDirectory
9394
| "/" `B.isPrefixOf` host = unixConnection host
9495
| otherwise = tcpConnection
9596
where
96-
host = settingsHost settings
9797
unixConnection dirPath = do
98-
-- 47 - `/`
99-
let dir = B.reverse . B.dropWhile (== 47) $ B.reverse dirPath
100-
path = dir <> "/" <> unixPathFilename
101-
<> BS.pack (show $ settingsPort settings)
102-
-- TODO check for Nothing
103-
address = fromJust $ socketAddressUnixPath path
104-
s <- socket :: IO (Socket Unix Stream Unix)
105-
Socket.connect s address
106-
pure $ constructRawConnection s
98+
let mAddress = socketAddressUnixPath $ makeUnixPath dirPath
99+
case mAddress of
100+
Nothing -> throwAuthErrorInIO AuthInvalidAddress
101+
Just address -> do
102+
s <- socket :: IO (Socket Unix Stream Unix)
103+
Socket.connect s address
104+
pure . Right $ constructRawConnection s
105+
107106
tcpConnection = do
108107
addressInfo <- getAddressInfo (Just host) Nothing aiV4Mapped
109108
:: IO [AddressInfo Inet Stream TCP]
110-
let address = (socketAddress $ head addressInfo)
111-
{ inetPort = fromIntegral $ settingsPort settings }
112-
-- TODO check for empty
113-
s <- socket :: IO (Socket Inet Stream TCP)
114-
Socket.connect s address
115-
pure $ constructRawConnection s
109+
case socketAddress <$> addressInfo of
110+
[] -> throwAuthErrorInIO AuthInvalidAddress
111+
(address:_) -> do
112+
s <- socket :: IO (Socket Inet Stream TCP)
113+
Socket.connect s address
114+
{ inetPort = fromIntegral $ settingsPort settings }
115+
pure . Right $ constructRawConnection s
116+
117+
host = settingsHost settings
118+
makeUnixPath dirPath =
119+
-- 47 - `/`, removing slash on the end of the path
120+
let dir = B.reverse . B.dropWhile (== 47) $ B.reverse dirPath
121+
in dir <> "/" <> unixPathFilename
122+
<> BS.pack (show $ settingsPort settings)
116123

117124
constructRawConnection :: Socket f t p -> RawConnection
118125
constructRawConnection s = RawConnection
@@ -123,40 +130,47 @@ constructRawConnection s = RawConnection
123130
}
124131

125132
-- | Public
126-
connect :: ConnectionSettings -> IO Connection
133+
connect :: ConnectionSettings -> IO (Either Error Connection)
127134
connect settings = connectWith settings defaultFilter
128135

129-
connectWith :: ConnectionSettings -> ServerMessageFilter -> IO Connection
130-
connectWith settings msgFilter = do
131-
rawConn <- createRawConnection settings
132-
when (settingsTls settings == RequiredTls) $ handshakeTls rawConn
133-
authResult <- authorize rawConn settings
134-
-- TODO should close connection on error
135-
connParams <- either (\e -> print e >> error "invalid connection")
136-
pure authResult
137-
136+
connectWith
137+
:: ConnectionSettings
138+
-> ServerMessageFilter
139+
-> IO (Either Error Connection)
140+
connectWith settings msgFilter =
141+
createRawConnection settings >>=
142+
either throwErrorInIO (\rawConn ->
143+
authorize rawConn settings >>=
144+
either throwErrorInIO (\params ->
145+
Right <$> buildConnection rawConn params msgFilter))
146+
147+
buildConnection
148+
:: RawConnection
149+
-> ConnectionParameters
150+
-> ServerMessageFilter
151+
-> IO Connection
152+
buildConnection rawConn connParams msgFilter = do
138153
(inDataChan, outDataChan) <- newChan
139154
(inAllChan, outAllChan) <- newChan
140155
storage <- newStatementStorage
141156
modeRef <- newIORef defaultConnectionMode
142157

143158
tid <- forkIO $
144159
receiverThread msgFilter rawConn inDataChan inAllChan modeRef
145-
rec conn <- pure Connection
146-
{ connRawConnection = rawConn
147-
, connReceiverThread = tid
148-
, connOutDataChan = outDataChan
149-
, connOutAllChan = outAllChan
150-
, connStatementStorage = storage
151-
, connParameters = connParams
152-
, connMode = modeRef
153-
}
154-
pure conn
160+
pure Connection
161+
{ connRawConnection = rawConn
162+
, connReceiverThread = tid
163+
, connOutDataChan = outDataChan
164+
, connOutAllChan = outAllChan
165+
, connStatementStorage = storage
166+
, connParameters = connParams
167+
, connMode = modeRef
168+
}
155169

156170
authorize
157171
:: RawConnection
158172
-> ConnectionSettings
159-
-> IO (Either AuthError ConnectionParameters)
173+
-> IO (Either Error ConnectionParameters)
160174
authorize rawConn settings = do
161175
sendStartMessage rawConn $ consStartupMessage settings
162176
-- 4096 should be enough for the whole response from a server at
@@ -173,15 +187,17 @@ authorize rawConn settings = do
173187
let pass = "md5" <> md5Hash (md5Hash (settingsPassword settings
174188
<> settingsUser settings) <> salt)
175189
in performPasswordAuth $ PasswordMD5 pass
176-
AuthenticationGSS -> pure $ Left $ AuthNotSupported "GSS"
177-
AuthenticationSSPI -> pure $ Left $ AuthNotSupported "SSPI"
178-
AuthenticationGSSContinue _ -> pure $ Left $ AuthNotSupported "GSS"
179-
AuthErrorResponse desc -> pure $ Left $ AuthPostgresError desc
190+
AuthenticationGSS ->
191+
throwAuthErrorInIO $ AuthNotSupported "GSS"
192+
AuthenticationSSPI ->
193+
throwAuthErrorInIO $ AuthNotSupported "SSPI"
194+
AuthenticationGSSContinue _ ->
195+
throwAuthErrorInIO $ AuthNotSupported "GSS"
196+
AuthErrorResponse desc ->
197+
throwErrorInIO $ PostgresError desc
180198
-- TODO handle this case
181199
f -> error "athorize"
182200
where
183-
performPasswordAuth
184-
:: PasswordText -> IO (Either AuthError ConnectionParameters)
185201
performPasswordAuth password = do
186202
sendMessage rawConn $ PasswordMessage password
187203
r <- rReceive rawConn 4096
@@ -190,7 +206,7 @@ authorize rawConn settings = do
190206
AuthenticationOk ->
191207
pure $ Right $ parseParameters rest
192208
AuthErrorResponse desc ->
193-
pure $ Left $ AuthPostgresError desc
209+
throwErrorInIO $ PostgresError desc
194210
_ -> error "Impossible happened"
195211
-- TODO handle this case
196212
f -> error "authorize"
@@ -321,7 +337,6 @@ sendMessage rawConn msg = void $ do
321337
let smsg = toStrict . toLazyByteString $ encodeClientMessage msg
322338
rSend rawConn smsg
323339

324-
325340
-- Public
326341
data Query = Query
327342
{ qStatement :: B.ByteString
@@ -422,3 +437,9 @@ describeStatement conn stmt = do
422437
xs -> maybe (error "Impossible happened") (Left . PostgresError )
423438
$ findFirstError xs
424439

440+
throwErrorInIO :: Error -> IO (Either Error a)
441+
throwErrorInIO = pure . Left
442+
443+
throwAuthErrorInIO :: AuthError -> IO (Either Error a)
444+
throwAuthErrorInIO = pure . Left . AuthError
445+

tests/Connection.hs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ import Database.PostgreSQL.Driver.Settings
66

77
-- | Creates connection with default filter.
88
withConnection :: (Connection -> IO a) -> IO a
9-
withConnection = bracket (connect defaultSettings) close
9+
withConnection = bracket (getConnection <$> connect defaultSettings) close
1010

1111
-- | Creates connection than collects all server messages in chan.
1212
withConnectionAll :: (Connection -> IO a) -> IO a
13-
withConnectionAll = bracket (connectWith defaultSettings filterAllowedAll) close
13+
withConnectionAll = bracket
14+
(getConnection <$> connectWith defaultSettings filterAllowedAll) close
1415

1516
defaultSettings = defaultConnectionSettings
1617
{ settingsHost = "localhost"
@@ -19,3 +20,7 @@ defaultSettings = defaultConnectionSettings
1920
, settingsPassword = ""
2021
}
2122

23+
getConnection :: Either Error Connection -> Connection
24+
getConnection (Left e) = error $ "Connection error " ++ show e
25+
getConnection (Right c) = c
26+

0 commit comments

Comments
 (0)
0