Commit b14c2506 authored by Alfredo Di Napoli's avatar Alfredo Di Napoli

Add PGUpdateReturning facility

parent 489968f6
...@@ -15,6 +15,7 @@ module Gargantext.Database.Transactional ( ...@@ -15,6 +15,7 @@ module Gargantext.Database.Transactional (
-- * Smart constructors -- * Smart constructors
, mkPGQuery , mkPGQuery
, mkPGUpdate , mkPGUpdate
, mkPGUpdateReturning
, mkOpaQuery , mkOpaQuery
, mkOpaUpdate , mkOpaUpdate
, mkOpaInsert , mkOpaInsert
...@@ -35,6 +36,7 @@ import Database.PostgreSQL.Simple.Transaction qualified as PG ...@@ -35,6 +36,7 @@ import Database.PostgreSQL.Simple.Transaction qualified as PG
import Gargantext.Database.Prelude import Gargantext.Database.Prelude
import Opaleye import Opaleye
import Prelude import Prelude
import qualified Control.Exception.Safe as Safe
data DBOperation = DBRead | DBWrite data DBOperation = DBRead | DBWrite
...@@ -48,6 +50,10 @@ data DBTransactionOp err (r :: DBOperation) next where ...@@ -48,6 +50,10 @@ data DBTransactionOp err (r :: DBOperation) next where
-- | A Postgres /write/, returning the number of affected rows. It can be used only in -- | A Postgres /write/, returning the number of affected rows. It can be used only in
-- 'DBWrite' transactions. -- 'DBWrite' transactions.
PGUpdate :: PG.ToRow a => PG.Query -> a -> (Int -> next) -> DBTransactionOp err DBWrite next PGUpdate :: PG.ToRow a => PG.Query -> a -> (Int -> next) -> DBTransactionOp err DBWrite next
-- | Unlike a 'PGUpdate' that returns the list of affected rows, this can be used
-- to write updates that returns a value via the \"RETURNING\" directive. It's the programmer's
-- responsibility to ensure that the SQL fragment contains it.
PGUpdateReturning :: (PG.ToRow q, PG.FromRow a) => PG.Query -> q -> (a -> next) -> DBTransactionOp err DBWrite next
-- | An Opaleye /read/, returning a list of results. The 'r' in the result is polymorphic -- | An Opaleye /read/, returning a list of results. The 'r' in the result is polymorphic
-- so that reads can be embedded in updates transactions. -- so that reads can be embedded in updates transactions.
OpaQuery :: Default FromFields fields a => Select fields -> ([a] -> next) -> DBTransactionOp err r next OpaQuery :: Default FromFields fields a => Select fields -> ([a] -> next) -> DBTransactionOp err r next
...@@ -76,6 +82,7 @@ instance Functor (DBTransactionOp err r) where ...@@ -76,6 +82,7 @@ instance Functor (DBTransactionOp err r) where
fmap f = \case fmap f = \case
PGQuery q params cont -> PGQuery q params (f . cont) PGQuery q params cont -> PGQuery q params (f . cont)
PGUpdate q a cont -> PGUpdate q a (f . cont) PGUpdate q a cont -> PGUpdate q a (f . cont)
PGUpdateReturning q a cont -> PGUpdateReturning q a (f . cont)
OpaQuery sel cont -> OpaQuery sel (f . cont) OpaQuery sel cont -> OpaQuery sel (f . cont)
OpaInsert ins cont -> OpaInsert ins (f . cont) OpaInsert ins cont -> OpaInsert ins (f . cont)
OpaUpdate upd cont -> OpaUpdate upd (f . cont) OpaUpdate upd cont -> OpaUpdate upd (f . cont)
...@@ -127,11 +134,19 @@ evalOp :: PG.Connection -> DBTransactionOp err r a -> DBTxCmd err a ...@@ -127,11 +134,19 @@ evalOp :: PG.Connection -> DBTransactionOp err r a -> DBTxCmd err a
evalOp conn = \case evalOp conn = \case
PGQuery qr q cc -> cc <$> liftBase (PG.query conn qr q) PGQuery qr q cc -> cc <$> liftBase (PG.query conn qr q)
PGUpdate qr a cc -> cc <$> liftBase (fromIntegral <$> PG.execute conn qr a) PGUpdate qr a cc -> cc <$> liftBase (fromIntegral <$> PG.execute conn qr a)
PGUpdateReturning qr a cc -> cc <$> liftBase (queryOne conn qr a)
OpaQuery sel cc -> cc <$> liftBase (runSelect conn sel) OpaQuery sel cc -> cc <$> liftBase (runSelect conn sel)
OpaInsert ins cc -> cc <$> liftBase (runInsert conn ins) OpaInsert ins cc -> cc <$> liftBase (runInsert conn ins)
OpaUpdate upd cc -> cc <$> liftBase (runUpdate conn upd) OpaUpdate upd cc -> cc <$> liftBase (runUpdate conn upd)
DBFail err -> throwError err DBFail err -> throwError err
queryOne :: (PG.ToRow q, PG.FromRow r) => PG.Connection -> PG.Query -> q -> IO r
queryOne conn q v = do
rs <- PG.query conn q v
case rs of
[x] -> pure x
_ -> Safe.throwIO $ userError "queryOne: more than one result returned. Have you used the 'RETURNING' directive?"
-- --
-- Smart constructors -- Smart constructors
-- --
...@@ -148,6 +163,9 @@ mkPGQuery q a = DBTx $ liftF (PGQuery q a id) ...@@ -148,6 +163,9 @@ mkPGQuery q a = DBTx $ liftF (PGQuery q a id)
mkPGUpdate :: PG.ToRow a => PG.Query -> a -> DBUpdate err Int mkPGUpdate :: PG.ToRow a => PG.Query -> a -> DBUpdate err Int
mkPGUpdate q a = DBTx $ liftF (PGUpdate q a id) mkPGUpdate q a = DBTx $ liftF (PGUpdate q a id)
mkPGUpdateReturning :: (PG.ToRow q, PG.FromRow a) => PG.Query -> q -> DBUpdate err a
mkPGUpdateReturning q a = DBTx $ liftF (PGUpdateReturning q a id)
mkOpaQuery :: Default FromFields fields a mkOpaQuery :: Default FromFields fields a
=> Select fields => Select fields
-> DBQuery err x [a] -> DBQuery err x [a]
......
...@@ -140,6 +140,10 @@ getCounterById (CounterId cid) = do ...@@ -140,6 +140,10 @@ getCounterById (CounterId cid) = do
[c] -> pure c [c] -> pure c
rst -> dbFail $ Prelude.userError ("getCounterId returned more than one result: " <> show rst) rst -> dbFail $ Prelude.userError ("getCounterId returned more than one result: " <> show rst)
insertCounter :: DBUpdate IOException Counter
insertCounter = do
mkPGUpdateReturning [sql| INSERT INTO public.ggtx_test_counter_table(counter_value) VALUES(0) RETURNING id, counter_value|] ()
-- --
-- MAIN TESTS -- MAIN TESTS
-- --
...@@ -147,10 +151,17 @@ getCounterById (CounterId cid) = do ...@@ -147,10 +151,17 @@ getCounterById (CounterId cid) = do
tests :: Spec tests :: Spec
tests = parallel $ around withTestCounterDB $ tests = parallel $ around withTestCounterDB $
describe "Database Transactions" $ do describe "Database Transactions" $ do
describe "Pure Queries" $ do describe "Pure PG Queries" $ do
it "Simple query works" simpleQueryWorks it "Simple query works" simplePGQueryWorks
describe "Pure PG Inserts" $ do
it "Simple insert works" simplePGInsertWorks
simpleQueryWorks :: DBHandle -> Assertion simplePGQueryWorks :: DBHandle -> Assertion
simpleQueryWorks env = flip runReaderT env $ runTestMonad $ do simplePGQueryWorks env = flip runReaderT env $ runTestMonad $ do
x <- runDBQuery $ getCounterById (CounterId 1) x <- runDBQuery $ getCounterById (CounterId 1)
liftIO $ counterValue x `shouldBe` 42 liftIO $ counterValue x `shouldBe` 42
simplePGInsertWorks :: DBHandle -> Assertion
simplePGInsertWorks env = flip runReaderT env $ runTestMonad $ do
x <- runDBTx $ insertCounter
liftIO $ x `shouldBe` (Counter (CounterId 2) 0)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment