{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-|
Module      : Gargantext.Core.LinearAlgebra.Distributional
Description : The "distributional" algorithm, fast and slow implementations
Copyright   : (c) CNRS, 2017-Present
License     : AGPL + CECILL v3
Maintainer  : team@gargantext.org
Stability   : experimental
Portability : POSIX

-}

module Gargantext.Core.LinearAlgebra.Distributional (
    distributional
  , logDistributional2

  -- * Internals for testing
  , distributionalReferenceImplementation
  ) where

import Data.Massiv.Array (D, Matrix, Vector, Array, Ix3, U, Ix2 (..), IxN (..))
import Data.Massiv.Array qualified as A
import Gargantext.Core.LinearAlgebra.Operations
import Prelude

-- | `distributional m` returns the distributional distance between each
-- pair of terms as a matrix.  The argument m is the matrix $[n_{ij}]_{i,j}$
-- where $n_{ij}$ is the coocccurrence between term $i$ and term $j$.
--
-- ## Basic example with Matrix of size 3: 
--
-- >>> theMatrixInt 3
-- Matrix (Z :. 3 :. 3)
--   [ 7, 4, 0,
--     4, 5, 3,
--     0, 3, 4]
--
-- >>> distributional $ theMatrixInt 3
-- Matrix (Z :. 3 :. 3)
--   [ 1.0, 0.0, 0.9843749999999999,
--     0.0, 1.0,                0.0,
--     1.0, 0.0,                1.0]
--
-- ## Basic example with Matrix of size 4: 
--
-- >>> theMatrixInt 4
-- Matrix (Z :. 4 :. 4)
--   [ 4, 1, 2, 1,
--     1, 4, 0, 0,
--     2, 0, 3, 3,
--     1, 0, 3, 3]
--
-- >>> distributional $ theMatrixInt 4
-- Matrix (Z :. 4 :. 4)
--   [                  1.0,                   0.0, 0.5714285714285715, 0.8421052631578947,
--                      0.0,                   1.0,                1.0,                1.0,
--     8.333333333333333e-2,             4.6875e-2,                1.0,               0.25,
--       0.3333333333333333, 5.7692307692307696e-2,                1.0,                1.0]
--
-- /IMPORTANT/: As this function computes the diagonal matrix in order to carry on the computation
-- the input has to be a square matrix, or this function will fail at runtime.
distributional :: forall r e. ( A.Manifest r e
                              , A.Manifest r Int
                              , A.Unbox e
                              , A.Source r Int
                              , A.Size r
                              , Ord e
                              , Fractional e
                              , Num e
                              )
                              => Matrix r Int
                              -> Matrix U e
distributional m' = A.computeP result
 where
    mD :: Matrix D e
    mD = A.map fromIntegral m'

    m :: Matrix A.U e
    m = A.compute mD

    n :: Int
    n = dim m'

    diag_m :: Vector A.U e
    diag_m = diag m

    d_1 :: Matrix A.D e
    d_1 = A.backpermute' (A.Sz2 n n) (\(_ A.:. i) -> i) diag_m

    d_2 :: Matrix A.D e
    d_2 = A.backpermute' (A.Sz2 n n) (\(i A.:. _) -> i) diag_m

    a :: Matrix D e
    a = termDivNanD mD d_1

    b :: Matrix D e
    b = termDivNanD mD d_2

    miDelayed :: Matrix D e
    miDelayed = a `mulD` b

    miMemo :: Matrix D e
    miMemo = A.delay (A.compute @U miDelayed)

    w_1 :: Array D Ix3 e
    w_1 = A.backpermute' (A.Sz3 n n n) (\(x A.:> _y A.:. z) -> x A.:. z) miMemo

    w_2 :: Array D Ix3 e
    w_2 = A.backpermute' (A.Sz3 n n n) (\(_x A.:> y A.:. z) -> y A.:. z) miMemo

    w' :: Array D Ix3 e
    w' = A.zipWith min w_1 w_2

    z_1 :: Matrix A.D e
    z_1 = A.ifoldlWithin' 1 ( \(i :> j :. k) acc w'_val ->
      let ii_val = if k /= i && k /= j then 1 else 0
          z1_val = w'_val * ii_val
      in acc + z1_val
      ) 0 w'

    z_2 :: Matrix A.D e
    z_2 = A.ifoldlWithin' 1 ( \(i :> j :. k) acc w1_val ->
      let ii_val = if k /= i && k /= j then 1 else 0
          z2_val = w1_val * ii_val
      in acc + z2_val
      ) 0 w_1

    result :: Matrix A.D e
    result = termDivNanD z_1 z_2

-- | A reference implementation for \"distributional\" which is slower but
-- it's more declarative and can be used to assess the correctness of the
-- optimised version.
-- Same proviso about the shape of the matri applies for this function.
distributionalReferenceImplementation :: forall r e.
                                      ( A.Manifest r e
                                      , A.Unbox e
                                      , A.Source r Int
                                      , A.Size r
                                      , Ord e
                                      , Fractional e
                                      , Num e
                                      )
                                      => Matrix r Int
                                      -> Matrix r e
