{-# OPTIONS_GHC -Wno-deprecations #-}

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}

module Test.Utils (
  -- * Helper types
    JsonFragmentResponseMatcher(..)

  -- * Utility functions
  , (@??=)
  , containsJSON
  , gargMkRequest
  , getJSON
  , pending
  , pollUntilWorkFinished
  , postJSONUrlEncoded
  , protected
  , protectedJSON
  , protectedJSONWith
  , protectedNewError
  , protectedWith
  , shouldRespondWithFragment
  , shouldRespondWithFragmentCustomStatus
  , shouldRespondWithJSON
  , waitForTChanValue
  , waitForTSem
  , waitUntil
  , withValidLogin
  , withValidLoginA
  ) where

import Control.Concurrent.STM.TChan (TChan, readTChan)
import Control.Concurrent.STM.TSem (TSem, waitTSem)
import Control.Concurrent.STM.TVar (newTVarIO, writeTVar, readTVarIO)
import Control.Exception.Safe ()
import Control.Monad ()
import Data.Aeson.KeyMap qualified as KM
import Data.Aeson qualified as JSON
import Data.ByteString.Char8 qualified as B
import Data.ByteString.Lazy qualified as L
import Data.Map.Strict qualified as Map
import Data.Text.Encoding qualified as TE
import Data.Text.Lazy.Encoding qualified as TLE
import Data.Text.Lazy qualified as TL
import Data.Text qualified as T
import Data.TreeDiff
import Gargantext.API.Admin.Auth.Types (AuthRequest(..), AuthResponse, Token, authRes_token)
import Gargantext.API.Admin.Orchestrator.Types
import Gargantext.API.Routes.Types (xGargErrorScheme)
import Gargantext.Core.Config (LogConfig)
import Gargantext.Core.Notifications.Dispatcher.Types qualified as DT
import Gargantext.Core.Types.Individu (Username, GargPassword)
import Gargantext.Core.Worker.Types (JobInfo(..))
import Gargantext.Prelude
import Gargantext.System.Logging (withLogger, logMsg, LogLevel(..))
import Network.HTTP.Client (defaultManagerSettings, newManager)
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Types.Header (hAccept, hAuthorization, hContentType)
import Network.HTTP.Types (Header, Method, status200)
import Network.Wai.Handler.Warp (Port)
import Network.Wai.Test (SResponse(..))
import Network.WebSockets qualified as WS
import Prelude qualified
import Servant.Client.Core (BaseUrl)
import Servant.Client.Core.Request qualified as Client
import Servant.Client.Streaming (ClientEnv, baseUrlPort, mkClientEnv, parseBaseUrl, runClientM, makeClientRequest, defaultMakeClientRequest)
import System.Environment (lookupEnv)
import System.Timeout qualified as Timeout
import Test.API.Routes (auth_api)
import Test.Hspec.Expectations
import Test.Hspec.Wai.JSON (FromValue(..))
import Test.Hspec.Wai (MatchBody(..), WaiExpectation, WaiSession, request)
import Test.Hspec.Wai.Matcher (MatchHeader(..), ResponseMatcher(..), bodyEquals, formatHeader, match)
import Test.Tasty.HUnit (Assertion, assertBool)
import Test.Utils.Notifications (withWSConnection, millisecond)


-- | Marks the input 'Assertion' as pending, by ignoring any exception
-- thrown by it.
pending :: Prelude.String -> Assertion -> Assertion
pending reason act = act `catch` (\(e :: SomeException) -> do
  putStrLn $ "PENDING: " <> reason
  putStrLn (displayException e))


newtype JsonFragmentResponseMatcher = JsonFragmentResponseMatcher { getJsonMatcher :: ResponseMatcher }

-- | Succeeds if the full body matches the input /fragment/. Careful in using this
-- combinator, as it won't check that the full body matches the input, but rather
-- that the body contains the input fragment, which might lead to confusion.
shouldRespondWithFragment :: HasCallStack
                          => WaiSession st SResponse
                          -> JsonFragmentResponseMatcher
                          -> WaiExpectation st
shouldRespondWithFragment action matcher =
  shouldRespondWithFragmentCustomStatus 200 action matcher

-- | Same as above, but with custom status code
shouldRespondWithFragmentCustomStatus :: HasCallStack
                                      => Int
                                      -> WaiSession st SResponse
                                      -> JsonFragmentResponseMatcher
                                      -> WaiExpectation st
shouldRespondWithFragmentCustomStatus status action matcher = do
  let m = (getJsonMatcher matcher) { matchStatus = status }
  r <- action
  forM_ (match r (getJsonMatcher $ JsonFragmentResponseMatcher m)) (liftIO . expectationFailure)


instance FromValue JsonFragmentResponseMatcher where
  fromValue = JsonFragmentResponseMatcher . ResponseMatcher 200 [matchHeader] . containsJSON
    where
      matchHeader = MatchHeader $ \headers _body ->
        case Prelude.lookup "Content-Type" headers of
          Just h | isJSON h -> Nothing
          _ -> Just $ Prelude.unlines [
              "missing header:"
            , formatHeader ("Content-Type", "application/json")
            ]
      isJSON c = media == "application/json" && parameters `elem` ignoredParameters
        where
          (media, parameters) = let (m, p) = breakAt ';' c in (strip m, strip p)
          ignoredParameters = ["", "charset=utf-8"]

      breakAt c = fmap (B.drop 1) . B.break (== c)
      strip = B.reverse . B.dropWhile isSpace . B.reverse . B.dropWhile isSpace

