{-# LANGUAGE QuasiQuotes          #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE TupleSections        #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}

{-| Tests for the transactional DB API -}

module Test.Database.Transactions (
     tests
  ) where

import Control.Concurrent.Async (forConcurrently)
import Control.Exception.Safe
import Control.Exception.Safe qualified as Safe
import Control.Monad.Reader
import Control.Monad.Trans.Control
import Data.Pool
import Data.Profunctor.Product.TH (makeAdaptorAndInstance)
import Data.String
import Data.Text qualified as T
import Data.Text.Encoding qualified as TE
import Database.PostgreSQL.Simple qualified as PG
import Database.PostgreSQL.Simple.FromField
import Database.PostgreSQL.Simple.FromRow
import Database.PostgreSQL.Simple.Options qualified as Client
import Database.PostgreSQL.Simple.SqlQQ (sql)
import Database.PostgreSQL.Simple.ToField
import Database.Postgres.Temp qualified as Tmp
import Gargantext.API.Errors.Types (BackendInternalError)
import Gargantext.Database.Query.Table.Node.Error (errorWith)
import Gargantext.Database.Schema.Prelude (Table (..))
import Gargantext.Database.Transactional
import Gargantext.Prelude
import Opaleye (selectTable, requiredTableField, SqlInt4)
import Opaleye qualified as O
import Prelude qualified
import Shelly as SH
import System.Random.Stateful
import Test.Database.Types hiding (Counter)
import Test.Hspec
import Test.Tasty.HUnit hiding (assert)
import Text.RawString.QQ

--
-- For these tests we do not want to test the normal GGTX database queries, but rather
-- the foundational approach for the DBTx monad. Therefore we don't use the usual
-- 'withTestDB' code, but we rely on something very simple, a single table representing
-- counters with IDs, like so:
--
-- | ID | Counter_value |
-- | 1  | 0
-- | 2  | ...
--

newtype CounterId = CounterId { _CounterId :: Int }
  deriving (Show, Eq, ToField, FromField)

data Counter' id value = Counter
  { counterId    :: !id
  , counterValue :: value
  }
  deriving (Show, Eq)

type Counter = Counter' CounterId Int

$(makeAdaptorAndInstance "pCounter" ''Counter')

type CounterOpa = Counter' (O.Field SqlInt4) (O.Field SqlInt4)

countersTable :: Table CounterOpa CounterOpa
countersTable  =
  Table "ggtx_test_counter_table"
         ( pCounter
           Counter { counterId = requiredTableField "id"
                   , counterValue = requiredTableField "counter_value"
                   }
         )


newtype TestDBTxMonad a = TestDBTxMonad { _TestDBTxMonad :: TestMonadM DBHandle BackendInternalError a }
  deriving ( Functor, Applicative, Monad
           , MonadReader DBHandle, MonadError BackendInternalError
           , MonadBase IO
           , MonadBaseControl IO
           , MonadFail
           , MonadIO
           , MonadMask
           , MonadCatch
           , MonadThrow
           )

runTestDBTxMonad :: DBHandle -> TestMonadM DBHandle BackendInternalError a -> IO a
runTestDBTxMonad env = flip runReaderT env . _TestMonad

setup :: IO DBHandle
setup = do
  res <- Tmp.startConfig tmpPgConfig
  case res of
    Left err -> Prelude.fail $ show err
    Right db -> do
      let idleTime = 60.0
      let maxResources = 2
      let poolConfig = defaultPoolConfig (PG.connectPostgreSQL (Tmp.toConnectionString db))
                                         PG.close
                                         idleTime
                                         maxResources
      pool <- newPool (setNumStripes (Just 2) poolConfig)
      bootstrapCounterDB db pool
      pure $ DBHandle pool db
  where
    tmpPgConfig :: Tmp.Config
    tmpPgConfig = Tmp.defaultConfig <>
      Tmp.optionsToDefaultConfig mempty
        { Client.dbname   = pure dbName
        , Client.user     = pure dbUser
        , Client.password = pure dbPassword
        }

dbUser, dbPassword, dbName, dbTable :: String
dbUser = "ggtx_test_counter_db_user"
dbPassword = "ggtx_test_counter_db_pwd"
dbName = "ggtx_test_counter_db"
dbTable = "public.ggtx_test_counter_table"

bootstrapCounterDB :: Tmp.DB -> Pool PG.Connection -> IO ()
bootstrapCounterDB tmpDB pool = withResource pool $ \conn -> do
  void $ PG.execute_ conn (fromString $ "ALTER USER \"" <> dbUser <> "\" with PASSWORD '" <> dbPassword <> "'")
  let schemaContent = counterDBSchema
  let connString = Tmp.toConnectionString tmpDB
  (res,ec) <- shelly $ silently $ escaping False $ do
    withTmpDir $ \tdir -> do
      let schemaPath = tdir <> "/schema.sql"
      writefile schemaPath (T.pack schemaContent)
      result <- SH.run "psql" ["-d", "\"" <> TE.decodeUtf8 connString <> "\"", "<", fromString schemaPath]
      (result,) <$> lastExitCode
  unless (ec == 0) $ Safe.throwIO (Prelude.userError $ show ec <> ": " <> T.unpack res)

counterDBSchema :: String
counterDBSchema = [r|
  CREATE TABLE |] <> dbTable <> [r| (
      id SERIAL,
      counter_value INT NOT NULL DEFAULT 0,
      PRIMARY KEY (id)
  );
  ALTER TABLE public.ggtx_test_counter_table OWNER TO |] <> dbUser <> ";" <> [r|
  INSERT INTO public.ggtx_test_counter_table(counter_value) VALUES(42);
|]

withTestCounterDB :: (DBHandle -> IO ()) -> IO ()
withTestCounterDB = Safe.bracket setup teardown

teardown :: DBHandle -> IO ()
teardown test_db = do
  destroyAllResources $ _DBHandle test_db
  Tmp.stop $ _DBTmp test_db

--
-- Helpers and transactions to work with counters
--

instance PG.FromRow Counter where
  fromRow = Counter <$> field <*> field

getCounterById :: CounterId -> DBQuery BackendInternalError r Counter
getCounterById (CounterId cid) = do
  xs <- mkPGQuery [sql| SELECT * FROM public.ggtx_test_counter_table WHERE id = ?; |] (PG.Only cid)
  case xs of
    [c] -> pure c
    rst -> errorWith $ "getCounterId returned more than one result: " <> T.pack (show rst)

insertCounter :: DBUpdate BackendInternalError Counter
insertCounter = do
  mkPGUpdateReturningOne [sql| INSERT INTO public.ggtx_test_counter_table(counter_value) VALUES(0) RETURNING id, counter_value|] ()

updateCounter :: CounterId -> Int -> DBUpdate BackendInternalError Counter
updateCounter cid x = do
  mkPGUpdateReturningOne [sql| UPDATE public.ggtx_test_counter_table SET counter_value = ? WHERE id = ? RETURNING *|] (x, cid)

-- | We deliberately write this as a composite operation.
stepCounter :: CounterId -> DBUpdate BackendInternalError Counter
stepCounter cid = do
  Counter{..} <- getCounterById cid
  mkPGUpdateReturningOne [sql| UPDATE public.ggtx_test_counter_table SET counter_value = ? WHERE id = ? RETURNING *|] (counterValue + 1, cid)

--
-- MAIN TESTS
--

tests :: Spec
tests = parallel $ around withTestCounterDB $
  describe "Database Transactions" $ do
    describe "Opaleye count queries" $ do
      it "Supports counting rows" opaCountQueries
    describe "Pure PG Queries" $ do
      it "Simple query works" simplePGQueryWorks
    describe "Pure PG Inserts" $ do
      it "Simple insert works" simplePGInsertWorks
    describe "Pure PG Updates" $ do
      it "Simple updates works" simplePGUpdateWorks
    describe "PG Queries and Updates" $ do
      it "Supports mixing queries and updates" mixQueriesAndUpdates
    describe "Rollback support" $ do
      it "can rollback in case of errors" testRollback
    describe "Read/Write Consistency" $ do
      it "should return a consistent state to different actors" testConsistency

simplePGQueryWorks :: DBHandle -> Assertion
simplePGQueryWorks env = runTestDBTxMonad env $ do
  x <- runDBQuery $ getCounterById (CounterId 1)
  liftIO $ counterValue x `shouldBe` 42

simplePGInsertWorks :: DBHandle -> Assertion
simplePGInsertWorks env = runTestDBTxMonad env $ do
  x <- runDBTx $ insertCounter
  liftIO $ x `shouldBe` (Counter (CounterId 2) 0)

simplePGUpdateWorks :: DBHandle -> Assertion
simplePGUpdateWorks env = runTestDBTxMonad env $ do
  x <- runDBTx $ updateCounter (CounterId 1) 99
  liftIO $ x `shouldBe` (Counter (CounterId 1) 99)

mixQueriesAndUpdates :: DBHandle -> Assertion
mixQueriesAndUpdates env = runTestDBTxMonad env $ do
  (final_1, final_2) <- runDBTx $ do
    c1 <- insertCounter
    c2 <- insertCounter
    c1' <- getCounterById (counterId c1)
    c2' <- stepCounter (counterId c2)
    pure (c1', c2')
  liftIO $ do
    final_1 `shouldBe` (Counter (CounterId 2) 0)
    final_2 `shouldBe` (Counter (CounterId 3) 1)

testRollback :: DBHandle -> Assertion
testRollback env = runTestDBTxMonad env $ do
  initialCounter <- runDBTx $ insertCounter >>= stepCounter . counterId
  liftIO $ counterValue initialCounter `shouldBe` 1
  -- Let's do another transaction where at the very last instruction we
  -- fail.
  Safe.handle (\(_ :: SomeException) -> pure ()) $ runDBTx $ do
    _x' <- stepCounter (counterId initialCounter)
    errorWith "urgh"

  -- Let's check that the second 'stepCounter' didn't actually modified the counter's value.
  finalCounter <- runDBTx $ getCounterById (counterId initialCounter)
  liftIO $ counterValue finalCounter `shouldBe` 1

-- | In this test we create concurrent actors all writing to the /same/ counter.
-- Each one should observe only the state it's updating.
testConsistency :: DBHandle -> Assertion
testConsistency env = do

  let competing_actors = 10
  initialCounter <- runTestDBTxMonad env $ runDBTx insertCounter

  results <- forConcurrently [ 1 .. competing_actors ] $ \x -> runTestDBTxMonad env $ do
    -- random delay
    liftIO $ do
      delay_us <- uniformRM (100, 2_000_000) globalStdGen
      threadDelay delay_us
    runDBTx $ do
      _ <- updateCounter (counterId initialCounter) x
      getCounterById (counterId initialCounter)

  -- Each actor should observe a consistent state.
  liftIO $ results `shouldBe` map (Counter (CounterId 2)) [ 1 .. competing_actors ]

opaCountQueries :: DBHandle -> Assertion
opaCountQueries env = runTestDBTxMonad env $ do
  num0 <- runDBTx $ mkOpaCountQuery (selectTable countersTable)
  liftIO $ num0 @?= 1 -- Returns the master counter created alongside the schema.

  num <- runDBTx $ do
    _ <- insertCounter
    _ <- insertCounter
    mkOpaCountQuery (selectTable countersTable)
  liftIO $ num @?= 3