Learn.hs 4.4 KB
{-|
Module      : Gargantext.Text.List.Learn
Description : Learn to make lists
Copyright   : (c) CNRS, 2018-Present
License     : AGPL + CECILL v3
Maintainer  : team@gargantext.org
Stability   : experimental
Portability : POSIX

CSV parser for Gargantext corpus files.

-}

{-# OPTIONS_GHC -fno-warn-orphans #-}

{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}

module Gargantext.Text.List.Learn
  where

import Control.Monad.Reader (MonadReader)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Gargantext.API.Settings
import Data.Map (Map)
import Data.Maybe (maybe)
import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
import Gargantext.Prelude
import Gargantext.Prelude.Utils
import Gargantext.Text.Metrics.Count (occurrencesWith)
import qualified Data.IntMap as IntMap
import qualified Data.List   as List
import qualified Data.Map    as Map
import qualified Data.SVM    as SVM
import qualified Data.Vector as Vec

------------------------------------------------------------------------
train :: Double -> Double -> SVM.Problem -> IO SVM.Model
train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))

predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
predict m vs = mapM (predict' m) vs
  where
    predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')

------------------------------------------------------------------------
trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
trainList x y = (train x y) . trainList'
  where
    trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
    trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))

    mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
    mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)

    vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
    vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)


predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs

------------------------------------------------------------------------
data Model = ModelSVM { modelSVM :: SVM.Model
                      , param1 :: Maybe Double
                      , param2 :: Maybe Double
                      }
--{-
instance SaveFile Model
  where
    saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp

instance ReadFile Model
  where
    readFile' fp = do
      m <- SVM.loadModel fp
      pure $ ModelSVM m Nothing Nothing
--}
------------------------------------------------------------------------
-- | TODO
-- shuffle list
-- split list : train / test
-- grid parameters on best result on test

type Train = Map ListType [Vec.Vector Double]
type Tests = Map ListType [Vec.Vector Double]
type Score = Double
type Param = Double

grid :: (MonadReader env m, MonadIO m, HasSettings env)
     => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
grid _ _ _ []  = panic "Gargantext.Text.List.Learn.grid : empty test data"
grid s e tr te = do
  let
    grid' :: (MonadReader env m, MonadIO m, HasSettings env)
          => Double -> Double
          -> Train
          -> [Tests]
          -> m (Score, Model)
    grid' x y tr' te' = do
      model'' <- liftIO $ trainList x y tr'
      
      let
        model' = ModelSVM model'' (Just x) (Just y)

        score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
        score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)

        score'' :: Map (Maybe Bool) Int -> Double
        score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
          where
            total = fromIntegral $ foldl (+) 0 $ Map.elems m''

        getScore m t = do
          let (res, toGuess) = List.unzip
                             $ List.concat
                             $ map (\(k,vs) -> zip (repeat k) vs)
                             $ Map.toList t
        
          res' <- liftIO $ predictList m toGuess
          pure $ score'' $ score' $ List.zip res res' 
      
      score <- mapM (getScore model') te'
      pure (mean score, model')

  r <- head . List.reverse
            . (List.sortOn fst)
           <$> mapM (\(x,y) -> grid' x y tr te)
                         [(x,y) | x <- [s..e], y <- [s..e]]

  printDebug "GRID SEARCH" (map fst r)
  --printDebug "file" fp
  --fp <- saveFile (ModelSVM model')
  --save best result
  pure $ snd <$> r