shouldRespondWithJSON :: (JSON.FromJSON a, JSON.ToJSON a, HasCallStack)
                      => WaiSession st a
                      -> JsonFragmentResponseMatcher
                      -> WaiExpectation st
shouldRespondWithJSON action matcher = do
  r <- action
  forM_ (match (SResponse status200 mempty (JSON.encode r)) (getJsonMatcher matcher)) (liftIO . expectationFailure)

containsJSON :: JSON.Value -> MatchBody
containsJSON expected = MatchBody matcher
  where
    matcher headers actualBody = case JSON.decode actualBody of
      Just actual | expected `isSubsetOf` actual -> Nothing
      _ -> let MatchBody m = bodyEquals (JSON.encode expected) in m headers actualBody

    isSubsetOf :: JSON.Value -> JSON.Value -> Bool
    isSubsetOf (JSON.Object sub) (JSON.Object sup) =
      all (\(key, value) -> KM.lookup key sup == Just value) (KM.toList sub)
    isSubsetOf x y = x == y


-- | Issue a request with a valid 'Authorization: Bearer' inside.
protected :: HasCallStack
          => Token
          -> Method
          -> ByteString
          -> L.ByteString
          -> WaiSession () SResponse
protected tkn mth url = protectedWith mempty tkn mth url

protectedJSON :: forall a. (JSON.FromJSON a, Typeable a, HasCallStack)
              => Token
              -> Method
              -> ByteString
              -> JSON.Value
              -> WaiSession () a
protectedJSON tkn mth url = protectedJSONWith mempty tkn mth url

protectedJSONWith :: forall a. (JSON.FromJSON a, Typeable a, HasCallStack)
                  => [Header]
                  -> Token
                  -> Method
                  -> ByteString
                  -> JSON.Value
                  -> WaiSession () a
protectedJSONWith hdrs tkn mth url jsonV = do
  SResponse{..} <- protectedWith hdrs tkn mth url (JSON.encode jsonV)
  case JSON.eitherDecode simpleBody of
    Left err -> Prelude.fail $ "protectedJSON failed when parsing " <> show (typeRep $ Proxy @a) <> ": " <> err
    Right x  -> pure x

protectedWith :: HasCallStack
              => [Header]
              -> Token
              -> Method -> ByteString -> L.ByteString -> WaiSession () SResponse
protectedWith extraHeaders tkn mth url payload =
  -- Using a map means that if any of the extra headers contains a clashing header name,
  -- the extra headers will take precedence.
  let defaultHeaders = [ (hAccept, "application/json;charset=utf-8")
                       , (hContentType, "application/json")
                       , (hAuthorization, "Bearer " <> TE.encodeUtf8 tkn)
                       ]
      hdrs = Map.toList $ Map.fromList $ defaultHeaders <> extraHeaders
  in request mth url hdrs payload

protectedNewError :: HasCallStack => Token -> Method -> ByteString -> L.ByteString -> WaiSession () SResponse
protectedNewError tkn mth url = protectedWith newErrorFormat tkn mth url
  where
    newErrorFormat = [(xGargErrorScheme, "new")]

getJSON :: Token -> ByteString -> WaiSession () SResponse
getJSON tkn url = protectedWith mempty tkn "GET" url ""

postJSONUrlEncoded :: forall a. (JSON.FromJSON a, Typeable a, HasCallStack)
                   => Token
                   -> ByteString
                   -> L.ByteString
                   -> WaiSession () a
postJSONUrlEncoded tkn url queryPaths = do
  SResponse{..} <- protectedWith [(hContentType, "application/x-www-form-urlencoded")] tkn "POST" url queryPaths
  case JSON.eitherDecode simpleBody of
    Left err -> Prelude.fail $ "postJSONUrlEncoded failed when parsing " <> show (typeRep $ Proxy @a) <> ": " <> err <> "\nPayload was: " <> (T.unpack . TL.toStrict . TLE.decodeUtf8 $ simpleBody)
    Right x  -> pure x

withValidLoginA :: (MonadFail m, MonadIO m)
                => Port
                -> Username
                -> GargPassword
                -> (ClientEnv -> AuthResponse -> m a)
                -> m a
withValidLoginA port ur pwd act = do
  baseUrl <- liftIO $ parseBaseUrl "http://localhost"
  manager <- liftIO $ newManager defaultManagerSettings
  let clientEnv0 = mkClientEnv manager (baseUrl { baseUrlPort = port })
  let authPayload = AuthRequest ur pwd
  result <- liftIO $ runClientM (auth_api authPayload) clientEnv0
  case result of
    Left err  -> liftIO $ throwIO $ Prelude.userError (show err)
    Right res -> do
      traceEnabled <- isJust <$> liftIO (lookupEnv "GARG_DEBUG_LOGS")
      act (clientEnv0 { makeClientRequest = gargMkRequest traceEnabled }) res

