|
| 1 | +{-# language GADTs #-} |
| 2 | +{-# language TypeFamilies #-} |
| 3 | +{-# language DataKinds #-} |
| 4 | +{-# language KindSignatures #-} |
| 5 | +{-# language ScopedTypeVariables #-} |
| 6 | +{-# language FlexibleInstances #-} |
| 7 | +{-# language FlexibleContexts #-} |
| 8 | +{-# language UndecidableInstances #-} |
| 9 | +{-# language ConstrainedClassMethods #-} |
1 | 10 | module Database.PostgreSQL.Protocol.Codecs.Decoders where
|
2 | 11 |
|
3 |
| -import Data.Bool |
| 12 | +-- import Data.Bool |
4 | 13 | import Data.Word
|
5 | 14 | import Data.Int
|
6 | 15 | import Data.Char
|
7 | 16 | import Control.Monad
|
8 | 17 | import qualified Data.ByteString as B
|
9 | 18 | import qualified Data.Vector as V
|
10 | 19 |
|
| 20 | +import Control.Monad |
| 21 | +import Control.Applicative.Free |
| 22 | +import Data.Proxy |
| 23 | +import Prelude hiding (bool) |
| 24 | + |
11 | 25 | import Database.PostgreSQL.Protocol.Store.Decode
|
12 | 26 | import Database.PostgreSQL.Protocol.Store.Encode
|
13 | 27 | import Database.PostgreSQL.Protocol.Types
|
14 | 28 |
|
| 29 | +{-# INLINE skipDataRowHeader #-} |
15 | 30 | skipDataRowHeader :: Decode ()
|
16 | 31 | skipDataRowHeader = skipBytes 7
|
17 | 32 |
|
| 33 | +{-# INLINE fieldLength #-} |
18 | 34 | fieldLength :: Decode Int
|
19 | 35 | fieldLength = fromIntegral <$> getInt32BE
|
20 | 36 |
|
| 37 | +{-# INLINE getNonNullable #-} |
21 | 38 | getNonNullable :: FieldDecoder a -> Decode a
|
22 |
| -getNonNullable dec = fieldLength >>= runFieldDecoder dec |
| 39 | +getNonNullable fdec = fieldLength >>= fdec |
23 | 40 |
|
| 41 | +{-# INLINE getNullable #-} |
24 | 42 | getNullable :: FieldDecoder a -> Decode (Maybe a)
|
25 |
| -getNullable dec = do |
| 43 | +getNullable fdec = do |
26 | 44 | len <- fieldLength
|
27 | 45 | if len == -1
|
28 | 46 | then pure Nothing
|
29 |
| - else Just <$!> runFieldDecoder dec len |
| 47 | + else Just <$!> fdec len |
30 | 48 |
|
31 | 49 | -- Field in composites Oid before value
|
32 |
| -compositeValue :: Decode a -> Decode a |
33 |
| -compositeValue dec = skipBytes 4 >> dec |
| 50 | +compositeValue :: Decode () |
| 51 | +compositeValue = skipBytes 4 |
34 | 52 |
|
| 53 | +-- Skips length of elements in composite |
35 | 54 | compositeHeader :: Decode ()
|
36 | 55 | compositeHeader = skipBytes 4
|
37 | 56 |
|
38 |
| -arrayData :: Int -> Decode a -> Decode (V.Vector a) |
39 |
| -arrayData len dec = undefined |
| 57 | +-- Dimensions, HasNull, Oid |
| 58 | +arrayHeader :: Decode () |
| 59 | +arrayHeader = skipBytes 12 |
| 60 | + |
| 61 | +arrayDimensions :: Int -> Decode (V.Vector Int) |
| 62 | +arrayDimensions depth = V.reverse <$> V.replicateM depth arrayDimSize |
| 63 | + where |
| 64 | + arrayDimSize = (fromIntegral <$> getInt32BE) <* getInt32BE |
| 65 | + |
| 66 | + |
| 67 | +arrayFieldDecoder :: Int -> (V.Vector Int -> Decode a) -> FieldDecoder a |
| 68 | +arrayFieldDecoder dims f _ = arrayHeader *> arrayDimensions dims >>= f |
40 | 69 |
|
41 | 70 | -- Public decoders
|
42 | 71 | -- | Decodes only content of a field.
|
43 |
| -newtype FieldDecoder a = FieldDecoder { runFieldDecoder :: Int -> Decode a } |
| 72 | +type FieldDecoder a = Int -> Decode a |
44 | 73 |
|
| 74 | +{-# INLINE int2 #-} |
45 | 75 | int2 :: FieldDecoder Int16
|
46 |
| -int2 = FieldDecoder $ \ _ -> getInt16BE |
| 76 | +int2 _ = getInt16BE |
47 | 77 |
|
| 78 | +{-# INLINE int4 #-} |
48 | 79 | int4 :: FieldDecoder Int32
|
49 |
| -int4 = FieldDecoder $ \ _ -> getInt32BE |
| 80 | +int4 _ = getInt32BE |
50 | 81 |
|
| 82 | +{-# INLINE int8 #-} |
51 | 83 | int8 :: FieldDecoder Int64
|
52 |
| -int8 = FieldDecoder $ \ _ -> getInt64BE |
| 84 | +int8 _ = getInt64BE |
53 | 85 |
|
| 86 | +{-# INLINE bool #-} |
54 | 87 | bool :: FieldDecoder Bool
|
55 |
| -bool = FieldDecoder $ \ _ -> (== 1) <$> getWord8 |
| 88 | +bool _ = (== 1) <$> getWord8 |
| 89 | + |
| 90 | +data FieldF r a |
| 91 | + = Single !(FieldDecoder a) |
| 92 | + | Row !(r a) |
| 93 | + |
| 94 | +{-# INLINE getFieldDec #-} |
| 95 | +getFieldDec :: FieldF CompositeValue a -> FieldDecoder a |
| 96 | +getFieldDec (Single fd) = fd |
| 97 | +getFieldDec (Row r) = composite r |
| 98 | + |
| 99 | +-- High level |
| 100 | +-- |
| 101 | + |
| 102 | +class PrimField a where |
| 103 | + |
| 104 | + primField :: RowDecoder r => FieldF r a |
| 105 | + |
| 106 | + {-# INLINE field #-} |
| 107 | + field :: RowDecoder r => r a |
| 108 | + field = getRowNonNullValue $ getFieldDec primField |
| 109 | + |
| 110 | + type IsArrayField a :: Bool |
| 111 | + type IsArrayField a = 'False |
| 112 | + |
| 113 | + type IsNullableField a :: Bool |
| 114 | + type IsNullableField a = 'False |
| 115 | + |
| 116 | + arrayDim :: Proxy a -> Int |
| 117 | + arrayDim _ = 0 |
| 118 | + |
| 119 | + asArrayData :: V.Vector Int -> Decode a |
| 120 | + asArrayData _ = runRowDecoder (field :: RowValue a) |
| 121 | + |
| 122 | +instance PrimField Int16 where |
| 123 | + primField = Single int2 |
| 124 | + |
| 125 | +instance PrimField Int32 where |
| 126 | + primField = Single int4 |
| 127 | + |
| 128 | +instance PrimField Int64 where |
| 129 | + primField = Single int8 |
| 130 | + |
| 131 | +instance PrimField Bool where |
| 132 | + primField = Single bool |
| 133 | + |
| 134 | +instance PrimField B.ByteString where |
| 135 | + primField = Single getByteString |
| 136 | + |
| 137 | +instance PrimField a => PrimField (Maybe a) where |
| 138 | + primField = undefined |
| 139 | + |
| 140 | + type IsNullableField (Maybe a) = 'True |
| 141 | + type IsArrayField (Maybe a) = IsArrayField a |
| 142 | + {-# INLINE field #-} |
| 143 | + field = getRowNullValue $ getFieldDec primField |
| 144 | + |
| 145 | +instance (IsAllowedArray (IsNullableField a) (IsArrayField a) ~ 'True, |
| 146 | + PrimField a) |
| 147 | + => PrimField (V.Vector a) where |
| 148 | + primField = Single $ arrayFieldDecoder |
| 149 | + (arrayDim (Proxy :: Proxy (V.Vector a))) |
| 150 | + asArrayData |
| 151 | + |
| 152 | + type IsArrayField (V.Vector a) = 'True |
| 153 | + arrayDim _ = arrayDim (Proxy :: Proxy a) + 1 |
| 154 | + |
| 155 | + asArrayData vec = V.replicateM (vec V.! arrayDim (Proxy :: Proxy a)) |
| 156 | + $ asArrayData vec |
| 157 | + |
| 158 | +type family IsAllowedArray (n :: Bool) (a :: Bool) :: Bool where |
| 159 | + IsAllowedArray 'True 'True = 'False |
| 160 | + IsAllowedArray _ _ = 'True |
| 161 | + |
| 162 | + |
| 163 | +-- TODO add array value |
| 164 | +newtype RowValue a = RowValue { unRowValue :: Decode a } |
| 165 | + deriving (Functor, Applicative, Monad) |
| 166 | +newtype CompositeValue a = CompositeValue { unCompositeValue :: Decode a } |
| 167 | + deriving (Functor, Applicative, Monad) |
| 168 | + |
| 169 | +class (Functor r, Applicative r, Monad r) => RowDecoder r where |
| 170 | + getRowNonNullValue :: FieldDecoder a -> r a |
| 171 | + getRowNullValue :: FieldDecoder a -> r (Maybe a) |
| 172 | + runRowDecoder :: r a -> Decode a |
| 173 | + |
| 174 | +instance RowDecoder RowValue where |
| 175 | + {-# INLINE getRowNonNullValue #-} |
| 176 | + getRowNonNullValue = RowValue . getNonNullable |
| 177 | + {-# INLINE getRowNullValue #-} |
| 178 | + getRowNullValue = RowValue . getNullable |
| 179 | + {-# INLINE runRowDecoder #-} |
| 180 | + runRowDecoder = unRowValue |
| 181 | + |
| 182 | +instance RowDecoder CompositeValue where |
| 183 | + {-# INLINE getRowNonNullValue #-} |
| 184 | + getRowNonNullValue = CompositeValue |
| 185 | + . fmap (compositeValue *>) getNonNullable |
| 186 | + {-# INLINE getRowNullValue #-} |
| 187 | + getRowNullValue = CompositeValue |
| 188 | + . fmap (compositeValue *>) getNullable |
| 189 | + {-# INLINE runRowDecoder #-} |
| 190 | + runRowDecoder = unCompositeValue |
| 191 | + |
| 192 | +instance (PrimField a1, PrimField a2, PrimField a3) |
| 193 | + => PrimField (a1, a2, a3) where |
| 194 | + |
| 195 | + {-# INLINE primField #-} |
| 196 | + primField = Row $ (,,) <$> field <*> field <*> field |
| 197 | + |
| 198 | +instance (PrimField a1, PrimField a2, PrimField a3, PrimField a4, |
| 199 | + PrimField a5, PrimField a6, PrimField a7, PrimField a8, |
| 200 | + PrimField a9, PrimField a10, PrimField a11, PrimField a12) |
| 201 | + => PrimField (a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) |
| 202 | + where |
| 203 | + {-# INLINE primField #-} |
| 204 | + primField = Row $ (,,,,,,,,,,,) <$> field <*> field <*> field <*> field |
| 205 | + <*> field <*> field <*> field <*> field |
| 206 | + <*> field <*> field <*> field <*> field |
| 207 | + |
| 208 | + |
| 209 | +composite :: CompositeValue a -> FieldDecoder a |
| 210 | +composite dec _ = compositeHeader *> runRowDecoder dec |
| 211 | + |
| 212 | +{-# INLINE rowDecoder #-} |
| 213 | +rowDecoder :: forall a. PrimField a => Decode a |
| 214 | +rowDecoder = case primField of |
| 215 | + Single f -> skipDataRowHeader *> runRowDecoder |
| 216 | + (getRowNonNullValue f :: RowValue a) |
| 217 | + Row r -> skipDataRowHeader *> runRowDecoder (r :: RowValue a) |
| 218 | + |
| 219 | + |
0 commit comments