8000 Session monad · postgres-haskell/postgres-wire@394027a · GitHub
[go: up one dir, main page]

Skip to content

Commit 394027a

Browse files
Session monad
1 parent 47e3488 commit 394027a

File tree

1 file changed

+125
-43
lines changed

1 file changed

+125
-43
lines changed

src/Database/PostgreSQL/Protocol/Connection.hs

Lines changed: 125 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
{-# language OverloadedLists #-}
22
{-# language OverloadedStrings #-}
3+
{-# language DeriveFunctor #-}
4+
{-# language GADTs #-}
5+
{-# language ApplicativeDo #-}
6+
{-# language ExistentialQuantification #-}
7+
{-# language TypeSynonymInstances #-}
8+
{-# language FlexibleInstances #-}
39
module Database.PostgreSQL.Protocol.Connection where
410

511

@@ -12,7 +18,8 @@ import Data.Foldable
1218
import Control.Applicative
1319
import Data.Monoid
1420
import Control.Concurrent
15-
import Data.Binary.Get (Decoder(..), runGetIncremental, pushChunk)
21+
import Data.Binary.Get ( runGetIncremental, pushChunk)
22+
import qualified Data.Binary.Get as BG (Decoder(..))
1623
import Data.Maybe (fromJust)
1724
import qualified Data.Vector as V
1825
import System.Socket hiding (connect, close)
@@ -27,10 +34,12 @@ import Database.PostgreSQL.Protocol.Settings
2734
import Database.PostgreSQL.Protocol.Encoders
2835
import Database.PostgreSQL.Protocol.Decoders
2936
import Database.PostgreSQL.Protocol.Types
37+
import Database.PostgreSQL.Protocol.StatementStorage
3038

3139

3240
type UnixSocket = Socket Unix Stream Unix
3341
-- data Connection = Connection (Socket Inet6 Stream TCP)
42+
-- TODO add statement storage
3443
data Connection = Connection UnixSocket ThreadId
3544

3645
address :: SocketAddress Unix
@@ -65,7 +74,7 @@ sendMessage sock msg = void $ do
6574
readAuthMessage :: B.ByteString -> IO ()
6675
readAuthMessage s =
6776
case pushChunk (runGetIncremental decodeAuthResponse) s of
68-
Done _ _ r -> case r of
77+
BG.Done _ _ r -> case r of
6978
AuthenticationOk -> putStrLn "Auth ok"
7079
_ -> error "Invalid auth"
7180
f -> error $ show s
@@ -80,15 +89,14 @@ receiverThread sock = forever $ do
8089
where
8190
decoder = runGetIncremental decodeServerMessage
8291
go str = case pushChunk decoder str of
83-
Done rest _ v -> do
92+
BG.Done rest _ v -> do
8493
print v
8594
unless (B.null rest) $ go rest
86-
Partial _ -> error "Partial"
87-
Fail _ _ e -> error e
95+
BG.Partial _ -> error "Partial"
96+
BG.Fail _ _ e -> error e
8897

89-
data QQuery = QQuery
90-
{ qName :: B.ByteString
91-
, qStmt :: B.ByteString
98+
data QQuery a = QQuery
99+
{ qStmt :: B.ByteString
92100
, qOids :: V.Vector Oid
93101
, qValues :: V.Vector B.ByteString
94102
} deriving Show
@@ -98,39 +106,113 @@ data QQuery = QQuery
98106
-- query3 = QQuery "test3" "SELECT $1 + $2" [23, 23] ["3", "3"]
99107
-- query4 = QQuery "test4" "SELECT $1 + $2" [23, 23] ["4", "3"]
100108
-- query5 = QQuery "test5" "SELECT $1 + $2" [23, 23] ["5", "3"]
101-
query1 = QQuery "test1" "select sum(v) from a" [] []
102-
query2 = QQuery "test2" "select sum(v) from a" [] []
103-
query3 = QQuery "test3" "select sum(v) from a" [] []
104-
query4 = QQuery "test4" "select sum(v) from a" [] []
105-
query5 = QQuery "test5" "select sum(v) from a" [] []
106-
107-
sendBatch :: Connection -> [QQuery] -> IO ()
108-
sendBatch (Connection s _) qs = do
109-
traverse sendSingle $ take 5 qs
110-
sendMessage s $ encodeClientMessage Sync
111-
where
112-
sendSingle q = do
113-
sendMessage s $ encodeClientMessage $
114-
Parse (qName q) (qStmt q) (qOids q)
115-
sendMessage s $ encodeClientMessage $
116-
Bind (qName q) (qName q) Text (qValues q) Text
117-
sendMessage s $ encodeClientMessage $ Execute (qName q)
118-
119-
120-
sendQuery :: Connection -> IO ()
121-
sendQuery (Connection s _) = do
122-
sendMessage s $ encodeClientMessage $ Parse "test" "SELECT $1 + $2" [23, 23]
123-
sendMessage s $ encodeClientMessage $
124-
Bind "test" "test" Text ["2", "3"] Text
125-
sendMessage s $ encodeClientMessage $ Execute "test"
126-
sendMessage s $ encodeClientMessage Sync
127-
128-
test :: IO ()
129-
test = do
130-
c <- connect defaultConnectionSettings
131-
-- sendQuery c
132-
getPOSIXTime >>= \t -> print "Start " >> print t
133-
sendBatch c [query1, query2, query3, query4, query5]
134-
threadDelay $ 5 * 1000 * 1000
135-
close c
109+
-- query1 = QQuery "test1" "select sum(v) from a" [] []
110+
-- query2 = QQuery "test2" "select sum(v) from a" [] []
111+
-- query3 = QQuery "test3" "select sum(v) from a" [] []
112+
-- query4 = QQuery "test4" "select sum(v) from a" [] []
113+
-- query5 = QQuery "test5" "select sum(v) from a" [] []
114+
115+
-- sendBatch :: Connection -> [QQuery] -> IO ()
116+
-- sendBatch (Connection s _) qs = do
117+
-- traverse sendSingle $ take 5 qs
118+
-- sendMessage s $ encodeClientMessage Sync
119+
-- where
120+
-- sendSingle q = do
121+
-- sendMessage s $ encodeClientMessage $
122+
-- Parse (qName q) (qStmt q) (qOids q)
123+
-- sendMessage s $ encodeClientMessage $
124+
-- Bind (qName q) (qName q) Text (qValues q) Text
125+
-- sendMessage s $ encodeClientMessage $ Execute (qName q)
126+
127+
128+
-- sendQuery :: Connection -> IO ()
129+
-- sendQuery (Connection s _) = do
130+
-- sendMessage s $ encodeClientMessage $ Parse "test" "SELECT $1 + $2" [23, 23]
131+
-- sendMessage s $ encodeClientMessage $
132+
-- Bind "test" "test" Text ["2", "3"] Text
133+
-- sendMessage s $ encodeClientMessage $ Execute "test"
134+
-- sendMessage s $ encodeClientMessage Sync
135+
136+
-- test :: IO ()
137+
-- test = do
138+
-- c <- connect defaultConnectionSettings
139+
-- -- sendQuery c
140+
-- getPOSIXTime >>= \t -> print "Start " >> print t
141+
-- sendBatch c [query1, query2, query3, query4, query5]
142+
-- threadDelay $ 5 * 1000 * 1000
143+
-- close c
144+
145+
146+
-- sendBatchAndSync :: IsQuery a => [a] -> Connection -> IO ()
147+
-- sendBatchAndSync = undefined
148+
149+
-- sendBatchAndFlush :: IsQuery a => [a] -> Connection -> IO ()
150+
-- sendBatchAndFlush = undefined
151+
152+
-- internal helper
153+
-- sendBatch :: IsQuery a => [a] -> Connection -> IO ()
154+
-- sendBatch = undefined
155+
156+
-- Session Monad
157+
--
158+
159+
data Request = forall a . Request (QQuery a)
160+
161+
query :: Decode a => QQuery a -> Session a
162+
query q = Send One [Request q] $ Receive Done
163+
164+
data Count = One | Many
165+
deriving (Eq, Show)
166+
167+
data Session a
168+
= Done a
169+
| forall r . Decode r => Receive (r -> Session a)
170+
| Send Count [Request] (Session a)
171+
172+
instance Functor Session where
173+
f `fmap` (Done a) = Done $ f a
174+
f `fmap` (Receive g) = Receive $ fmap f . g
175+
f `fmap` (Send n br c) = Send n br (f <$> c)
176+
177+
instance Applicative Session where
178+
pure = Done
179+
180+
f <*> x = case (f, x) of
181+
(Done g, Done y) -> Done (g y)
182+
(Done g, Receive next) -> Receive $ fmap g . next
183+
(Done g, Send n br c) -> Send n br (g <$> c)
184+
185+
(Send n br c, Done y) -> Send n br (c <*> pure y)
186+
(Send n br c, Receive next)
187+
-> Send n br $ c <*> Receive next
188+
(Send n1 br1 c1, Send n2 br2 c2)
189+
-> if n1 == One
190+
then Send n2 (br1 <> br2) (c1 <*> c2)
191+
else Send n1 br1 (c1 <*> Send n2 br2 c2)
192+
193+
(Receive next1, Receive next2) ->
194+
Receive $ (\g -> Receive $ (g <*> ) . next2) . next1
195+
(Receive next, Done y) -> Receive $ (<*> Done y) . next
196+
(Receive next, Send n br c)
197+
-> Receive $ (<*> Send n br c) . next
198+
199+
instance Monad Session where
200+
return = pure
201+
202+
m >>= f = case m of
203+
Done a -> f a
204+
Receive g -> Receive $ (>>=f) . g
205+
Send _n br c -> Send Many br (c >>= f)
206+
207+
(>>) = (*>)
208+
209+
-- Type classes
210+
class Decode a where
211+
decode :: String -> a
212+
213+
instance Decode Integer where
214+
decode = read
215+
216+
instance Decode String where
217+
decode = id
136218

0 commit comments

Comments
 (0)
0