8000 Connection mode · postgres-haskell/postgres-wire@0840e39 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0840e39

Browse files
Connection mode
1 parent 422378e commit 0840e39

File tree

3 files changed

+65
-27
lines changed

3 files changed

+65
-27
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ install:
4040

4141
script:
4242
# Build the package, its tests, and its docs and run the tests
43-
- stack --no-terminal test :postgres-wire-test-connection
43+
- stack --no-terminal test :postgres-wire-test

src/Database/PostgreSQL/Driver/Connection.hs

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE RecursiveDo #-}
12
module Database.PostgreSQL.Driver.Connection where
23

34

@@ -9,6 +10,7 @@ import Control.Monad
910
import Data.Traversable
1011
import Data.Foldable
1112
import Control.Applicative
13+
import Data.IORef
1214
import Data.Monoid
1315
import Control.Concurrent (forkIO, killThread, ThreadId, threadDelay)
1416
import Data.Binary.Get ( runGetIncremental, pushChunk)
@@ -32,6 +34,14 @@ import Database.PostgreSQL.Driver.Settings
3234
import Database.PostgreSQL.Driver.StatementStorage
3335
import Database.PostgreSQL.Driver.Types
3436

37+
data ConnectionMode
38+
-- | In this mode, all result's data is ignored
39+
= SimpleQueryMode
40+
-- | Usual mode
41+
| ExtendedQueryMode
42+
43+
defaultConnectionMode :: ConnectionMode
44+
defaultConnectionMode = ExtendedQueryMode
3545

3646
data Connection = Connection
3747
{ connRawConnection :: RawConnection
@@ -42,6 +52,7 @@ data Connection = Connection
4252
, connOutAllChan :: OutChan ServerMessage
4353
, connStatementStorage :: StatementStorage
4454
, connParameters :: ConnectionParameters
55+
, connMode :: IORef ConnectionMode
4556
}
4657

4758
type ServerMessageFilter = ServerMessage -> Bool
@@ -112,7 +123,10 @@ constructRawConnection s = RawConnection
112123
}
113124

114125
connect :: ConnectionSettings -> IO Connection
115-
connect settings = do
126+
connect settings = connectWith settings defaultFilter
127+
128+
connectWith :: ConnectionSettings -> ServerMessageFilter -> IO Connection
129+
connectWith settings msgFilter = do
116130
rawConn <- createRawConnection settings
117131
when (settingsTls settings == RequiredTls) $ handshakeTls rawConn
118132
authResult <- authorize rawConn settings
@@ -123,16 +137,20 @@ connect settings = do
123137
(inDataChan, outDataChan) <- newChan
124138
(inAllChan, outAllChan) <- newChan
125139
storage <- newStatementStorage
126-
127-
tid <- forkIO $ receiverThread rawConn inDataChan inAllChan
128-
pure Connection
129-
{ connRawConnection = rawConn
130-
, connReceiverThread = tid
131-
, connOutDataChan = outDataChan
132-
, connOutAllChan = outAllChan
133-
, connStatementStorage = storage
134-
, connParameters = connParams
135-
}
140+
modeRef <- newIORef defaultConnectionMode
141+
142+
tid <- forkIO $
143+
receiverThread msgFilter rawConn inDataChan inAllChan modeRef
144+
rec conn <- pure Connection
145+
{ connRawConnection = rawConn
146+
, connReceiverThread = tid
147+
, connOutDataChan = outDataChan
148+
, connOutAllChan = outAllChan
149+
, connStatementStorage = storage
150+
, connParameters = connParams
151+
, connMode = modeRef
152+
}
153+
pure conn
136154

137155
authorize
138156
:: RawConnection
@@ -145,7 +163,7 @@ authorize rawConn settings = do
145163
r <- rReceive rawConn 4096
146164
case pushChunk (runGetIncremental decodeAuthResponse) r of
147165
BG.Done rest _ r -> case r of
148-
AuthenticationOk -> do
166+
AuthenticationOk ->
149167
-- TODO parse parameters
150168
pure $ Right $ parseParameters rest
151169
AuthenticationCleartextPassword ->
@@ -168,7 +186,7 @@ authorize rawConn settings = do
168186
r <- rReceive rawConn 4096
169187
case pushChunk (runGetIncremental decodeAuthResponse) r of
170188
BG.Done rest _ r -> case r of
171-
AuthenticationOk -> do
189+
AuthenticationOk ->
172190
pure $ Right $ parseParameters rest
173191
AuthErrorResponse desc ->
174192
pure $ Left $ AuthPostgresError desc
@@ -209,11 +227,13 @@ sendMessage rawConn msg = void $ do
209227

