{-| Module      : Gargantext.Core.Viz.Graph.ProxemyOptim
Description : Proxemy
Copyright   : (c) CNRS, 2017-Present
License     : AGPL + CECILL v3
Maintainer  : team@gargantext.org
Stability   : experimental
Portability : POSIX

Confluence for Graph Clustering, B. Gaume and A. Delanoë

Code written in Haskell by A. Delanoë from Python Specifications by B.
Gaume.

-}

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds,
             GADTs,
             KindSignatures,
             ScopedTypeVariables,
             StandaloneDeriving,
             TypeOperators       #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE NoImplicitPrelude   #-}
{-# LANGUAGE RankNTypes          #-}


module Graph.BAC.ProxemyOptim
  where

--import Debug.SimpleReflect
import Data.IntMap (IntMap)
import Data.Maybe (isJust)
import Data.Proxy (Proxy(Proxy))
import Data.Reflection
import Eigen.Internal (CTriplet(..), Elem(..), toC, fromC, C(..), natToInt, Row(..), Col(..))
import Eigen.Matrix (sum, unsafeCoeff)
import Eigen.SparseMatrix (SparseMatrix, SparseMatrixXd, (!), toMatrix, _unsafeCoeff)
import GHC.TypeLits (KnownNat, Nat, SomeNat(SomeNat), type(+), natVal, sameNat, someNatVal)
import Graph.FGL
import Prelude (String, readLn)
import Protolude hiding (sum, natVal)
import qualified Eigen.Matrix as DenseMatrix
import qualified Data.Graph.Inductive              as DGI
import qualified Data.Graph.Inductive.PatriciaTree as DGIP
import qualified Data.List                         as List
import qualified Data.IntMap                       as Dict
import qualified Data.Vector                       as V
import qualified Data.Vector.Storable              as VS
import qualified Eigen.Matrix                      as DMatrix
import qualified Eigen.SparseMatrix                as SMatrix
import qualified Prelude                           as Prelude
import qualified Data.Set                          as Set

-- | Main Types
type Length      = Int
type IsReflexive = Bool
type NeighborsFilter a b = DGI.Gr a b -> Node -> [Node]
type RmEdge      = Bool

---------------------------------------------------------------
---------------------------------------------------------------
type Graph a b = DGI.Gr a b
-- | A finite Graph is a Graph whose number of Nodes is known
data FiniteGraph (n :: Nat) a b = FiniteGraph (Graph a b)

instance (Show a, Show b) => Show (FiniteGraph n a b) where
  show (FiniteGraph g) = Prelude.show g


buildFiniteGraph :: Proxy n -> Graph a b -> FiniteGraph n a b
buildFiniteGraph _ g = FiniteGraph g

sparseVectorFromList :: KnownNat n => Proxy n -> [(Int, Int, Double)] -> SparseVector n
sparseVectorFromList _ ns = SMatrix.fromList ns

withG :: (Show a, Show b)
      => Graph a b
      -> (forall n. KnownNat n => FiniteGraph n a b -> r)
      -> r
withG g f = reifyNat (fromIntegral $ length $ nodes g) $ \n -> f (buildFiniteGraph n g)

-- ∀ x. P x -> ∃ x. P x
data T a b where
  T :: forall n a b. KnownNat n => Proxy n -> FiniteGraph n a b -> T a b
data Dim (n  :: Nat) = Dim Nat
---------------------------------------------------------------
getDim :: Graph a b -> Integer
getDim = undefined

getDim' :: T a b -> Integer
getDim' (T p _) = natVal p

tab :: Graph a b -> T a b
tab g = do
  let Just someNat = someNatVal (fromIntegral $ length $ nodes g)
  case someNat of
    SomeNat p -> T p (FiniteGraph g)

---------------------------------------------------------------
---------------------------------------------------------------
type DenseMatrix      n = DenseMatrix.Matrix n n Double
type AdjacencyMatrix  n = SparseMatrix n n Double
type TransitionMatrix n = SparseMatrix n n Double
type ProxemyMatrix    n = SparseMatrix n n Double

type SparseVector     n = SparseMatrix 1 n Double

type ConfluenceMatrix n = DenseMatrix n
type ModularityMatrix n = DenseMatrix n
data SimilarityMatrix n = SimConf (ConfluenceMatrix n)
                        | SimMod  (ModularityMatrix n)
data Similarity = Conf | Mod

type Dict = IntMap
data Clustering a = Clustering { parts :: Dict (Set a)
                               , index :: Dict Int
                               , score :: Double
                               }

-- TODO
data ClusteringMode = Part | Over | Both

type StrictClustering   a = Clustering a
type OverlapsClustering a = Clustering a

-----
adjacent :: KnownNat n
         => FiniteGraph n a b
         -> IsReflexive
         -> AdjacencyMatrix n
adjacent (FiniteGraph g) isReflexive =
  SMatrix.fromVector $ VS.fromList $ triplets <> diag
  where
    triplets = [ CTriplet (toC i) (toC j) 1.0
               | i <- nodes g
               , j <- neighbors g i
               , i /= j
               ]

    diag = case isReflexive of
      True -> [ CTriplet (toC n) (toC n) 1.0
             | n <- nodes g
             ]
      False -> []


transition :: KnownNat n
           => AdjacencyMatrix  n
           -> TransitionMatrix n
transition m = SMatrix.imap (\i j v -> v * (VS.!) s i) m
  where
    s = sumWith Colonne (\s -> 1 / s) m

data Direction = Ligne | Colonne

sumWith :: ( Elem a
           , Elem t
           , KnownNat n
           ) => Direction -> (t -> a) -> SparseMatrix n n t -> VS.Vector a
sumWith d f m = VS.fromList
       $ map (\v -> f v)
       $ case d of
           Colonne -> map (sum . SMatrix.toMatrix) $ SMatrix.getCols m
           Ligne   -> map (sum . SMatrix.toMatrix) $ SMatrix.getRows m


proxemie :: KnownNat n
         => Length
        -> TransitionMatrix n
        -> ProxemyMatrix n
proxemie l m = case l < 1 of
  True  -> panic "Length has to be >= 1"
  False -> foldl' (\m' _-> SMatrix.mul m' m) m [1 .. (l :: Int)]


---------------------------------------------------------------
matconf :: forall n. KnownNat n
        => RmEdge
        -> AdjacencyMatrix  n
        -> ProxemyMatrix    n
        -> ConfluenceMatrix n
matconf False a p = symmetry $ toMatrix confmat
  where
    vcount  = natToInt @n
    degs    = sumWith Colonne identity a
    sumdeg  = VS.sum degs
    confmat = SMatrix.imap (\x y v -> if x < y
                           then let
                                   prox_y_x_length = v
                                   prox_y_x_infini = ((VS.!) degs x) / sumdeg
                              in (prox_y_x_length - prox_y_x_infini)
                               / (prox_y_x_length + prox_y_x_infini)
                           else 0
                        ) p
matconf True _a _p = panic "MatConf True: TODO but not needed for now"




symmetry :: KnownNat n => DenseMatrix n -> DenseMatrix n
symmetry m = DMatrix.imap (\x y v -> if x < y then v else DMatrix.unsafeCoeff y x m) m

matmod :: forall n a b. KnownNat n => FiniteGraph n a b -> ModularityMatrix n
matmod fg = symmetry $ toMatrix modmat
  where
    n' = natToInt @n
    a  = adjacent fg False
    sumRows = sumWith Ligne   identity a
    sumCols = sumWith Colonne identity a
    ecount = sum $ toMatrix a
    modmat = SMatrix.imap (\x y v -> if x < y
                                then v - ((VS.!) sumRows x * (VS.!) sumCols y) / (2 * ecount)
                                else 0
                          ) a

---------------------------------------------------------------
type UnsortedEdges = [(Node, Node, Double)]
type SortedEdges   = [(Node, Node, Double)]

edges_confluence :: forall n a b. KnownNat n
                 => Length
                 -> FiniteGraph      n a b
                 -> AdjacencyMatrix  n
                 -> TransitionMatrix n
                 -> UnsortedEdges
edges_confluence l fg am tm = SMatrix.toList matconf
  where
    vcount  = natToInt @n
    degs    = sumWith Colonne identity am
    sumdeg  = VS.sum degs
    matconf = SMatrix.imap (\x y v -> let am' = SMatrix.add am (SMatrix.fromList [(x,y,-v)])
                                          v'  = doProx l (sparseVectorFromList (Proxy @n) [(1,y,1)]) am'

                                          prox_y_x_length = _unsafeCoeff x 1 v'
                                          prox_y_x_infini = ((VS.!) degs x - 1) / (sumdeg - 2)

                                       in (prox_y_x_length - prox_y_x_length)
                                        / (prox_y_x_length + prox_y_x_infini)
                           ) am

    doProx :: KnownNat n => Length -> SparseVector n -> TransitionMatrix n -> SparseVector n
    doProx l v tm = foldl' (\v' _-> SMatrix.mul v' tm) v [1 .. (l :: Int)]


-- | TODO optimization
sort_edges :: Int
           -> UnsortedEdges
           -> SortedEdges
sort_edges n =  List.concat
             . (map (List.sortOn (\(x,y,_) -> x * n + y)))
             . (List.groupBy (\x y -> third x == third y))
             .  List.reverse
             . (List.sortOn third)
  where
    third :: forall a b c. (a,b,c) -> c
    third (_,_,c) = c


---------------------------------------------------------------

updateWith :: Clustering Node -> (Int -> Int -> Double) -> (Int,Int) -> (Int,Int) -> Clustering Node
updateWith c@(Clustering parts idx _) f (modX,x) (modY,y) = Clustering parts' idx' score'
  where
    parts' = Dict.alter (del y) modY $ Dict.alter (add y) modX parts
    add y Nothing   = Just (Set.singleton y)
    add y (Just y') = Just (Set.insert y y')

    del y Nothing   = Nothing
    del y (Just y') = Just (Set.delete y y')
    
    idx' = Dict.alter (alter modY) y idx
    alter my _   = Just my

    px = fromMaybe Set.empty $ Dict.lookup x parts'
    py = fromMaybe Set.empty $ Dict.lookup y parts'
    score' = Prelude.sum [ f x'' y''
                 | x'' <- Set.toList px
                 , y'' <- Set.toList py
                 ]


