{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-|
Module      : Gargantext.Core.LinearAlgebra.Operations
Description : Operations on matrixes using massiv
Copyright   : (c) CNRS, 2017-Present
License     : AGPL + CECILL v3
Maintainer  : team@gargantext.org
Stability   : experimental
Portability : POSIX

-}

module Gargantext.Core.LinearAlgebra.Operations (
  -- * Convertion functions
    accelerate2MassivMatrix
  , accelerate2Massiv3DMatrix
  , massiv2AccelerateMatrix
  , massiv2AccelerateVector

  -- * Operations on matrixes
  , (.*)
  , (.-)
  , diag
  , termDivNan
  , sumRows
  , dim
  , matrixEye
  , matrixIdentity
  , diagNull

  -- * Operations on delayed arrays
  , diagD
  , subD
  , mulD
  , termDivNanD
  , sumRowsD
  , safeDiv

  -- * Internals for testing
  , sumRowsReferenceImplementation
  , matMaxMini
  , sumM_go
  , sumMin_go
  ) where

import Data.Array.Accelerate qualified as Acc
import Data.List.Split qualified as Split
import Data.Massiv.Array (D, Matrix, Vector, Array)
import Data.Massiv.Array qualified as A
import Prelude
import Protolude.Safe (headMay)
import Data.Monoid

-- | Converts an accelerate matrix into a Massiv matrix.
accelerate2MassivMatrix :: (A.Unbox a, Acc.Elt a) => Acc.Matrix a -> Matrix A.U a
accelerate2MassivMatrix m =
  let (Acc.Z Acc.:. _r Acc.:. c) = Acc.arrayShape m
  in A.fromLists' @A.U A.Par $ Split.chunksOf c (Acc.toList m)