210228

211229
receiverThread
212-
:: RawConnection
230+
:: ServerMessageFilter
231+
-> RawConnection
213232
-> InChan (Either Error DataMessage)
214233
-> InChan ServerMessage
234+
-> IORef ConnectionMode
215235
-> IO ()
216-
receiverThread rawConn dataChan allChan = receiveLoop []
236+
receiverThread msgFilter rawConn dataChan allChan modeRef = receiveLoop []
217237
where
218238
receiveLoop :: [V.Vector B.ByteString] -> IO()
219239
receiveLoop acc = do
@@ -226,8 +246,7 @@ receiverThread rawConn dataChan allChan = receiveLoop []
226246
go str acc = case pushChunk decoder str of
227247
BG.Done rest _ v -> do
228248
-- putStrLn $ "Received: " ++ show v
229-
-- TODO select filter
230-
when (defaultFilter v) $ writeChan allChan v
249+
when (msgFilter v) $ writeChan allChan v
231250
newAcc <- dispatch v acc
232251
if B.null rest
233252
then pure newAcc
@@ -262,6 +281,10 @@ receiverThread rawConn dataChan allChan = receiveLoop []
262281
-- do nothing on other messages
263282
dispatch _ acc = pure acc
264283

284+
-- | For testings purposes.
285+
filterAllowedAll :: ServerMessageFilter
286+
filterAllowedAll _ = True
287+
265288
defaultFilter :: ServerMessageFilter
266289
defaultFilter msg = case msg of
267290
-- PostgreSQL sends it only in startup phase
@@ -338,6 +361,23 @@ sendFlush conn = sendMessage (connRawConnection conn) Flush
338361
readNextData :: Connection -> IO (Either Error DataMessage)
339362
readNextData conn = readChan $ connOutDataChan conn
340363

364+
-- | Public
365+
sendSimpleQuery :: Connection -> B.ByteString -> IO (Either Error ())
366+
sendSimpleQuery conn q = withConnectionMode conn SimpleQueryMode $ \c -> do
367+
sendMessage (connRawConnection c) $ SimpleQuery (StatementSQL q)
368+
readReadyForQuery c
369+
370+
withConnectionMode
371+
:: Connection -> ConnectionMode -> (Connection -> IO a) -> IO a
372+
withConnectionMode conn mode handler = do
373+
oldMode <- readIORef ref
374+
atomicWriteIORef ref mode
375+
r <- handler conn
376+
atomicWriteIORef ref oldMode
377+
pure r
378+
where
379+
ref = connMode conn
380+
341381
-- | Public
342382
-- SHOULD BE called after every sended `Sync` message
343383
-- skips all messages except `ReadyForQuery`

tests/test.hs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import Control.Exception
2+
13
import Test.Tasty
24
import Test.Tasty.HUnit
35

@@ -24,15 +26,15 @@ query2 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["a", "3"] Text Text
2426
query3 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["3", "3"] Text Text
2527
query4 = Query "SELECT $1 + $2" [Oid 23, Oid 23] ["4", "3"] Text Text
2628

29+
withConnection :: (Connection -> IO a) -> IO a
30+
withConnection = bracket (connect defaultSettings) close
2731

2832
test :: IO ()
29-
test = do
30-
c <- connect defaultConnectionSettings
33+
test = withConnection $ \c -> do
3134
sendBatch c queries
3235
sendSync c
3336
readResults c $ length queries
3437
readReadyForQuery c >>= print
35-
close c
3638
where
3739
queries = [query1, query2, query3, query4 ]
3840
readResults c 0 = pure ()
@@ -44,18 +46,14 @@ test = do
4446
Right _ -> readResults c $ n - 1
4547

4648

47-
4849
testDescribe1 :: IO ()
49-
testDescribe1 = do
50-
c <- connect defaultConnectionSettings
50+
testDescribe1 = withConnection $ \c -> do
5151
r <- describeStatement c $ StatementSQL "start transaction"
5252
print r
53-
close c
5453

5554
testDescribe2 :: IO ()
56-
testDescribe2 = do
55+
testDescribe2 = withConnection $ \c -> do
5756
c <- connect defaultConnectionSettings
5857
r <- describeStatement c $ StatementSQL "select count(*) from a where v > $1"
5958
print r
60-
close c
6159

0 commit comments

Comments
 (0)
0