{-| Module      : Gargantext.Core.Viz.Graph.ProxemyOptim
Description : Proxemy
Copyright   : (c) CNRS, 2017-Present
License     : AGPL + CECILL v3
Maintainer  : team@gargantext.org
Stability   : experimental
Portability : POSIX

(/!\ To be published soon, confidential document for now. After
publication it will be integrated to the backend with usual license --
see above)

Article: Confluence for Graph Clustering, B. Gaume and A. Delanoë, A. Mestanogullari

-}

{-# LANGUAGE FlexibleContexts
           , DataKinds
           , GADTs
           , KindSignatures
           , ScopedTypeVariables
           , StandaloneDeriving
           , TypeOperators
           , TypeApplications
           , NoImplicitPrelude
           , RankNTypes
           , MultiParamTypeClasses
           , BangPatterns
#-}

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TupleSections #-}
module Graph.BAC.ProxemyOptim
  where

import Data.IntMap (IntMap)
import Data.Maybe (isJust, fromJust)
import Data.Proxy (Proxy(Proxy))
import Data.Reflection
import Data.Semigroup
import GHC.TypeLits (KnownNat, Nat, SomeNat(SomeNat), type(+), natVal, sameNat, someNatVal)
import Graph.FGL
import Graph.Types
import Prelude (String, readLn, error, id)
import Protolude hiding (sum, natVal)
import qualified Data.Graph.Inductive              as DGI
import qualified Data.Graph.Inductive.PatriciaTree as DGIP
import qualified Data.List                         as List
import qualified Data.Vector                       as V
import qualified Data.Vector.Storable              as VS
import qualified Data.Vector.Unboxed               as VU
import qualified Prelude                           as Prelude
import qualified Data.Set                          as Set

import qualified Data.Matrix.Sparse.Static         as SMatrix
import qualified Numeric.LinearAlgebra.Static      as DMatrix
import qualified Data.Vector.Unboxed as UV
import qualified Numeric.LinearAlgebra.Devel       as DMatrix
import qualified Data.Vector.Sparse.Static         as SVector
import qualified Data.IntMap.Strict as Dict
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Unboxed.Mutable as MVU
import qualified Data.IntSet as IntSet
import qualified Data.IntMap.Strict as IntMap
import qualified Data.Map.Strict as Map

----------------------------------------------------------------

traceMaxIndices :: forall a t. (Bounded t, Ord t, Show t) => String -> [t] -> a -> a
traceMaxIndices f xs a = trace s a

  where s = "[" ++ f ++ "] (min, max, # of ints) = " ++ show (min_i, max_i, Set.size is)
        (min_i, max_i, is) = foldl'
          (\(s, b, ints) i -> (min s i, max b i, Set.insert i ints))
          (maxBound :: t, minBound :: t, Set.empty)
          xs

traceAdjMapIndices :: String -> Map (Int, Int) x -> a -> a
traceAdjMapIndices f m a = traceMaxIndices f (foldMap (\(a, b) -> [a, b]) (Map.keys m)) a

traceDicoIndices :: (Ord x, Show x, Bounded x) => String -> IntMap x -> a -> a
traceDicoIndices s m a = traceMaxIndices (s ++ " dico keys => ") (Dict.keys m) $
  traceMaxIndices (s ++ " dico vals => ") (Dict.elems m) a

defaultClustering :: Map (Int, Int) Double -> [ClusterNode]
defaultClustering adjmap = withG g $ \fg ->
  case clusteringOptim len fg dicoToId beta gc of
    Clust _ idx _ -> map go (Dict.toList idx)

  where gc = False
        beta = 0.0
        len = 3
        g = DGI.mkGraph ns es
        ns = zip [0..] . Set.toList . Set.fromList $
          concatMap (\(a, b) -> [a, b]) $ Map.keys adjmap
        lkpId n = dicoToId Dict.! n
        lkpLbl n = dicoToLbl Dict.! n
        dicoToId = Dict.fromList (map (\(a, b) -> (b, a)) ns)
        dicoToLbl = Dict.fromList ns
        es = map (\((a, b), w) -> (lkpId a, lkpId b, w)) $ Map.toList adjmap
        go (i, clust) = ClusterNode
          (lkpLbl i)
          clust

{-# INLINE clusteringOptim #-}
clusteringOptim :: forall n a b. (KnownNat n, Ord a, Show a, Bounded a)
                => Length            -- ^ length of the random walks
                -> FiniteGraph n a b -- ^ graph to compute clusters for
                -> Dict a
                -> Double            -- ^ beta
                -> Bool              -- ^ True = run GC, False = don't
                -> Clust a
clusteringOptim l fg@(FiniteGraph g) dico beta gc =
  case runClustering gc beta adj prox sorted_edges of
    (clusts, d) -> Clust clusts (index clusts) d
  where
    index clusts = Dict.foldMapWithKey
      (\clustN is -> Dict.fromList $ map (,clustN) (IntSet.toList is))
      clusts
    adj = graphMatrix fg True
    tra = transition adj
    prox = proxemie l tra
    sorted_edges = sort_edges (natToInt @n) (edges_confluence l fg adj tra)

graphMatrix
  :: forall (n :: Nat) a b.
     KnownNat n
  => FiniteGraph n a b -> Bool -> SMatrix.Matrix n n Double
graphMatrix (FiniteGraph g) reflexive = adj
  where
    adj = SMatrix.fromList es
    es = diag ++ triplets
    triplets = [ (i, j, 1.0) | i <- nodes g, j <- neighbors g i ]
    diag     = if reflexive
               then [ (i, i, 1.0) | i <- nodes g ]
               else []

transition
  :: KnownNat n => SMatrix.Matrix n n Double -> SMatrix.Matrix n n Double
transition m = SMatrix.imap
  (\i j _ -> 1 / fromIntegral (SMatrix.nnzCol m j))
  m

-- | Where main Types are defined as
data Similarity = Conf | Mod
type Length      = Int

-- | A finite Graph is a Graph whose number of Nodes is a KnownNat
data FiniteGraph (n :: Nat) a b = FiniteGraph (Graph a b)

type IsReflexive = Bool
type NeighborsFilter a b = DGI.Gr a b -> Node -> [Node]
type RmEdge      = Bool
---------------------------------------------------------------
-- Data Structure
type VectorS          n = SVector.V n Double
type VectorD          n = DMatrix.L 1 n

type AdjacencyMatrix  n = MatrixS n
type TransitionMatrix n = MatrixS n
type ProxemyMatrix    n = MatrixS n

type ConfluenceMatrix n = MatrixD n
type ModularityMatrix n = MatrixD n

data SimilarityMatrix n = SimConf !(ConfluenceMatrix n)
                        | SimMod  !(ModularityMatrix n)

instance KnownNat n => Show (SimilarityMatrix n) where
  show (SimConf (DMatrix.L (DMatrix.Dim (DMatrix.Dim m)))) = show m
  show (SimMod m) = show m


---
proxemie :: KnownNat n
         => Length
         -> SMatrix.Matrix n n Double
         -> ProxemyMatrix n
proxemie l !tm = case l <= 1 of
  True  -> tm
  False -> iterate (SMatrix.mul tm) tm Prelude.!! (l-1)

---------------------------------------------------------------
matconf :: forall n. KnownNat n
        => RmEdge
        -> SMatrix.Matrix n n Double
        -> ProxemyMatrix    n
        -> ConfluenceMatrix n
matconf False a p = seq a $ seq p $ confmat
  where
    vcount  = natToInt @n
    sumdeg  = fromIntegral (SMatrix.nonZeros a)
    prox_nnz = 100 * fromIntegral (SMatrix.nonZeros p) / fromIntegral (vcount * vcount)
    p' = SMatrix.densify p
    confmat = DMatrix.L . DMatrix.Dim . DMatrix.Dim $ DMatrix.runSTMatrix $ do
      m <- DMatrix.newMatrix 0 vcount vcount
      forM_ [0..(vcount-1)] $ \x -> do
        let deg_x = fromIntegral (SMatrix.nnzCol a x)
        forM_ [(x+1)..(vcount-1)] $ \y -> do
            let prox_y_x_length = p' `DMatrix.at` (x,y)
                prox_y_x_infini = fromIntegral (SMatrix.nnzCol a x) / sumdeg
                conf = (prox_y_x_length - prox_y_x_infini) / (prox_y_x_length + prox_y_x_infini)
            DMatrix.unsafeWriteMatrix m x y conf
            DMatrix.unsafeWriteMatrix m y x conf
      return m
matconf True _a _p = panic "MatConf True: TODO but not needed for now"

confAt
  :: forall (n :: Nat).
     KnownNat n
  => Double -- ^ beta
  -> SMatrix.Matrix n n Double -- ^ adjacency
  -> SMatrix.Matrix n n Double -- ^ proxemie
  -> Int -> Int -> Double
confAt beta adj prox x y = xy
  where deg_x = fromIntegral (SMatrix.nnzCol adj x)
        deg_y = fromIntegral (SMatrix.nnzCol adj y)
        sumdeg = fromIntegral (SMatrix.nonZeros adj)
        vcount = fromIntegral $ natVal (Proxy :: Proxy n)
        ecount = sumdeg - vcount
        -- ecount = 2 * number of edges, a bitof a misnomer
        -- indeed, sumdeg = number of non-zero entries in (reflexive) adj matrix,
        -- and '- vcount' removes the diagonal, giving 2 * |E|
        prox_y_x_length = prox `SMatrix.at` (x, y)
        prox_y_x_infini = deg_x / sumdeg
        conf = (prox_y_x_length - prox_y_x_infini) / (prox_y_x_length + prox_y_x_infini)
        xy | adj `SMatrix.at` (x, y) > 0 =
               conf + (1 - (deg_x-1)*(deg_y-1)/ecount)
           | otherwise =
               conf - (deg_x-1)*(deg_y-1)/ecount - beta*(1-conf)

---------------------------------------------------------------
type UnsortedEdges = [(Node, Node, Double)]
type SortedEdges   = [(Node, Node, Double)]

type X = Edge

-- | Just compute the confluences of a list of edges
computeConfluences
  :: Length -- ^ length of the random walk
  -> [(Int, Int)] -- ^ list of edges
  -> Bool -- ^ reflexive?
  -> Map (Int, Int) Double
computeConfluences l edges reflexive = reifyNat (fromIntegral maxNode + 1) $ \(Proxy :: Proxy n) ->
  let
    nodeLabels = Set.toList $ Set.fromList $ foldMap (\(a, b) -> [a, b]) edges
    dictLabels = Dict.fromList (zip [0..] nodeLabels)
    dictIDs    = Dict.fromList (zip nodeLabels [0..])
    edges'     = map (\(a, b) -> (dictIDs Dict.! a, dictIDs Dict.! b)) edges
    xs :: [(Int, Int, Double)]
    xs =
      concatMap (\(i, j) -> [(i, j, 1.0), (j, i, 1.0)]) edges' ++
      (if reflexive
        then [ (i, i, 1.0) | i <- [0..(Dict.size dictLabels - 1)] ]
        else []
      )
    am :: SMatrix.Matrix n n Double
    am = SMatrix.fromList xs
    tm = transition am
    sumdeg_m2 = fromIntegral (SMatrix.nonZeros am - 2)
    go x y =
      let
        !deg_x_m1 = fromIntegral (SMatrix.nnzCol am x - 1)
        !deg_y_m1 = fromIntegral (SMatrix.nnzCol am y - 1)
        v         = SMatrix.asColumn (SVector.singleton y 1)
        v'        =
          SMatrix.withColChangeExcept x (1/deg_x_m1) y tm $ \tm' ->
            SMatrix.withColChangeExcept y (1/deg_y_m1) x tm' $ \tm'' ->
              iterate (SMatrix.mul tm'') v Prelude.!! l
        prox_y_x_length = SMatrix.extractCol v' 0 SVector.! x
        prox_y_x_infini = if sumdeg_m2 == 0 then 0 else deg_x_m1 / sumdeg_m2
        denominator = (prox_y_x_length + prox_y_x_infini)
      in
        if denominator == 0
           then 0
           else (prox_y_x_length - prox_y_x_infini) / denominator
  in
    Map.fromList $ map
      (\(a, b) -> ( (a, b)
                  , go (dictIDs Dict.! a) (dictIDs Dict.! b)
                  )
      ) edges

  where maxNode = getMax $ foldMap (\(i, j) -> Max (max i j)) edges

edges_confluence :: forall n a b.
                    KnownNat n
                 => Length
                 -> FiniteGraph n a b
                 -> SMatrix.Matrix n n Double -- adjacency
                 -> SMatrix.Matrix n n Double -- transition
                 -> UnsortedEdges
edges_confluence l (FiniteGraph g) am tm = map f (edges g)

  where
      vcount    = natToInt @n
      sumdeg_m2 = fromIntegral (SMatrix.nonZeros am - 2)

      f (x, y) =
          let !deg_x_m1 = fromIntegral (SMatrix.nnzCol am x - 1)
              !deg_y_m1 = fromIntegral (SMatrix.nnzCol am y - 1)
              v         = SMatrix.asColumn (SVector.singleton y 1)
              v'        =
                SMatrix.withColChangeExcept x (1/deg_x_m1) y tm $ \tm' ->
                  SMatrix.withColChangeExcept y (1/deg_y_m1) x tm' $ \tm'' ->
                    iterate (SMatrix.mul tm'') v Prelude.!! l
              prox_y_x_length = SMatrix.extractCol v' 0 SVector.! x
              prox_y_x_infini = deg_x_m1 / sumdeg_m2
              conf  = (prox_y_x_length - prox_y_x_infini) / (prox_y_x_length + prox_y_x_infini)
          in seq conf (x, y, conf)

sort_edges :: Int
           -> UnsortedEdges
           -> SortedEdges
sort_edges n = List.sortBy (\a b -> confCompare a b <> comparing xnpy a b)
  where
    third
      :: forall a b c
       . (a,b,c) -> c
    third (_,_,c) = c
    xnpy (x,y,_) = x*n+y
    confCompare a b =
      if abs (third b - third a) < 10**(-12)
        then EQ
        else comparing (Down . third) a b


---------------------------------------------------------------

data PartData = PartData
  { partElems :: !IntSet
  , partScore :: {-# UNPACK #-} !Double
  } deriving (Show, Eq)

scoreMerge :: (Node -> Node -> Double) -> Maybe PartData -> Maybe PartData -> Double
scoreMerge f Nothing  Nothing  = 0
scoreMerge f Nothing  (Just p) = partScore p
scoreMerge f (Just p) Nothing  = partScore p
scoreMerge f (Just p) (Just q) =
  IntSet.foldl'
    (\s i -> IntSet.foldl' (\s' j -> s' + f i j) s (partElems p))
    (partScore p + partScore q)
    (partElems q)

partsMerge :: Maybe PartData -> Maybe PartData -> Double -> Maybe PartData
partsMerge Nothing  Nothing  _ = Nothing
partsMerge Nothing  (Just p) _ = Just p
partsMerge (Just p) Nothing  _ = Just p
partsMerge (Just p) (Just q) s = Just $
  PartData (partElems p `IntSet.union` partElems q) s

data MClustering s =
  MClustering { mparts :: V.MVector s (Maybe PartData)
              , mindex :: VU.MVector s Int
              , mscore :: VU.MVector s Double -- 1-entry array, total score
              , mnumcl :: VU.MVector s Int    -- 1-entry array, #clusters
              }

newMClustering :: Int -> ST s (MClustering s)
newMClustering n = do
  mps <- MV.unsafeNew n
  mis <- MVU.unsafeNew n
  msc <- MVU.unsafeNew 1
  mcl <- MVU.unsafeNew 1
  MVU.unsafeWrite msc 0 0.0
  MVU.unsafeWrite mcl 0 n
  forM_ [0..(n-1)] $ \i -> do
    MV.unsafeWrite mps i $ Just $ PartData (IntSet.singleton i) 0
    MVU.unsafeWrite mis i i
  return (MClustering mps mis msc mcl)

clusteringStep
  :: KnownNat n
  => Double -- ^ beta
  -> SMatrix.Matrix n n Double -- ^ adjacency
  -> SMatrix.Matrix n n Double -- ^ proxemie
  -> MClustering s -> (Node, Node) -> ST s ()
clusteringStep beta adj prox mclust (x, y) = do
  modX <- MVU.unsafeRead (mindex mclust) x
  modY <- MVU.unsafeRead (mindex mclust) y
  when (x /= y && modX /= modY) $ do
    partX <- MV.unsafeRead (mparts mclust) modX
    partY <- MV.unsafeRead (mparts mclust) modY
    score <- MVU.unsafeRead (mscore mclust) 0
    let scoreX  = maybe 0 partScore partX
        scoreY  = maybe 0 partScore partY
        scoreXY = scoreMerge f partX partY
        partXY  = partsMerge partX partY scoreXY
        score'  = score - scoreX - scoreY + scoreXY
    when (score' >= score) $ do
      MV.unsafeWrite (mparts mclust) modX partXY
      MV.unsafeWrite (mparts mclust) modY Nothing
      MVU.unsafeWrite (mscore mclust) 0 score'
      MVU.unsafeModify (mnumcl mclust) pred 0
      forM_ (partElems <$> partY) $ \ys ->
        IntSet.foldr (\i m -> MVU.unsafeWrite (mindex mclust) i modX >> m)
                     (return ())
                     ys

  where f x y = confAt beta adj prox x y

clusteringCollector
  :: forall s (n :: Nat).
     KnownNat n
  => Double -- ^ beta
  -> SMatrix.Matrix n n Double -- ^ adjacency
  -> SMatrix.Matrix n n Double -- ^ proxemie
  -> MClustering s
  -> ST s (Dict IntSet, Double)
clusteringCollector beta adj prox mclust = do
  nclust <- MV.foldl'
    (\n_acc mpart ->
       if isNothing mpart
       then n_acc
       else 1+n_acc
    )
    0
    (mparts mclust)
  newClusts <- MV.unsafeNew nclust
  let go new_i _old_i Nothing  = return new_i
      go new_i old_i (Just p) = do
        MV.unsafeWrite newClusts new_i (Just p)
        return (new_i+1)
  MV.ifoldM' go 0 (mparts mclust)
  mat_delta <- DMatrix.newMatrix (negate maxDouble) nclust nclust
  forM_ [0..(nclust-1)] $ \i ->
    forM_ [(i+1)..(nclust-1)] $ \j -> do
      DMatrix.unsafeWriteMatrix mat_delta i j 0
      part_i <- MV.unsafeRead newClusts i
      part_j <- MV.unsafeRead newClusts j
      forPart part_i $ \x ->
        forPart part_j $ \y ->
          DMatrix.modifyMatrix mat_delta i j $ \a -> a + confAt beta adj prox x y
  delta0 <- MVU.unsafeRead (mscore mclust) 0
  let clusts = IntSet.fromList [0..(nclust-1)]
      argmaxRow i = do
        foldM (\acc@(max_j, !max_v) j -> do
                  v <- DMatrix.unsafeReadMatrix mat_delta i j
                  if v > max_v then return (j, v) else return acc
              )
              (-1, negate maxDouble)
              [(i+1)..(nclust-1)]
      bestPair cs = fmap (maximumBy (comparing (\(_, _, c) -> c))) $
        forM (IntSet.toAscList cs) $ \i ->
          (\(j, v) -> (i, j, v)) <$> argmaxRow i
      fusionRound !cs !merges !delta = do
        (good_i, good_j, delta') <- bestPair cs
        if delta' > 0 then do
          case IntSet.split good_i cs of
            (before_i, after_i) -> do
              forM_ (IntSet.toList before_i) $ \i -> do
                delta_ij <- DMatrix.unsafeReadMatrix mat_delta i good_j
                DMatrix.modifyMatrix mat_delta i good_i (+ delta_ij)
              forM_ (IntSet.toList after_i) $ \i -> do
                let (x, y) = if i < good_j then (i, good_j) else (good_j, i)
                when (x < y) $ do
                  delta_xy <- DMatrix.unsafeReadMatrix mat_delta x y
                  DMatrix.modifyMatrix mat_delta good_i i (+ delta_xy)
              case IntSet.split good_j cs of
                (before_j, _) -> forM_ (IntSet.toList before_j) $ \i ->
                  DMatrix.unsafeWriteMatrix mat_delta i good_j (negate maxDouble)
          let merges' = IntMap.insertWith IntSet.union good_i
                          (IntSet.insert good_j $ fromMaybe IntSet.empty $ IntMap.lookup good_j merges)
                          merges
          fusionRound (IntSet.delete good_j cs)
                      merges'
                      (delta + delta')
        else return (cs, merges, delta)
  (cs, merges, finalDelta) <- fusionRound clusts IntMap.empty delta0
  let groups = [ (i, maybe [] IntSet.toList (IntMap.lookup i merges) )
               | i <- IntSet.toList cs
               ]
  clustsDict <- fmap IntMap.unions . forM (IntSet.toList cs) $ \i ->
    IntMap.singleton i . maybe IntSet.empty partElems <$> MV.unsafeRead newClusts i
  c <- foldGroups groups clustsDict $ \dict (i, js) -> do
    sets_i <- traverse (fmap (maybe IntSet.empty partElems) . MV.unsafeRead newClusts) js
    return $! IntMap.insertWith IntSet.union i (IntSet.unions sets_i) dict
  return (c, finalDelta)

  where maxDouble = encodeFloat m k :: Double
        b = floatRadix  (0 :: Double)
        e = floatDigits (0 :: Double)
        (_, e') = floatRange (0 :: Double)
        m = b^e - 1
        k = e' - e
        n = fromIntegral $ natVal (Proxy :: Proxy n)
        foldGroups gs cs f = foldM f cs gs
        forPart Nothing _  = return ()
        forPart (Just p) f =
          IntSet.foldr (\i acc -> f i >> acc) (return ()) (partElems p)

data Clust a = Clust
  { cparts :: !(Dict IntSet)
  , cindex :: (Dict Int)
  , cscore :: !Double
  } deriving (Show, Eq)

runClustering
  :: forall (n :: Nat). KnownNat n
  => Bool -- ^ do we run the 'garbage collector'?
  -> Double -- ^ beta
  -> SMatrix.Matrix n n Double -- ^ adjacency
  -> SMatrix.Matrix n n Double -- ^ proxemie
  -> SortedEdges
  -> (Dict IntSet, Double)
runClustering gc beta adj prox se = runST $ do
  mclust <- newMClustering n
  forM_ se $ \(x, y, _) -> clusteringStep beta adj prox mclust (x, y)
  if gc
    then clusteringCollector beta adj prox mclust
    else do cps <- V.unsafeFreeze (mparts mclust)
            let cps' = Dict.fromList
                  [ (n, xs)
                  | (n, Just (PartData xs _)) <- zip [0..] (V.toList cps)
                  ]
            sc <- MVU.unsafeRead (mscore mclust) 0
            return (cps', sc)

  where n = fromIntegral $ natVal (Proxy :: Proxy n)

-- all the code below is unused for now

data Clustering a
  = ClusteringIs { parts :: !(Dict (Set a))
                 , index :: !(Dict Int)
                 , score :: !(Double)
                 , mode  :: !ClusteringMode
                 }
  | Clusterings { strict :: !(StrictClustering   a)
                , over   :: !(OverlapsClustering a)
                }
  deriving (Show, Eq)

data ClusteringMode = Part | Over | Both
  deriving (Show, Eq)

type StrictClustering   a = Clustering a
type OverlapsClustering a = Clustering a

make_clust_part :: forall n
                 . KnownNat n
                => SortedEdges
                -> SimilarityMatrix n
                -> Clustering Node
make_clust_part se sm =
  foldl' (\c (e1,e2,_) ->
            updateClustering c
              (\x y -> 2 * (sm' `DMatrix.at` (x, y)))
              e1 e2
         )
         (ClusteringIs parts idx 0 Part) se

  where
    ns    = [0 .. natToInt @n - 1 ]
    parts = Dict.fromAscList $ List.map (\n -> (n, Set.singleton n)) ns
    idx   = Dict.fromAscList (List.zip ns ns)
    sm'   = case sm of
      SimConf cm -> cm
      SimMod  mm -> mm

updateClustering :: Clustering Node
                 -> (Int -> Int -> Double)
                 -> Int -> Int -> Clustering Node
updateClustering c@(ClusteringIs parts idx currentScore _) f x y =
  let modX = fromMaybe (error "modX") $ Dict.lookup x idx
      modY = fromMaybe (error "modY") $ Dict.lookup y idx
   in case x == y || modX == modY of
    True  -> c -- do case x' or y' are Nothing
    False -> let c' = updateWith c f x modX y modY
              in case score c' >= currentScore of
                   True  -> c'
                   False -> c

updateWith :: Clustering Node
           -> (Int -> Int -> Double)
           -> Int -> Int
           -> Int -> Int
           -> Clustering Node
updateWith c@(ClusteringIs parts idx _ _) f x modX y modY =
  ClusteringIs parts' idx' score' Part
    where
      parts' = Dict.alter (add y) modX
             $ Dict.alter (del y) modY parts

      add y Nothing   = Just (Set.singleton y)
      add y (Just y') = Just (Set.insert y y')

      del y Nothing   = Nothing
      del y (Just y') = case Set.delete y y' of
        s | Set.null s -> Nothing
          | otherwise  -> Just s

      idx'         = Dict.alter (alter modX) y idx
      alter my _   = Just my

      score' = getSum $
        foldMap (\s -> fold [ Sum (f x'' y'')
                            | x'' <- Set.toList s
                            , y'' <- Set.toList s
                            , x'' < y''
                            ]
                ) parts'

---------------------------------------------------------------

------------------------------
make_clust_over :: KnownNat n
                => SimilarityMatrix n
                -> StrictClustering   a
                -> OverlapsClustering a
make_clust_over = Prelude.undefined

---------------------------------------------------------------
-- | Some Utils

-- | The Main Trick: Graph must be finite, then n can be known 1 time at least
--  The same n is needed the size of the Graph and Matrixes and Vectors
--  Proof by compilation of the compatibilities of the calculations
--  Then Proof by tests is less needed at this specific step of the flow.
buildFiniteGraph :: Proxy n -> Graph a b -> FiniteGraph n a b
buildFiniteGraph _ g = FiniteGraph g

-- | To execute a function on finite Graph
-- the KnownNat n is computed one time only
withG :: (Show a, Show b)
      => Graph a b
      -> (forall n. KnownNat n => FiniteGraph n a b -> r)
      -> r
withG g f = reifyNat (fromIntegral $ length $ nodes g)
          $ \n -> f (buildFiniteGraph n g)

natToInt :: forall (n :: Nat). KnownNat n => Int
natToInt = fromIntegral $ natVal (Proxy :: Proxy n)