-- | Converts a massiv matrix into an accelerate matrix.
massiv2AccelerateMatrix :: (Acc.Elt a, A.Source r a) => Matrix r a -> Acc.Matrix a
massiv2AccelerateMatrix m =
  let m' = A.toLists2 m
      r  = Prelude.length m'
      c  = maybe 0 Prelude.length (headMay m')
  in Acc.fromList (Acc.Z Acc.:. r Acc.:. c) (mconcat m')

-- | Converts a massiv vector into an accelerate one.
massiv2AccelerateVector :: (A.Source r a, Acc.Elt a) => A.Vector r a -> Acc.Vector a
massiv2AccelerateVector m =
  let m' = A.toList m
      r  = Prelude.length m'
  in Acc.fromList (Acc.Z Acc.:. r) m'

accelerate2Massiv3DMatrix :: (A.Unbox e, Acc.Elt e, A.Manifest r e)
                          => Acc.Array (Acc.Z Acc.:. Int Acc.:. Int Acc.:. Int) e
                          -> A.Array r A.Ix3 e
accelerate2Massiv3DMatrix m =
  let (Acc.Z Acc.:. _r Acc.:. _c Acc.:. _z) = Acc.arrayShape m
  in A.fromLists' A.Par $ map (Split.chunksOf $ _z) $ Split.chunksOf (_c*_z) (Acc.toList m)



-- | Computes the diagnonal matrix of the input one.
diag :: (A.Unbox e, A.Manifest r e, A.Source r e, Num e) => Matrix r e -> Vector A.U e
diag matrix =
  let (A.Sz2 rows _cols) = A.size matrix
      newSize = A.Sz1 rows
  in A.makeArrayR A.U A.Seq newSize $ (\(A.Ix1 i) -> matrix A.! (A.Ix2 i i))

diagD :: (A.Source r e, A.Size r) => Matrix r e -> Vector A.D e
diagD matrix =
  let (A.Sz2 rows _cols) = A.size matrix
      newSize = A.Sz1 rows
  in A.backpermute' newSize (\i -> i A.:. i) matrix

-- | Term by term division where divisions by 0 produce 0 rather than NaN.
termDivNan :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, Eq a, Fractional a)
           => Matrix r1 a
           -> Matrix r2 a
           -> Matrix r3 a
termDivNan m1 = A.compute . termDivNanD m1

termDivNanD :: (A.Source r1 a, A.Source r2 a, Eq a, Fractional a)
            => Matrix r1 a
            -> Matrix r2 a
            -> Matrix D a
termDivNanD m1 m2 = A.zipWith safeDiv m1 m2

safeDiv :: (Eq a, Fractional a) => a -> a -> a
safeDiv i j = if j == 0 then 0 else i / j
{-# INLINE safeDiv #-}

sumRows :: ( A.Index (A.Lower ix)
           , A.Index ix
           , A.Source r e
           , A.Manifest r e
           , A.Strategy r
           , A.Size r
           , Num e
           ) => Array r ix e
             -> Array r (A.Lower ix) e
sumRows = A.compute . sumRowsD

sumRowsD :: ( A.Index (A.Lower ix)
            , A.Index ix
            , A.Source r e
            , Num e
            ) => Array r ix e
              -> Array D (A.Lower ix) e
sumRowsD matrix = A.map getSum $ A.foldlWithin' 1 (\(Sum s) n -> Sum $ s + n) mempty matrix

sumRowsReferenceImplementation :: ( A.Load r A.Ix2 e
                                  , A.Source r e
                                  , A.Manifest r e
                                  , A.Strategy r
                                  , A.Size r
                                  , Num e
                                  ) => Array r A.Ix3 e
                                    -> Array r A.Ix2 e
sumRowsReferenceImplementation matrix =
  let A.Sz3 rows cols z = A.size matrix
  in A.makeArray (A.getComp matrix) (A.Sz2 rows cols) $ \(i A.:. j) ->
       A.sum (A.backpermute' (A.Sz1 z) (\c -> i A.:> j A.:. c) matrix)

-- | Matrix cell by cell multiplication
(.*) :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, A.Index ix, Num a)
     => Array r1 ix a
     -> Array r2 ix a
     -> Array r3 ix a
(.*) m1 = A.compute . mulD m1

mulD :: (A.Source r1 a, A.Source r2 a, A.Index ix, Num a)
     => Array r1 ix a
     -> Array r2 ix a
     -> Array D ix a
mulD m1 m2 = A.zipWith (*) m1 m2

-- | Matrix cell by cell substraction
(.-) :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, A.Index ix, Num a)
     => Array r1 ix a
     -> Array r2 ix a
     -> Array r3 ix a
(.-) m1 = A.compute . subD m1

subD :: (A.Source r1 a, A.Source r2 a, A.Index ix, Num a)
     => Array r1 ix a
     -> Array r2 ix a
     -> Array D ix a
subD m1 m2 = A.zipWith (-) m1 m2


-- | Get the dimensions of a /square/ matrix.
dim :: A.Size r => Matrix r a -> Int
dim m = n
  where
    (A.Sz2 _ n) = A.size m

matMaxMini :: (A.Unbox a, A.Source r a, Ord a, Num a, A.Shape r A.Ix2) => Matrix r a -> Matrix A.U a
matMaxMini m = A.compute $ A.map (\x -> if x > miniMax then x else 0) m
  where
    -- Convert the matrix to a list of rows, take the minimum of each row,
    -- and then the maximum of those minima.
    miniMax = maximum (map minimum (A.toLists m))

sumM_go :: (A.Unbox a, A.Manifest r a, Num a, A.Load r A.Ix2 a) => Int -> Matrix r a -> Matrix A.U a
sumM_go n mi = A.makeArrayR A.U A.Seq (A.Sz2 n n) $ \(i A.:. j) ->
  Prelude.sum [ if k /= i && k /= j then mi A.! (i A.:. k) else 0 | k <- [0 .. n - 1] ]

sumMin_go :: (A.Unbox a, A.Manifest r a, Num a, Ord a, A.Load r A.Ix2 a) => Int -> Matrix r a -> Matrix A.U a
sumMin_go n mi = A.makeArrayR A.U A.Seq (A.Sz2 n n) $ \(i A.:. j) ->
  Prelude.sum
    [ if k /= i && k /= j
        then min (mi A.! (i A.:. k)) (mi A.! (j A.:. k))
        else 0
    | k <- [0 .. n - 1]
    ]

matrixEye :: (A.Unbox e, Num e) => Int -> Matrix A.U e
matrixEye n = A.makeArrayR A.U A.Seq (A.Sz2 n n) $ \(i A.:. j) -> if i == j then 0 else 1
{-# INLINE matrixEye #-}
{-# SPECIALIZE matrixEye :: Int -> Matrix A.U Double #-}

matrixIdentity :: (A.Unbox e, Num e) => Int -> Matrix A.U e
matrixIdentity n = A.makeArrayR A.U A.Seq (A.Sz2 n n) $ \(i A.:. j) -> if i == j then 1 else 0
{-# INLINE matrixIdentity #-}
{-# SPECIALIZE matrixIdentity :: Int -> Matrix A.U Double #-}

diagNull :: (A.Unbox e, A.Source r e, Num e) => Int -> Matrix r e -> Matrix A.U e
diagNull n m = A.compute $ A.zipWith (*) m (matrixEye n)