distributionalReferenceImplementation m' = result
 where
    mD :: Matrix D e
    mD = A.map fromIntegral m'

    m :: Matrix A.U e
    m = A.compute mD

    n :: Int
    n = dim m'

    -- Computes the diagonal matrix of the input ..
    diag_m :: Vector A.U e
    diag_m = diag m

    -- Then we create a matrix that contains the same elements of diag_m
    -- for the rows and columns, to make it square again.
    d_1 :: Matrix A.D e
    d_1 = A.backpermute' (A.Sz2 n n) (\(_ A.:. i) -> i) diag_m

    d_2 :: Matrix A.D e
    d_2 = A.backpermute' (A.Sz2 n n) (\(i A.:. _) -> i) diag_m

    a :: Matrix D e
    a = termDivNanD mD d_1

    b :: Matrix D e
    b = termDivNanD mD d_2

    miDelayed :: Matrix D e
    miDelayed = a `mulD` b

    miMemo :: Matrix D e
    miMemo = A.delay (A.compute @U miDelayed)

    -- The matrix permutations is taken care of below by directly replicating
    -- the matrix mi, making the matrix w unneccessary and saving one step.
    -- replicate (constant (Z :. All :. n :. All)) mi
    w_1 :: Array D Ix3 e
    w_1 = A.backpermute' (A.Sz3 n n n) (\(x A.:> _y A.:. z) -> x A.:. z) miMemo

    -- replicate (constant (Z :. n :. All :. All)) mi
    w_2 :: Array D Ix3 e
    w_2 = A.backpermute' (A.Sz3 n n n) (\(_x A.:> y A.:. z) -> y A.:. z) miMemo

    w' :: Array D Ix3 e
    w' = A.zipWith min w_1 w_2

    -- The matrix ii = [r_{i,j,k}]_{i,j,k} has r_(i,j,k) = 0 if k = i OR k = j
    -- and r_(i,j,k) = 1 otherwise (i.e. k /= i AND k /= j).
    -- generate (constant (Z :. n :. n :. n)) (lift1 (\( i A.:. j A.:. k) -> cond ((&&) ((/=) k i) ((/=) k j)) 1 0))
    ii :: Array A.D Ix3 e
    ii = A.makeArrayR A.D A.Seq (A.Sz3 n n n) $ \(i A.:> j A.:. k) -> if k /= i && k /= j then 1 else 0

    z_1 :: Matrix A.D e
    z_1 = sumRowsD (w' `mulD` ii)

    z_2 :: Matrix A.D e
    z_2 = sumRowsD (w_1 `mulD` ii)

    result = A.computeP (termDivNanD z_1 z_2)


logDistributional2 :: (A.Manifest r e
                      , A.Unbox e
                      , A.Source r Int
                      , A.Shape r Ix2
                      , Num e
                      , Ord e
                      , A.Source r e
                      , Fractional e
                      , Floating e
                      )
                   => Matrix r Int
                   -> Matrix r e
logDistributional2 m = A.computeP
                     $ diagNull n
                     $ matMaxMini
                     $ logDistributional' n m
  where
    n = dim m

logDistributional' :: forall r e.
                   ( A.Manifest r e
                   , A.Unbox e
                   , A.Source r Int
                   , A.Shape r Ix2
                   , Num e
                   , Ord e
                   , A.Source r e
                   , Fractional e
                   , Floating e
                   )
                   => Int
                   -> Matrix r Int
                   -> Matrix r e
logDistributional' n m' = result
 where
    m :: Matrix A.U e
    m = A.compute $ A.map fromIntegral m'

    -- Scalar. Sum of all elements of m.
    to :: e
    to = A.sum m

    -- Diagonal matrix with the diagonal of m.
    d_m :: Matrix A.D e
    d_m = m `mulD` (matrixIdentity n)

    -- Size n vector. s = [s_i]_i
    s :: Vector A.U e
    s = A.compute $ sumRowsD (m `subD` d_m)

    -- Matrix nxn. Vector s replicated as rows.
    s_1 :: Matrix D e
    s_1 = A.backpermute' (A.Sz2 n n) (\(x :. _y) -> x) s

    -- Matrix nxn. Vector s replicated as columns.
    s_2 :: Matrix D e
    s_2 = A.backpermute' (A.Sz2 n n) (\(_x :. y) -> y) s

    -- Matrix nxn. ss = [s_i * s_j]_{i,j}. Outer product of s with itself.
    ss :: Matrix A.D e
    ss = s_1 `mulD` s_2

    mi_divvy :: Matrix A.D e
    mi_divvy = A.zipWith (\m_val ss_val ->
      let x  = m_val `safeDiv` ss_val
          x' = x * to
      in if (x' < 1) then 0 else log x') m ss

    -- Matrix nxn. mi = [m_{i,j}]_{i,j} where
    -- m_{i,j} = 0 if n_{i,j} = 0 or i = j,
    -- m_{i,j} = log(to * n_{i,j} / s_{i,j}) otherwise.
    mi :: Matrix A.U e
    mi = A.computeP $ mulD (matrixEye n) (mi_divvy)

    sumMin :: Matrix A.U e
    sumMin = sumMin_go n mi

    sumM :: Matrix A.U e
    sumM = sumM_go n mi

    result :: Matrix r e
    result = termDivNan sumMin sumM
