Skip to content

Commit 8566160

Browse files
Refactor Driver directory
1 parent 1b89885 commit 8566160

File tree

5 files changed

+71
-49
lines changed

5 files changed

+71
-49
lines changed

postgres-wire.cabal

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ library
5858
BangPatterns
5959
OverloadedStrings
6060
GeneralizedNewtypeDeriving
61+
LambdaCase
6162
cc-options: -O2 -Wall
6263

6364
test-suite postgres-wire-test-connection

src/Database/PostgreSQL/Driver/Connection.hs

+13-9
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ import Database.PostgreSQL.Driver.RawConnection
5656
-- | Public
5757
-- Connection parametrized by message type in chan.
5858
data AbsConnection mt = AbsConnection
59-
{ connRawConnection :: RawConnection
60-
, connReceiverThread :: Weak ThreadId
61-
, connStatementStorage :: StatementStorage
62-
, connParameters :: ConnectionParameters
63-
, connOutChan :: TQueue (Either ReceiverException mt)
59+
{ connRawConnection :: !RawConnection
60+
, connReceiverThread :: !(Weak ThreadId)
61+
, connStatementStorage :: !StatementStorage
62+
, connParameters :: !ConnectionParameters
63+
, connOutChan :: !(TQueue (Either ReceiverException mt))
6464
}
6565

6666
type Connection = AbsConnection DataMessage
@@ -122,15 +122,18 @@ connectCommon' settings msgFilter = connectWith settings $ \rawConn params ->
122122

123123
-- Low-level sending functions
124124

125+
{-# INLINE sendStartMessage #-}
125126
sendStartMessage :: RawConnection -> StartMessage -> IO ()
126127
sendStartMessage rawConn msg = void $
127128
rSend rawConn . runEncode $ encodeStartMessage msg
128129

129130
-- Only for testings and simple queries
131+
{-# INLINE sendMessage #-}
130132
sendMessage :: RawConnection -> ClientMessage -> IO ()
131133
sendMessage rawConn msg = void $
132134
rSend rawConn . runEncode $ encodeClientMessage msg
133135

136+
{-# INLINE sendEncode #-}
134137
sendEncode :: AbsConnection c -> Encode -> IO ()
135138
sendEncode conn = void . rSend (connRawConnection conn) . runEncode
136139

@@ -290,6 +293,11 @@ receiverThreadCommon rawConn chan msgFilter ntfHandler = go ""
290293
dispatchIfNotification (NotificationResponse ntf) handler = handler ntf
291294
dispatchIfNotification _ _ = pure ()
292295

296+
-- | Helper to read from queue.
297+
{-# INLINE writeChan #-}
298+
writeChan :: TQueue a -> a -> IO ()
299+
writeChan q = atomically . writeTQueue q
300+
293301
defaultNotificationHandler :: NotificationHandler
294302
defaultNotificationHandler = const $ pure ()
295303

@@ -332,7 +340,3 @@ defaultFilter msg = case msg of
332340
-- as result for `describe` message
333341
RowDescription{} -> True
334342

335-
-- | Helper to read from queue.
336-
writeChan :: TQueue a -> a -> IO ()
337-
writeChan q = atomically . writeTQueue q
338-

src/Database/PostgreSQL/Driver/Query.hs

+35-30
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ module Database.PostgreSQL.Driver.Query
1212
, collectUntilReadyForQuery
1313
) where
1414

15-
import Data.Foldable
16-
import Data.Monoid
17-
import Data.Bifunctor
18-
import qualified Data.Vector as V
19-
import qualified Data.ByteString as B
2015
import Control.Concurrent.STM.TQueue (TQueue, readTQueue )
21-
import Control.Concurrent.STM (atomically)
16+
import Control.Concurrent.STM (atomically)
17+
import Data.Foldable (fold)
18+
import Data.Monoid ((<>))
19+
import Data.ByteString (ByteString)
20+
import Data.Vector (Vector)
2221

2322
import Database.PostgreSQL.Protocol.Encoders
2423
import Database.PostgreSQL.Protocol.Store.Encode
@@ -31,26 +30,30 @@ import Database.PostgreSQL.Driver.StatementStorage
3130

3231
-- Public
3332
data Query = Query
34-
{ qStatement :: B.ByteString
35-
, qValues :: [(Oid, Maybe Encode)]
36-
, qParamsFormat :: Format
37-
, qResultFormat :: Format
38-
, qCachePolicy :: CachePolicy
33+
{ qStatement :: !ByteString
34+
, qValues :: ![(Oid, Maybe Encode)]
35+
, qParamsFormat :: !Format
36+
, qResultFormat :: !Format
37+
, qCachePolicy :: !CachePolicy
3938
} deriving (Show)
4039

4140
-- | Public
41+
{- INLINE sendBatchAndFlush #-}
4242
sendBatchAndFlush :: Connection -> [Query] -> IO ()
4343
sendBatchAndFlush = sendBatchEndBy Flush
4444

4545
-- | Public
46+
{-# INLINE sendBatchAndSync #-}
4647
sendBatchAndSync :: Connection -> [Query] -> IO ()
4748
sendBatchAndSync = sendBatchEndBy Sync
4849

4950
-- | Public
51+
{-# INLINE sendSync #-}
5052
sendSync :: Connection -> IO ()
5153
sendSync conn = sendEncode conn $ encodeClientMessage Sync
5254

5355
-- | Public
56+
{-# INLINABLE readNextData #-}
5457
readNextData :: Connection -> IO (Either Error DataRows)
5558
readNextData conn =
5659
readChan (connOutChan conn) >>=
@@ -62,6 +65,7 @@ readNextData conn =
6265
DataReady -> throwIncorrectUsage
6366
"Expected DataRow message, but got ReadyForQuery"
6467

68+
{-# INLINABLE waitReadyForQuery #-}
6569
waitReadyForQuery :: Connection -> IO (Either Error ())
6670
waitReadyForQuery conn =
6771
readChan (connOutChan conn) >>=
@@ -77,6 +81,7 @@ waitReadyForQuery conn =
7781
DataReady -> pure $ Right ()
7882

7983
-- Helper
84+
{-# INLINE sendBatchEndBy #-}
8085
sendBatchEndBy :: ClientMessage -> Connection -> [Query] -> IO ()
8186
sendBatchEndBy msg conn qs = do
8287
batch <- constructBatch conn qs
@@ -90,28 +95,27 @@ constructBatch conn = fmap fold . traverse constructSingle
9095
pname = PortalName ""
9196
constructSingle q = do
9297
let stmtSQL = StatementSQL $ qStatement q
93-
(sname, parseMessage) <- case qCachePolicy q of
94-
AlwaysCache -> do
95-
mName <- lookupStatement storage stmtSQL
96-
case mName of
97-
Nothing -> do
98-
newName <- storeStatement storage stmtSQL
99-
pure (newName, encodeClientMessage $
100-
Parse newName stmtSQL (fst <$> qValues q))
101-
Just name -> pure (name, mempty)
102-
NeverCache -> do
103-
let newName = defaultStatementName
104-
pure (newName, encodeClientMessage $
105-
Parse newName stmtSQL (fst <$> qValues q))
106-
let bindMessage = encodeClientMessage $
107-
Bind pname sname (qParamsFormat q) (snd <$> qValues q)
98+
(stmtName, needParse) <- case qCachePolicy q of
99+
AlwaysCache -> lookupStatement storage stmtSQL >>= \case
100+
Nothing -> do
101+
newName <- storeStatement storage stmtSQL
102+
pure (newName, True)
103+
Just name ->
104+
pure (name, False)
105+
NeverCache -> pure (defaultStatementName, True)
106+
let parseMessage = if needParse
107+
then encodeClientMessage $
108+
Parse stmtName stmtSQL (fst <$> qValues q)
109+
else mempty
110+
bindMessage = encodeClientMessage $
111+
Bind pname stmtName (qParamsFormat q) (snd <$> qValues q)
108112
(qResultFormat q)
109113
executeMessage = encodeClientMessage $
110114
Execute pname noLimitToReceive
111115
pure $ parseMessage <> bindMessage <> executeMessage
112116

113117
-- | Public
114-
sendSimpleQuery :: ConnectionCommon -> B.ByteString -> IO (Either Error ())
118+
sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ())
115119
sendSimpleQuery conn q = do
116120
sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q)
117121
(checkErrors =<<) <$> collectUntilReadyForQuery conn
@@ -122,8 +126,8 @@ sendSimpleQuery conn q = do
122126
-- | Public
123127
describeStatement
124128
:: ConnectionCommon
125-
-> B.ByteString
126-
-> IO (Either Error (V.Vector Oid, V.Vector FieldDescription))
129+
-> ByteString
130+
-> IO (Either Error (Vector Oid, Vector FieldDescription))
127131
describeStatement conn stmt = do
128132
sendEncode conn $
129133
encodeClientMessage (Parse sname (StatementSQL stmt) [])
@@ -135,7 +139,7 @@ describeStatement conn stmt = do
135139
sname = StatementName ""
136140
parseMessages msgs = case msgs of
137141
[ParameterDescription params, NoData]
138-
-> pure $ Right (params, V.empty)
142+
-> pure $ Right (params, mempty)
139143
[ParameterDescription params, RowDescription fields]
140144
-> pure $ Right (params, fields)
141145
xs -> maybe
@@ -160,5 +164,6 @@ findFirstError [] = Nothing
160164
findFirstError (ErrorResponse desc : _) = Just desc
161165
findFirstError (_ : xs) = findFirstError xs
162166

167+
{-# INLINE readChan #-}
163168
readChan :: TQueue a -> IO a
164169
readChan = atomically . readTQueue

src/Database/PostgreSQL/Driver/StatementStorage.hs

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1-
module Database.PostgreSQL.Driver.StatementStorage where
2-
3-
import qualified Data.HashTable.IO as H
4-
import qualified Data.ByteString as B
1+
module Database.PostgreSQL.Driver.StatementStorage
2+
( StatementStorage
3+
, CachePolicy(..)
4+
, newStatementStorage
5+
, lookupStatement
6+
, storeStatement
7+
, getCacheSize
8+
, defaultStatementName
9+
) where
10+
11+
import Data.Monoid ((<>))
12+
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
13+
import Data.Word (Word)
14+
15+
import Data.ByteString (ByteString)
516
import Data.ByteString.Char8 (pack)
6-
import Data.Word (Word)
7-
import Data.IORef
17+
import qualified Data.HashTable.IO as H
818

919
import Database.PostgreSQL.Protocol.Types
1020

@@ -21,16 +31,17 @@ data CachePolicy
2131
newStatementStorage :: IO StatementStorage
2232
newStatementStorage = StatementStorage <$> H.new <*> newIORef 0
2333

34+
{-# INLINE lookupStatement #-}
2435
lookupStatement :: StatementStorage -> StatementSQL -> IO (Maybe StatementName)
2536
lookupStatement (StatementStorage table _) = H.lookup table
2637

27-
-- TODO place right name
2838
-- TODO info about exceptions and mask
39+
{-# INLINE storeStatement #-}
2940
storeStatement :: StatementStorage -> StatementSQL -> IO StatementName
3041
storeStatement (StatementStorage table counter) stmt = do
3142
n <- readIORef counter
3243
writeIORef counter $ n + 1
33-
let name = StatementName . pack $ show n
44+
let name = StatementName . (statementPrefix <>) . pack $ show n
3445
H.insert table stmt name
3546
pure name
3647

@@ -40,3 +51,6 @@ getCacheSize (StatementStorage _ counter) = readIORef counter
4051
defaultStatementName :: StatementName
4152
defaultStatementName = StatementName ""
4253

54+
statementPrefix :: ByteString
55+
statementPrefix = "_pw_statement_"
56+

src/Database/PostgreSQL/Protocol/Codecs/Numeric.hs

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
{-# language LambdaCase #-}
2-
31
module Database.PostgreSQL.Protocol.Codecs.Numeric
42
( scientificToNumeric
53
, numericToScientific

0 commit comments

Comments
 (0)