@@ -8,6 +8,7 @@ import Control.Monad
8
8
import Data.Traversable
9
9
import Data.Foldable
10
10
import Control.Applicative
11
+ import Control.Exception
11
12
import Data.IORef
12
13
import Data.Monoid
13
14
import Control.Concurrent (forkIO , killThread , ThreadId , threadDelay )
@@ -88,11 +89,15 @@ connectWith
88
89
-> ServerMessageFilter
89
90
-> IO (Either Error Connection )
90
91
connectWith settings msgFilter =
91
- createRawConnection settings >>=
92
- either throwErrorInIO (\ rawConn ->
93
- authorize rawConn settings >>=
94
- either throwErrorInIO (\ params ->
95
- Right <$> buildConnection rawConn params msgFilter))
92
+ bracketOnError
93
+ (createRawConnection settings)
94
+ (either throwErrorInIO rClose)
95
+ (either throwErrorInIO performAuth)
96
+ where
97
+ performAuth rawConn = authorize rawConn settings >>= either
98
+ -- We should close connection on an authorization failure
99
+ (\ e -> rClose rawConn >> throwErrorInIO e)
100
+ (\ params -> Right <$> buildConnection rawConn params msgFilter)
96
101
97
102
-- | Authorizes on the server and reads connection parameters.
98
103
authorize
@@ -108,8 +113,8 @@ authorize rawConn settings = do
108
113
readAuthResponse = do
109
114
-- 4096 should be enough for the whole response from a server at
110
115
-- the startup phase.
111
- r <- rReceive rawConn 4096
112
- case runDecode decodeAuthResponse r of
116
+ resp <- rReceive rawConn 4096
117
+ case runDecode decodeAuthResponse resp of
113
118
Right (rest, r) -> case r of
114
119
AuthenticationOk ->
115
120
pure $ parseParameters rest
0 commit comments