@@ -21,8 +21,8 @@ import Data.Binary.Get ( runGetIncremental, pushChunk)
21
21
import qualified Data.Binary.Get as BG (Decoder (.. ))
22
22
import Data.Maybe (fromJust )
23
23
import qualified Data.Vector as V
24
- import System.Socket hiding ( connect , close , Error ( .. ) )
25
- import qualified System.Socket as Socket (connect , close )
24
+ import System.Socket ( Socket , socket )
25
+ import qualified System.Socket as Socket (connect , close , send , receive )
26
26
import System.Socket.Family.Inet6
27
27
import System.Socket.Type.Stream
28
28
import System.Socket.Protocol.TCP
@@ -41,7 +41,7 @@ import Database.PostgreSQL.Types
41
41
type UnixSocket = Socket Unix Stream Unix
42
42
-- data Connection = Connection (Socket Inet6 Stream TCP)
43
43
data Connection = Connection
44
- { connSocket :: UnixSocket
44
+ { connRawConnection :: RawConnection
45
45
, connReceiverThread :: ThreadId
46
46
-- channel only for Data messages
47
47
, connOutDataChan :: OutChan (Either Error DataMessage )
@@ -64,6 +64,13 @@ data Error
64
64
data DataMessage = DataMessage [V. Vector B. ByteString ]
65
65
deriving (Show )
66
66
67
+ -- | Abstraction over raw socket connection or tls connection
68
+ data RawConnection = RawConnection
69
+ { rFlush :: IO ()
70
+ , rClose :: IO ()
71
+ , rSend :: B. ByteString -> IO ()
72
+ , rReceive :: Int -> IO B. ByteString
73
+ }
67
74
68
75
address :: SocketAddress Unix
69
76
address = fromJust $ socketAddressUnixPath " /var/run/postgresql/.s.PGSQL.5432"
@@ -73,15 +80,21 @@ connect settings = do
73
80
s <- socket
74
81
Socket. connect s address
75
82
sendStartMessage s $ consStartupMessage settings
76
- r <- receive s 4096 mempty
83
+ r <- Socket. receive s 4096 mempty
77
84
readAuthMessage r
78
85
79
86
(inDataChan, outDataChan) <- newChan
80
87
(inAllChan, outAllChan) <- newChan
81
- tid <- forkIO $ receiverThread s inDataChan inAllChan
88
+ let rawConnection = RawConnection
89
+ { rFlush = pure ()
90
+ , rClose = Socket. close s
91
+ , rSend = \ msg -> void $ Socket. send s msg mempty
92
+ , rReceive = \ n -> Socket. receive s n mempty
93
+ }
94
+ tid <- forkIO $ receiverThread rawConnection inDataChan inAllChan
82
95
storage <- newStatementStorage
83
96
pure Connection
84
- { connSocket = s
97
+ { connRawConnection = rawConnection
85
98
, connReceiverThread = tid
86
99
, connOutDataChan = outDataChan
87
100
, connOutAllChan = outAllChan
@@ -96,7 +109,7 @@ connect settings = do
96
109
close :: Connection -> IO ()
97
110
close conn = do
98
111
killThread $ connReceiverThread conn
99
- Socket. close $ connSocket conn
112
+ rClose $ connRawConnection conn
100
113
101
114
consStartupMessage :: ConnectionSettings -> StartMessage
102
115
consStartupMessage stg = StartupMessage
@@ -105,12 +118,12 @@ consStartupMessage stg = StartupMessage
105
118
sendStartMessage :: UnixSocket -> StartMessage -> IO ()
106
119
sendStartMessage sock msg = void $ do
107
120
let smsg = toStrict . toLazyByteString $ encodeStartMessage msg
108
- send sock smsg mempty
121
+ Socket. send sock smsg mempty
109
122
110
- sendMessage :: UnixSocket -> ClientMessage -> IO ()
111
- sendMessage sock msg = void $ do
123
+ sendMessage :: RawConnection -> ClientMessage -> IO ()
124
+ sendMessage rawConn msg = void $ do
112
125
let smsg = toStrict . toLazyByteString $ encodeClientMessage msg
113
- send sock smsg mempty
126
+ rSend rawConn smsg
114
127
115
128
readAuthMessage :: B. ByteString -> IO ()
116
129
readAuthMessage s =
@@ -121,15 +134,15 @@ readAuthMessage s =
121
134
f -> error $ show s
122
135
123
136
receiverThread
124
- :: UnixSocket
137
+ :: RawConnection
125
138
-> InChan (Either Error DataMessage )
126
139
-> InChan ServerMessage
127
140
-> IO ()
128
- receiverThread sock dataChan allChan = receiveLoop []
141
+ receiverThread rawConn dataChan allChan = receiveLoop []
129
142
where
130
143
receiveLoop :: [V. Vector B. ByteString ] -> IO ()
131
144
receiveLoop acc = do
132
- r <- receive sock 4096 mempty
145
+ r <- rReceive rawConn 4096
133
146
-- print r
134
147
go r acc >>= receiveLoop
135
148
@@ -220,7 +233,7 @@ data Query = Query
220
233
sendBatch :: Connection -> [Query ] -> IO ()
221
234
sendBatch conn = traverse_ sendSingle
222
235
where
223
- s = connSocket conn
236
+ s = connRawConnection conn
224
237
sname = StatementName " "
225
238
pname = PortalName " "
226
239
sendSingle q = do
@@ -230,10 +243,10 @@ sendBatch conn = traverse_ sendSingle
230
243
sendMessage s $ Execute pname noLimitToReceive
231
244
232
245
sendSync :: Connection -> IO ()
233
- sendSync conn = sendMessage (connSocket conn) Sync
246
+ sendSync conn = sendMessage (connRawConnection conn) Sync
234
247
235
248
sendFlush :: Connection -> IO ()
236
- sendFlush conn = sendMessage (connSocket conn) Flush
249
+ sendFlush conn = sendMessage (connRawConnection conn) Flush
237
250
238
251
readNextData :: Connection -> IO (Either Error DataMessage )
239
252
readNextData conn = readChan $ connOutDataChan conn
@@ -269,7 +282,7 @@ describeStatement conn stmt = do
269
282
sendMessage s Sync
270
283
parseMessages <$> waitReadyForQueryCollect conn
271
284
where
272
- s = connSocket conn
285
+ s = connRawConnection conn
273
286
sname = StatementName " "
274
287
parseMessages msgs = case msgs of
275
288
[ParameterDescription params, NoData ]
@@ -289,6 +302,7 @@ test :: IO ()
289
302
test = do
290
303
c <- connect defaultConnectionSettings
291
304
sendBatch c queries
305
+ sendSync c
292
306
readResults c $ length queries
293
307
readReadyForQuery c >>= print
294
308
close c
0 commit comments