withValidLogin :: (MonadFail m, MonadIO m)
               => Port
               -> Username
               -> GargPassword
               -> (ClientEnv -> Token -> m a)
               -> m a
withValidLogin port ur pwd act =
  withValidLoginA port ur pwd (\clientEnv authRes -> act clientEnv $ authRes ^. authRes_token)


-- | Allows to enable/disable logging of the input 'Request' to check what the
-- client is actually sending to the server.
-- FIXME(adn) We cannot upgrade to servant-client 0.20 due to OpenAlex:
-- https://gitlab.iscpif.fr/gargantext/crawlers/openalex/blob/main/src/OpenAlex/ServantClientLogging.hs#L24
gargMkRequest :: Bool -> BaseUrl -> Client.Request -> IO HTTP.Request
gargMkRequest traceEnabled bu clientRq = do
  httpReq <- defaultMakeClientRequest bu clientRq
  pure $
    case traceEnabled of
       True ->
         traceShowId httpReq
       False -> httpReq


pollUntilWorkFinished :: HasCallStack
                      => LogConfig
                      -> Port
                      -> JobInfo
                      -> WaiSession () JobInfo
pollUntilWorkFinished log_cfg port ji = do
  let waitSecs = 60
  isFinishedTVar <- liftIO $ newTVarIO False
  let wsConnect =
        withWSConnection ("127.0.0.1", port) $ \conn -> do
          -- We wait a bit before the server settles
          -- threadDelay (100 * millisecond)
          -- subscribe to notifications about this job
          let topic = DT.UpdateWorkerProgress ji
          WS.sendTextData conn $ JSON.encode (DT.WSSubscribe topic)
          forever $ do
            d <- WS.receiveData conn
            let dec = JSON.decode d :: Maybe DT.Notification
            case dec of
              Nothing -> pure ()
              Just (DT.NUpdateWorkerProgress ji' jl) -> do
                withLogger log_cfg $ \ioL ->
                  logMsg ioL DEBUG $ "[pollUntilWorkFinished] received " <> show ji' <> ", " <> show jl
                if ji' == ji && isFinished jl
                then do
                  withLogger log_cfg $ \ioL ->
                    logMsg ioL DEBUG $ "[pollUntilWorkFinished] FINISHED! " <> show ji'
                  atomically $ writeTVar isFinishedTVar True
                else
                  pure ()
              _ -> pure ()

  liftIO $ withAsync wsConnect $ \_ -> do
    mRet <- Timeout.timeout (waitSecs * 1000 * millisecond) $ do
      let go = do
            finished <- readTVarIO isFinishedTVar
            if finished
              then do
                withLogger log_cfg $ \ioL ->
                  logMsg ioL DEBUG $ "[pollUntilWorkFinished] JOB FINISHED: " <> show ji
                return True
              else do
                threadDelay (50 * millisecond)
                go
      go
    case mRet of
      Nothing -> panicTrace $ "[pollUntilWorkFinished] timed out while waiting to finish job " <> show ji
      Just _ -> return ji


  where
    isFinished (JobLog { .. }) = _scst_remaining == Just 0

-- | Like HUnit's '@?=', but With a nicer error message in case the two entities are not equal.
(@??=) :: (HasCallStack, ToExpr a, Eq a) => a -> a -> Assertion
actual @??= expected =
  assertBool (show $ ansiWlEditExprCompact $ ediff expected actual) (expected == actual)


-- | Given a predicate IO action, test it for given number of
-- milliseconds or fail
waitUntil :: HasCallStack => IO Bool -> Int -> Expectation
waitUntil pred' timeoutMs = do
  _mTimeout <- Timeout.timeout (timeoutMs * 1000) performTest
  -- shortcut for testing mTimeout
  p <- pred'
  unless p (expectationFailure "Predicate test failed")

  where
    performTest = do
      p <- pred'
      if p
        then return ()
        else do
          threadDelay 50000
          performTest


-- wait for given number of milliseconds for a given tchan value
waitForTChanValue :: (HasCallStack, Eq a, Show a) => TChan a -> a -> Int -> IO ()
waitForTChanValue tchan expected timeoutMs = do
  mTimeout <- Timeout.timeout (timeoutMs * 1000) $ do
    v <- atomically $ readTChan tchan
    unless (v == expected) $ panicTrace $ "[waitForTChanValue] v != expected (" <> show v <> " != " <> show expected <> ")"
    -- v `shouldBe` expected
  -- no timeout should have occurred
  -- mTimeout `shouldSatisfy` isJust
  when (isNothing mTimeout) $
    panicTrace $ "[waitForTChanValue] timeout when waiting for " <> show expected <> " on tchan"


waitForTSem :: HasCallStack => TSem -> Int -> IO ()
waitForTSem tsem timeoutMs = do
  mTimeout <- Timeout.timeout (timeoutMs * 1000) $ do
    atomically $ waitTSem tsem
  when (isNothing mTimeout) $
    panicTrace $ "[waitForTSem] timeout when waiting TSem"