1
+ {-# LANGUAGE RecursiveDo #-}
1
2
module Database.PostgreSQL.Driver.Connection where
2
3
3
4
@@ -9,6 +10,7 @@ import Control.Monad
9
10
import Data.Traversable
10
11
import Data.Foldable
11
12
import Control.Applicative
13
+ import Data.IORef
12
14
import Data.Monoid
13
15
import Control.Concurrent (forkIO , killThread , ThreadId , threadDelay )
14
16
import Data.Binary.Get ( runGetIncremental , pushChunk )
@@ -32,6 +34,14 @@ import Database.PostgreSQL.Driver.Settings
32
34
import Database.PostgreSQL.Driver.StatementStorage
33
35
import Database.PostgreSQL.Driver.Types
34
36
37
+ data ConnectionMode
38
+ -- | In this mode, all result's data is ignored
39
+ = SimpleQueryMode
40
+ -- | Usual mode
41
+ | ExtendedQueryMode
42
+
43
+ defaultConnectionMode :: ConnectionMode
44
+ defaultConnectionMode = ExtendedQueryMode
35
45
36
46
data Connection = Connection
37
47
{ connRawConnection :: RawConnection
@@ -42,6 +52,7 @@ data Connection = Connection
42
52
, connOutAllChan :: OutChan ServerMessage
43
53
, connStatementStorage :: StatementStorage
44
54
, connParameters :: ConnectionParameters
55
+ , connMode :: IORef ConnectionMode
45
56
}
46
57
47
58
type ServerMessageFilter = ServerMessage -> Bool
@@ -112,7 +123,10 @@ constructRawConnection s = RawConnection
112
123
}
113
124
114
125
connect :: ConnectionSettings -> IO Connection
115
- connect settings = do
126
+ connect settings = connectWith settings defaultFilter
127
+
128
+ connectWith :: ConnectionSettings -> ServerMessageFilter -> IO Connection
129
+ connectWith settings msgFilter = do
116
130
rawConn <- createRawConnection settings
117
131
when (settingsTls settings == RequiredTls ) $ handshakeTls rawConn
118
132
authResult <- authorize rawConn settings
@@ -123,16 +137,20 @@ connect settings = do
123
137
(inDataChan, outDataChan) <- newChan
124
138
(inAllChan, outAllChan) <- newChan
125
139
storage <- newStatementStorage
126
-
127
- tid <- forkIO $ receiverThread rawConn inDataChan inAllChan
128
- pure Connection
129
- { connRawConnection = rawConn
130
- , connReceiverThread = tid
131
- , connOutDataChan = outDataChan
132
- , connOutAllChan = outAllChan
133
- , connStatementStorage = storage
134
- , connParameters = connParams
135
- }
140
+ modeRef <- newIORef defaultConnectionMode
141
+
142
+ tid <- forkIO $
143
+ receiverThread msgFilter rawConn inDataChan inAllChan modeRef
144
+ rec conn <- pure Connection
145
+ { connRawConnection = rawConn
146
+ , connReceiverThread = tid
147
+ , connOutDataChan = outDataChan
148
+ , connOutAllChan = outAllChan
149
+ , connStatementStorage = storage
150
+ , connParameters = connParams
151
+ , connMode = modeRef
152
+ }
153
+ pure conn
136
154
137
155
authorize
138
156
:: RawConnection
@@ -145,7 +163,7 @@ authorize rawConn settings = do
145
163
r <- rReceive rawConn 4096
146
164
case pushChunk (runGetIncremental decodeAuthResponse) r of
147
165
BG. Done rest _ r -> case r of
148
- AuthenticationOk -> do
166
+ AuthenticationOk ->
149
167
-- TODO parse parameters
150
168
pure $ Right $ parseParameters rest
151
169
AuthenticationCleartextPassword ->
@@ -168,7 +186,7 @@ authorize rawConn settings = do
168
186
r <- rReceive rawConn 4096
169
187
case pushChunk (runGetIncremental decodeAuthResponse) r of
170
188
BG. Done rest _ r -> case r of
171
- AuthenticationOk -> do
189
+ AuthenticationOk ->
172
190
pure $ Right $ parseParameters rest
173
191
AuthErrorResponse desc ->
174
192
pure $ Left $ AuthPostgresError desc
@@ -209,11 +227,13 @@ sendMessage rawConn msg = void $ do
209
227
210
228
211
229
receiverThread
212
- :: RawConnection
230
+ :: ServerMessageFilter
231
+ -> RawConnection
213
232
-> InChan (Either Error DataMessage )
214
233
-> InChan ServerMessage
234
+ -> IORef ConnectionMode
215
235
-> IO ()
216
- receiverThread rawConn dataChan allChan = receiveLoop []
236
+ receiverThread msgFilter rawConn dataChan allChan modeRef = receiveLoop []
217
237
where
218
238
receiveLoop :: [V. Vector B. ByteString ] -> IO ()
219
239
receiveLoop acc = do
@@ -226,8 +246,7 @@ receiverThread rawConn dataChan allChan = receiveLoop []
226
246
go str acc = case pushChunk decoder str of
227
247
BG. Done rest _ v -> do
228
248
-- putStrLn $ "Received: " ++ show v
229
- -- TODO select filter
230
- when (defaultFilter v) $ writeChan allChan v
249
+ when (msgFilter v) $ writeChan allChan v
231
250
newAcc <- dispatch v acc
232
251
if B. null rest
233
252
then pure newAcc
@@ -262,6 +281,10 @@ receiverThread rawConn dataChan allChan = receiveLoop []
262
281
-- do nothing on other messages
263
282
dispatch _ acc = pure acc
264
283
284
+ -- | For testings purposes.
285
+ filterAllowedAll :: ServerMessageFilter
286
+ filterAllowedAll _ = True
287
+
265
288
defaultFilter :: ServerMessageFilter
266
289
defaultFilter msg = case msg of
267
290
-- PostgreSQL sends it only in startup phase
@@ -338,6 +361,23 @@ sendFlush conn = sendMessage (connRawConnection conn) Flush
338
361
readNextData :: Connection -> IO (Either Error DataMessage )
339
362
readNextData conn = readChan $ connOutDataChan conn
340
363
364
+ -- | Public
365
+ sendSimpleQuery :: Connection -> B. ByteString -> IO (Either Error () )
366
+ sendSimpleQuery conn q = withConnectionMode conn SimpleQueryMode $ \ c -> do
367
+ sendMessage (connRawConnection c) $ SimpleQuery (StatementSQL q)
368
+ readReadyForQuery c
369
+
370
+ withConnectionMode
371
+ :: Connection -> ConnectionMode -> (Connection -> IO a ) -> IO a
372
+ withConnectionMode conn mode handler = do
373
+ oldMode <- readIORef ref
374
+ atomicWriteIORef ref mode
375
+ r <- handler conn
376
+ atomicWriteIORef ref oldMode
377
+ pure r
378
+ where
379
+ ref = connMode conn
380
+
341
381
-- | Public
342
382
-- SHOULD BE called after every sended `Sync` message
343
383
-- skips all messages except `ReadyForQuery`
0 commit comments