Commit 485666a2 authored by Alexandre Delanoë's avatar Alexandre Delanoë

[LEARN] Grid Search improved.

parent 06aa56b6
...@@ -60,7 +60,7 @@ import Gargantext.Database.Types.Node (CorpusId, ContactId) ...@@ -60,7 +60,7 @@ import Gargantext.Database.Types.Node (CorpusId, ContactId)
import Gargantext.Database.Utils -- (Cmd, CmdM) import Gargantext.Database.Utils -- (Cmd, CmdM)
import Gargantext.Prelude import Gargantext.Prelude
import Gargantext.API.Settings import Gargantext.API.Settings
import Gargantext.Text.Metrics import Gargantext.Text.Metrics (Scored(..))
import Gargantext.Viz.Graph hiding (Node)-- (Graph(_graph_metadata),LegendField(..), GraphMetadata(..),readGraphFromJson,defaultGraph) import Gargantext.Viz.Graph hiding (Node)-- (Graph(_graph_metadata),LegendField(..), GraphMetadata(..),readGraphFromJson,defaultGraph)
import Gargantext.Viz.Graph.Tools (cooc2graph) import Gargantext.Viz.Graph.Tools (cooc2graph)
import Servant import Servant
...@@ -69,7 +69,7 @@ import Test.QuickCheck.Arbitrary (Arbitrary, arbitrary) ...@@ -69,7 +69,7 @@ import Test.QuickCheck.Arbitrary (Arbitrary, arbitrary)
import qualified Data.Map as Map import qualified Data.Map as Map
import qualified Gargantext.Database.Node.Update as U (update, Update(..)) import qualified Gargantext.Database.Node.Update as U (update, Update(..))
{-- {-
import qualified Gargantext.Text.List.Learn as Learn import qualified Gargantext.Text.List.Learn as Learn
import qualified Data.Vector as Vec import qualified Data.Vector as Vec
--} --}
...@@ -408,10 +408,8 @@ getMetrics cId maybeListId tabType maybeLimit = do ...@@ -408,10 +408,8 @@ getMetrics cId maybeListId tabType maybeLimit = do
listType t m = maybe (panic errorMsg) fst $ Map.lookup t m listType t m = maybe (panic errorMsg) fst $ Map.lookup t m
errorMsg = "API.Node.metrics: key absent" errorMsg = "API.Node.metrics: key absent"
{-
let metrics' = Map.fromListWith (<>) $ map (\(Metric _ s1 s2 lt) -> (lt, [Vec.fromList [s1,s2]])) metrics
_ <- Learn.grid metrics'
--}
pure $ Metrics metrics pure $ Metrics metrics
...@@ -60,7 +60,7 @@ import Gargantext.Database.Types.Node -- (HyperdataDocument(..), NodeType(..), N ...@@ -60,7 +60,7 @@ import Gargantext.Database.Types.Node -- (HyperdataDocument(..), NodeType(..), N
import Gargantext.Database.Utils (Cmd, CmdM) import Gargantext.Database.Utils (Cmd, CmdM)
import Gargantext.Ext.IMT (toSchoolName) import Gargantext.Ext.IMT (toSchoolName)
import Gargantext.Prelude import Gargantext.Prelude
import Gargantext.Text.List (buildNgramsLists) import Gargantext.Text.List (buildNgramsLists,StopSize(..))
--import Gargantext.Text.Parsers (parseDocs, FileFormat) --import Gargantext.Text.Parsers (parseDocs, FileFormat)
import Gargantext.Text.Terms (TermType(..), tt_lang) import Gargantext.Text.Terms (TermType(..), tt_lang)
import Gargantext.Text.Terms (extractTerms) import Gargantext.Text.Terms (extractTerms)
...@@ -127,7 +127,7 @@ flowCorpusUser l userName corpusName ids = do ...@@ -127,7 +127,7 @@ flowCorpusUser l userName corpusName ids = do
-- User List Flow -- User List Flow
(_masterUserId, _masterRootId, masterCorpusId) <- getOrMkRootWithCorpus userMaster "" (_masterUserId, _masterRootId, masterCorpusId) <- getOrMkRootWithCorpus userMaster ""
ngs <- buildNgramsLists l 2 3 userCorpusId masterCorpusId ngs <- buildNgramsLists l 2 3 (StopSize 3) userCorpusId masterCorpusId
userListId <- flowList userId userCorpusId ngs userListId <- flowList userId userCorpusId ngs
printDebug "userListId" userListId printDebug "userListId" userListId
......
...@@ -25,20 +25,29 @@ Portability : POSIX ...@@ -25,20 +25,29 @@ Portability : POSIX
module Gargantext.Database.Lists where module Gargantext.Database.Lists where
import Control.Arrow (returnA) --import Control.Arrow (returnA)
--import Gargantext.API.Metrics
--import Gargantext.Core.Types.Individu (Username)
--import Gargantext.Database.Config (nodeTypeId)
--import Gargantext.Database.Schema.Node -- (HasNodeError, queryNodeTable)
--import Gargantext.Database.Schema.User -- (queryUserTable)
--import Gargantext.Database.Utils
--import Opaleye hiding (FromField)
--import Opaleye.Internal.QueryArr (Query)
import Gargantext.API.Ngrams (TabType(..))
import Gargantext.Core.Types -- (NodePoly(..), NodeCorpus, ListId) import Gargantext.Core.Types -- (NodePoly(..), NodeCorpus, ListId)
import Gargantext.Core.Types.Individu (Username) import Gargantext.Database.Flow (FlowCmdM)
import Gargantext.Database.Config (nodeTypeId)
import Gargantext.Database.Schema.Node -- (HasNodeError, queryNodeTable)
import Gargantext.Database.Schema.User -- (queryUserTable)
import Gargantext.Database.Utils
import Gargantext.Prelude hiding (sum, head) import Gargantext.Prelude hiding (sum, head)
import Opaleye hiding (FromField) import Gargantext.Text.Metrics (Scored(..))
import Opaleye.Internal.QueryArr (Query)
import Prelude hiding (null, id, map, sum) import Prelude hiding (null, id, map, sum)
import Servant (ServantErr)
import qualified Data.Map as Map
import qualified Data.Vector as Vec
import qualified Gargantext.Database.Metrics as Metrics
-- | To get all lists of a user -- | To get all lists of a user
-- /!\ lists of different types of corpora (Annuaire or Documents) -- /!\ lists of different types of corpora (Annuaire or Documents)
{-
listsWith :: HasNodeError err => Username -> Cmd err [Maybe ListId] listsWith :: HasNodeError err => Username -> Cmd err [Maybe ListId]
listsWith u = runOpaQuery (selectLists u) listsWith u = runOpaQuery (selectLists u)
where where
...@@ -53,7 +62,6 @@ listsWithJoin2 = leftJoin queryUserTable queryNodeTable cond12 ...@@ -53,7 +62,6 @@ listsWithJoin2 = leftJoin queryUserTable queryNodeTable cond12
where where
cond12 (u,n) = user_id u .== _node_userId n cond12 (u,n) = user_id u .== _node_userId n
{-
listsWithJoin3 :: Query (NodeRead, (UserRead, NodeReadNull)) listsWithJoin3 :: Query (NodeRead, (UserRead, NodeReadNull))
listsWithJoin3 = leftJoin3 queryUserTable queryNodeTable queryNodeTable cond12 cond23 listsWithJoin3 = leftJoin3 queryUserTable queryNodeTable queryNodeTable cond12 cond23
where where
...@@ -61,5 +69,23 @@ listsWithJoin3 = leftJoin3 queryUserTable queryNodeTable queryNodeTable cond12 c ...@@ -61,5 +69,23 @@ listsWithJoin3 = leftJoin3 queryUserTable queryNodeTable queryNodeTable cond12 c
cond12 (u,n) = user_id u .== _node_userId n cond12 (u,n) = user_id u .== _node_userId n
cond23 :: (NodeRead, (UserRead, NodeReadNull)) -> Column PGBool cond23 :: (NodeRead, (UserRead, NodeReadNull)) -> Column PGBool
cond23 (n1,(u,n2)) = (toNullable $ _node_id n1) .== _node_parentId n2 cond23 (n1,(u,n2)) = (toNullable $ _node_id n1) .== _node_parentId n2
--} --}
learnMetrics' :: FlowCmdM env ServantErr m
=> CorpusId -> Maybe ListId -> TabType -> Maybe Int
-> m (Map.Map ListType [Vec.Vector Double])
learnMetrics' cId maybeListId tabType maybeLimit = do
(ngs', scores) <- Metrics.getMetrics' cId maybeListId tabType maybeLimit
let
metrics = map (\(Scored t s1 s2) -> (listType t ngs', [Vec.fromList [s1,s2]])) scores
listType t m = maybe (panic errorMsg) fst $ Map.lookup t m
errorMsg = "API.Node.metrics: key absent"
{-
_ <- Learn.grid 100 110 metrics' metrics'
--}
pure $ Map.fromListWith (<>) metrics
...@@ -68,7 +68,6 @@ getLocalMetrics cId maybeListId tabType maybeLimit = do ...@@ -68,7 +68,6 @@ getLocalMetrics cId maybeListId tabType maybeLimit = do
pure (ngs, ngs', localMetrics myCooc) pure (ngs, ngs', localMetrics myCooc)
getNgramsCooc :: (FlowCmdM env ServantErr m) getNgramsCooc :: (FlowCmdM env ServantErr m)
=> CorpusId -> Maybe ListId -> TabType -> Maybe Limit => CorpusId -> Maybe ListId -> TabType -> Maybe Limit
-> m ( Map Text (ListType, Maybe Text) -> m ( Map Text (ListType, Maybe Text)
......
...@@ -50,3 +50,4 @@ selectRoot username = proc () -> do ...@@ -50,3 +50,4 @@ selectRoot username = proc () -> do
restrict -< _node_userId row .== (user_id users) restrict -< _node_userId row .== (user_id users)
returnA -< row returnA -< row
...@@ -25,6 +25,7 @@ import Gargantext.Core.Types (ListType(..), MasterCorpusId, UserCorpusId) ...@@ -25,6 +25,7 @@ import Gargantext.Core.Types (ListType(..), MasterCorpusId, UserCorpusId)
import Gargantext.Database.Metrics.NgramsByNode (getTficf', sortTficf, ngramsGroup, getNodesByNgramsUser, groupNodesByNgramsWith) import Gargantext.Database.Metrics.NgramsByNode (getTficf', sortTficf, ngramsGroup, getNodesByNgramsUser, groupNodesByNgramsWith)
import Gargantext.Database.Schema.Ngrams (NgramsType(..)) import Gargantext.Database.Schema.Ngrams (NgramsType(..))
import Gargantext.Database.Utils (Cmd) import Gargantext.Database.Utils (Cmd)
import Gargantext.Text.List.Learn (Model(..))
import Gargantext.Prelude import Gargantext.Prelude
--import Gargantext.Text.Terms (TermType(..)) --import Gargantext.Text.Terms (TermType(..))
import qualified Data.Char as Char import qualified Data.Char as Char
...@@ -33,11 +34,23 @@ import qualified Data.Map as Map ...@@ -33,11 +34,23 @@ import qualified Data.Map as Map
import qualified Data.Set as Set import qualified Data.Set as Set
import qualified Data.Text as Text import qualified Data.Text as Text
data NgramsListBuilder = BuilderStepO { stemSize :: Int
, stemX :: Int
, stopSize :: Int
}
| BuilderStep1 { withModel :: Model }
| BuilderStepN { withModel :: Model }
data StopSize = StopSize {unStopSize :: Int}
-- | TODO improve grouping functions of Authors, Sources, Institutes.. -- | TODO improve grouping functions of Authors, Sources, Institutes..
buildNgramsLists :: Lang -> Int -> Int -> UserCorpusId -> MasterCorpusId buildNgramsLists :: Lang -> Int -> Int -> StopSize -> UserCorpusId -> MasterCorpusId
-> Cmd err (Map NgramsType [NgramsElement]) -> Cmd err (Map NgramsType [NgramsElement])
buildNgramsLists l n m uCid mCid = do buildNgramsLists l n m s uCid mCid = do
ngTerms <- buildNgramsTermsList l n m uCid mCid ngTerms <- buildNgramsTermsList l n m s uCid mCid
othersTerms <- mapM (buildNgramsOthersList uCid identity) [Authors, Sources, Institutes] othersTerms <- mapM (buildNgramsOthersList uCid identity) [Authors, Sources, Institutes]
pure $ Map.unions $ othersTerms <> [ngTerms] pure $ Map.unions $ othersTerms <> [ngTerms]
...@@ -54,13 +67,13 @@ buildNgramsOthersList uCid groupIt nt = do ...@@ -54,13 +67,13 @@ buildNgramsOthersList uCid groupIt nt = do
] ]
-- TODO remove hard coded parameters -- TODO remove hard coded parameters
buildNgramsTermsList :: Lang -> Int -> Int -> UserCorpusId -> MasterCorpusId buildNgramsTermsList :: Lang -> Int -> Int -> StopSize -> UserCorpusId -> MasterCorpusId
-> Cmd err (Map NgramsType [NgramsElement]) -> Cmd err (Map NgramsType [NgramsElement])
buildNgramsTermsList l n m uCid mCid = do buildNgramsTermsList l n m s uCid mCid = do
candidates <- sortTficf <$> getTficf' uCid mCid NgramsTerms (ngramsGroup l n m) candidates <- sortTficf <$> getTficf' uCid mCid NgramsTerms (ngramsGroup l n m)
--printDebug "candidate" (length candidates) --printDebug "candidate" (length candidates)
let termList = toTermList (isStopTerm . fst) candidates let termList = toTermList ((isStopTerm s) . fst) candidates
--let termList = toTermList ((\_ -> False) . fst) candidates --let termList = toTermList ((\_ -> False) . fst) candidates
--printDebug "termlist" (length termList) --printDebug "termlist" (length termList)
...@@ -103,7 +116,7 @@ toTermList stop ns = map (toTermList' stop CandidateTerm) xs ...@@ -103,7 +116,7 @@ toTermList stop ns = map (toTermList' stop CandidateTerm) xs
a = 3 a = 3
b = 400 b = 400
isStopTerm :: Text -> Bool isStopTerm :: StopSize -> Text -> Bool
isStopTerm x = Text.length x < 3 || any isStopChar (Text.unpack x) isStopTerm (StopSize n) x = Text.length x < n || any isStopChar (Text.unpack x)
where where
isStopChar c = not (c `elem` ("- /()" :: [Char]) || Char.isAlpha c) isStopChar c = not (c `elem` ("- /()" :: [Char]) || Char.isAlpha c)
...@@ -57,21 +57,25 @@ trainList x y = (train x y) . trainList' ...@@ -57,21 +57,25 @@ trainList x y = (train x y) . trainList'
vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList) vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
predictList :: SVM.Model -> [Vec.Vector Double] -> IO [Maybe ListType] predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
predictList m vs = map (fromListTypeId . round) <$> predict m vs predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
------------------------------------------------------------------------ ------------------------------------------------------------------------
data Model = ModelSVM { model :: SVM.Model } data Model = ModelSVM { modelSVM :: SVM.Model
, param1 :: Maybe Double
, param2 :: Maybe Double
}
--{-
instance SaveFile Model instance SaveFile Model
where where
saveFile' fp (ModelSVM m) = SVM.saveModel m fp saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
instance ReadFile Model instance ReadFile Model
where where
readFile' fp = do readFile' fp = do
m <- SVM.loadModel fp m <- SVM.loadModel fp
pure $ ModelSVM m pure $ ModelSVM m Nothing Nothing
--}
------------------------------------------------------------------------ ------------------------------------------------------------------------
-- | TODO -- | TODO
-- shuffle list -- shuffle list
...@@ -80,43 +84,53 @@ instance ReadFile Model ...@@ -80,43 +84,53 @@ instance ReadFile Model
type Train = Map ListType [Vec.Vector Double] type Train = Map ListType [Vec.Vector Double]
type Tests = Map ListType [Vec.Vector Double] type Tests = Map ListType [Vec.Vector Double]
type Score = Double
type Param = Double
grid :: (MonadReader env m, MonadIO m, HasSettings env) grid :: (MonadReader env m, MonadIO m, HasSettings env)
=> (Train, Tests) -> m () -- Map (ListType, Maybe ListType) Int) => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
grid (m,_) = do grid s e tr te = do
let let
grid' :: (MonadReader env m, MonadIO m, HasSettings env) grid' :: (MonadReader env m, MonadIO m, HasSettings env)
=> Double -> Double => Double -> Double
-> Map ListType [Vec.Vector Double] -> Train
-> m (Double, (Double,Double)) -> [Tests]
grid' x y ls = do -> m (Score, Model)
model' <- liftIO $ trainList x y ls grid' x y tr' te' = do
--fp <- saveFile (ModelSVM model') model'' <- liftIO $ trainList x y tr'
--printDebug "file" fp
let (res, toGuess) = List.unzip $ List.concat let
$ map (\(k,vs) -> zip (repeat k) vs) model' = ModelSVM model'' (Just x) (Just y)
$ Map.toList ls
res' <- liftIO $ predictList model' toGuess score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
pure (score'' $ score' $ List.zip res res', (x,y)) score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
{- score'' :: Map (Maybe Bool) Int -> Double
score :: [(ListType, Maybe ListType)] -> Map (ListType, Maybe ListType) Int score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
score = occurrencesWith identity where
-} total = fromIntegral $ foldl (+) 0 $ Map.elems m''
score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int getScore m t = do
score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b) let (res, toGuess) = List.unzip $ List.concat
$ map (\(k,vs) -> zip (repeat k) vs)
score'' :: Map (Maybe Bool) Int -> Double $ Map.toList t
score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
where res' <- liftIO $ predictList m toGuess
total = fromIntegral $ foldl (+) 0 $ Map.elems m'' pure $ score'' $ score' $ List.zip res res'
r <- List.take 10 . List.reverse score <- mapM (getScore model') te'
. (List.sortOn fst) pure (mean score, model')
<$> mapM (\(x,y) -> grid' x y m) [(x,y) | x <- [500..510], y <- [500..510]]
r <- head . List.reverse
printDebug "GRID SEARCH" r . (List.sortOn fst)
-- save best result <$> mapM (\(x,y) -> grid' x y tr te)
[(x,y) | x <- [s..e], y <- [s..e]]
printDebug "GRID SEARCH" (map fst r)
--printDebug "file" fp
--fp <- saveFile (ModelSVM model')
--save best result
pure $ snd <$> r
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment