{-# LANGUAGE TypeApplications  #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

module Main where

import Control.DeepSeq
import Gargantext.Core.Types.Individu
import Gargantext.Core.Viz.Phylo
import Gargantext.Core.Viz.Phylo.API.Tools (readPhylo)
import Gargantext.Core.Viz.Phylo.PhyloMaker (toPhylo)
import Gargantext.Core.Viz.Phylo.PhyloTools
import Gargantext.Prelude.Crypto.Auth (createPasswordHash)
import Paths_gargantext
import qualified Data.Array.Accelerate as A
import qualified Data.Array.Accelerate as Accelerate
import qualified Data.Array.Accelerate.LLVM.Native as LLVM
import qualified Data.Array.Accelerate.Interpreter as Naive
import qualified Data.List.Split as Split
import qualified Data.Massiv.Array as Massiv
import qualified Gargantext.Core.LinearAlgebra as LA
import qualified Gargantext.Core.Methods.Matrix.Accelerate.Utils as Accelerate
import qualified Gargantext.Core.Methods.Similarities.Accelerate.Distributional as Accelerate
import qualified Numeric.LinearAlgebra.Data as HM
import Test.Tasty.Bench
import Data.Array.Accelerate ((:.))

phyloConfig :: PhyloConfig
phyloConfig = PhyloConfig {
    corpusPath = "corpus.csv"
  , listPath = "list.csv"
  , outputPath = "data/"
  , corpusParser = Csv {_csv_limit = 150000}
  , listParser = V4
  , phyloName = "Phylo Name"
  , phyloScale = 2
  , similarity = WeightedLogJaccard {_wlj_sensibility = 0.5, _wlj_minSharedNgrams = 2}
  , seaElevation = Constante {_cons_start = 0.1, _cons_gap = 0.1}
  , defaultMode = True
  , findAncestors = False
  , phyloSynchrony = ByProximityThreshold {_bpt_threshold = 0.5, _bpt_sensibility = 0.0, _bpt_scope = AllBranches, _bpt_strategy = MergeAllGroups}
  , phyloQuality = Quality {_qua_granularity = 0.8, _qua_minBranch = 3}
  , timeUnit = Year {_year_period = 3, _year_step = 1, _year_matchingFrame = 5}
  , clique = MaxClique {_mcl_size = 5, _mcl_threshold = 1.0e-4, _mcl_filter = ByThreshold}
  , exportLabel = [ BranchLabel {_branch_labelTagger = MostEmergentTfIdf, _branch_labelSize = 2}
                  , GroupLabel {_group_labelTagger = MostEmergentInclusive, _group_labelSize = 2}
                  ]
  , exportSort = ByHierarchy {_sort_order = Desc}
  , exportFilter = [ByBranchSize {_branch_size = 3.0}]
  }

matrixValues :: [Int]
matrixValues = [ 1 .. 10_000 ]

matrixDim :: Int
matrixDim = 100

testMatrix :: A.Matrix Int
testMatrix = A.fromList (A.Z A.:. matrixDim A.:. matrixDim) $ matrixValues
{-# INLINE testMatrix #-}

testVector :: A.Array (A.Z :. Int :. Int :. Int) Int
testVector = A.fromList (A.Z A.:. 20 A.:. 20 A.:. 20) $ matrixValues
{-# INLINE testVector #-}

testMassivMatrix :: Massiv.Matrix Massiv.U Int
testMassivMatrix = Massiv.fromLists' Massiv.Par $ Split.chunksOf matrixDim $ matrixValues
{-# INLINE testMassivMatrix #-}

testMassivVector :: Massiv.Array Massiv.U Massiv.Ix3 Int
testMassivVector = LA.accelerate2Massiv3DMatrix testVector
{-# INLINE testMassivVector #-}

main :: IO ()
main = do
  _issue290Phylo     <- force . setConfig phyloConfig <$> (readPhylo =<< getDataFileName "bench-data/phylo/issue-290.json")
  issue290PhyloSmall <- force . setConfig phyloConfig <$> (readPhylo =<< getDataFileName "bench-data/phylo/issue-290-small.json")
  let !accInput       = force testMatrix
  let !accVector      = force testVector
  let !massivVector   = force testMassivVector
  let !(accDoubleInput :: Accelerate.Matrix Double) = force $ Naive.run $ Accelerate.map Accelerate.fromIntegral (Accelerate.use testMatrix)
  let !massivInput    = force testMassivMatrix
  let !(massivDoubleInput :: Massiv.Matrix Massiv.U Double) = force $ Massiv.computeP $ Massiv.map fromIntegral testMassivMatrix
  defaultMain
    [ bgroup "Benchmarks"
      [ bgroup "User creation" [
        bench "createPasswordHash"  $ whnfIO (createPasswordHash "rabbit")
      , bench "toUserHash"  $
          whnfIO (toUserHash $ NewUser "alfredo" "alfredo@well-typed.com" (GargPassword "rabbit"))
      ]
      , bgroup "Phylo" [
          bench "toPhylo (small)" $ nf toPhylo issue290PhyloSmall
      ]
      , bgroup "logDistributional2" [
          bench "Accelerate (Naive)" $ nf (Accelerate.logDistributional2With @Double Naive.run) accInput
      ,   bench "Accelerate (LLVM)"  $ nf (Accelerate.logDistributional2With @Double LLVM.run) accInput
      ,   bench "Massiv"  $ nf (LA.logDistributional2 @_ @Double) massivInput
      ]
      , bgroup "distributional" [
          bench "Accelerate (Naive)" $ nf (Accelerate.distributionalWith @Double Naive.run) accInput
      ,   bench "Accelerate (LLVM)"  $ nf (Accelerate.distributionalWith @Double LLVM.run) accInput
      ,   bench "Massiv (reference implementation)" $ nf (LA.distributionalReferenceImplementation @_ @Double) massivInput
      ,   bench "Massiv " $ nf (LA.distributional @_ @Double) massivInput
      ]
      , bgroup "diag" [
          bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.diag . Accelerate.use) accInput
      ,   bench "Accelerate (LLVM)"  $ nf (LLVM.run  . Accelerate.diag . Accelerate.use) accInput
      ,   bench "Massiv " $ nf (LA.diag @_) massivInput
      ]
      , bgroup "matrixIdentity" [
          bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.matrixIdentity @Double) 1000
      ,   bench "Accelerate (LLVM)"  $ nf (LLVM.run  . Accelerate.matrixIdentity @Double) 1000
      ,   bench "Massiv" $ nf (LA.matrixIdentity @Double) 1000
      ,   bench "HMatrix" $ nf (HM.ident @Double) 1000
      ]
      , bgroup "matrixEye" [
          bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.matrixEye @Double) 1000
      ,   bench "Accelerate (LLVM)"  $ nf (LLVM.run  . Accelerate.matrixEye @Double) 1000
      ,   bench "Massiv " $ nf (LA.matrixEye @Double) 1000
      ]
      , bgroup "matMaxMini" [
          bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.matMaxMini @Double . Accelerate.use) accDoubleInput
      ,   bench "Accelerate (LLVM)"  $ nf (LLVM.run  . Accelerate.matMaxMini @Double . Accelerate.use) accDoubleInput
      ,   bench "Massiv " $ nf LA.matMaxMini massivDoubleInput
      ]
      , bgroup "(.*)" [
          bench "Accelerate (Naive)" $ nf (\v -> Naive.run $ (Accelerate.use v) Accelerate..* (Accelerate.use v)) accDoubleInput
      ,   bench "Accelerate (LLVM)"  $ nf (\v -> LLVM.run $ (Accelerate.use v) Accelerate..* (Accelerate.use v)) accDoubleInput
      ,   bench "Massiv " $ nf (\v -> (v LA..* v) :: Massiv.Matrix Massiv.U Double) massivDoubleInput
      ]
      , bgroup "sumRows" [
          bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.sum . Accelerate.use) accVector
      ,   bench "Accelerate (LLVM)"  $ nf (LLVM.run . Accelerate.sum . Accelerate.use) accVector
      ,   bench "Massiv "            $ nf LA.sumRows massivVector
      ]
      , bgroup "sumMin_go" [
          bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.sumMin_go 100 . Accelerate.use) accDoubleInput
      ,   bench "Accelerate (LLVM)"  $ nf (LLVM.run . Accelerate.sumMin_go 100 . Accelerate.use) accDoubleInput
      ,   bench "Massiv "            $ nf (Massiv.compute @Massiv.U . LA.sumMin_go 100) massivDoubleInput
      ]
      , bgroup "termDivNan" [
          bench "Accelerate (Naive)" $
            nf (\m -> Naive.run $ Accelerate.termDivNan (Accelerate.use m) (Accelerate.use m)) accDoubleInput
      ,   bench "Accelerate (LLVM)" $
            nf (\m -> LLVM.run $ Accelerate.termDivNan (Accelerate.use m) (Accelerate.use m)) accDoubleInput
      ,   bench "Massiv " $ nf (\m -> LA.termDivNan @Massiv.U m m) massivDoubleInput
      ]
      ]
    ]