updateClustering :: Clustering Node -> (Int -> Int -> Double) -> Int -> Int -> Clustering Node
updateClustering c@(Clustering parts idx currentScore) f x y =
  let modX = fromMaybe 0 $ Dict.lookup x idx
      modY = fromMaybe 0 $ Dict.lookup y idx
   in case x == y || modX == modY of
    True  -> c -- do case x' or y' are Nothing 
    False -> let c' = updateWith c f (x,modX) (y,modY)
              in case score c' > currentScore of
                   True  -> c'
                   False -> c

make_clust_part :: KnownNat n
                => SortedEdges
                -> SimilarityMatrix n
                -> Clustering Node
make_clust_part se sm = foldl' (\c (e1,e2,_) -> updateClustering c (\x y -> unsafeCoeff e1 e2 sm') e1 e2)
                               (Clustering Dict.empty Dict.empty 0) se
  where
    sm' = case sm of
      SimConf cm -> cm
      SimMod  mm -> mm


---------------------------------------------------------------
make_clust_over :: KnownNat n
                => SimilarityMatrix n
                -> StrictClustering   a
                -> OverlapsClustering a
make_clust_over = undefined

---------------------------------------------------------------
clusteringOptim :: forall n a b. KnownNat n
                => Length
                -> FiniteGraph n a b
                -> Similarity
                -> Clustering Node
clusteringOptim l fg@(FiniteGraph g) s = make_clust_part sorted_edges matq
  where
    adj   = adjacent fg True
    tra   = transition adj
    sorted_edges = sort_edges (natToInt @n)
                 $ edges_confluence l fg adj tra

    matq = case s of
      Conf -> SimConf $ matconf False adj (proxemie l tra)
      Mod  -> SimMod  $ matmod fg

