From af8d8c8c59e2b4d7949c51ccb4411122028ba43a Mon Sep 17 00:00:00 2001
From: Kai Zhang <kai@kzhang.org>
Date: Thu, 19 Apr 2018 14:33:52 -0700
Subject: [PATCH] rewrite cereal instances

---
 haskell-igraph.cabal     |  10 ++--
 src/IGraph.hs            | 107 ++++++++++++++++++++++++++++++++-------
 src/IGraph/Mutable.hs    |  17 +++----
 src/IGraph/Types.hs      |   2 +-
 stack.yaml               |   4 +-
 tests/Test/Attributes.hs |   6 ++-
 tests/Test/Basic.hs      |  14 +++--
 7 files changed, 119 insertions(+), 41 deletions(-)

diff --git a/haskell-igraph.cabal b/haskell-igraph.cabal
index eeb6288..5cdb4c3 100644
--- a/haskell-igraph.cabal
+++ b/haskell-igraph.cabal
@@ -7,7 +7,7 @@ license:             MIT
 license-file:        LICENSE
 author:              Kai Zhang
 maintainer:          kai@kzhang.org
-copyright:           (c) 2016-2017 Kai Zhang
+copyright:           (c) 2016-2018 Kai Zhang
 category:            Math
 build-type:          Simple
 cabal-version:       >=1.24
@@ -56,11 +56,12 @@ library
     build-depends: diagrams-lib, diagrams-cairo
 
   build-depends:
-      base >=4.0 && <5.0
-    , bytestring >=0.9
-    , bytestring-lexing >=0.5
+      base >= 4.0 && < 5.0
+    , bytestring >= 0.9
+    , bytestring-lexing >= 0.5
     , cereal
     , colour
+    , conduit >= 1.3.0
     , primitive
     , unordered-containers
     , hashable
@@ -95,6 +96,7 @@ test-suite tests
       base
     , haskell-igraph
     , cereal
+    , conduit >= 1.3.0
     , data-ordlist
     , matrices
     , tasty
diff --git a/src/IGraph.hs b/src/IGraph.hs
index 6705c5b..c435491 100644
--- a/src/IGraph.hs
+++ b/src/IGraph.hs
@@ -5,9 +5,12 @@ module IGraph
     , U(..)
     , D(..)
     , Graph(..)
+--    , encodeC
+ --   , decodeC
     , empty
     , mkGraph
     , fromLabeledEdges
+    , fromLabeledEdges'
 
     , unsafeFreeze
     , freeze
@@ -27,14 +30,18 @@ module IGraph
     , emap
     ) where
 
+import           Conduit
 import           Control.Arrow             ((***))
-import           Control.Monad             (forM, forM_, liftM)
+import           Control.Monad             (forM, forM_, liftM, unless, replicateM)
 import           Control.Monad.Primitive
 import           Control.Monad.ST          (runST)
+import qualified Data.ByteString           as B
 import           Data.Hashable             (Hashable)
 import qualified Data.HashMap.Strict       as M
 import qualified Data.HashSet              as S
+import           Data.List                 (sortBy)
 import           Data.Maybe
+import           Data.Ord                  (comparing)
 import           Data.Serialize
 import           Foreign                   (with)
 import           System.IO.Unsafe          (unsafePerformIO)
@@ -105,12 +112,15 @@ class MGraph d => Graph d where
             else Nothing
     {-# INLINE edgeLabMaybe #-}
 
+    getEdgeByEid :: LGraph d v e -> Int -> Edge
+    getEdgeByEid gr@(LGraph g _) i = unsafePerformIO $ igraphEdge g i
+    {-# INLINE getEdgeByEid #-}
+
     edgeLabByEid :: Serialize e => LGraph d v e -> Int -> e
     edgeLabByEid (LGraph g _) i = unsafePerformIO $
         igraphHaskellAttributeEAS g edgeAttr i >>= fromBS
     {-# INLINE edgeLabByEid #-}
 
-
 instance Graph U where
     isDirected = const False
     isD = const False
@@ -119,31 +129,48 @@ instance Graph D where
     isDirected = const True
     isD = const True
 
-instance (Graph d, Serialize v, Serialize e, Hashable v, Eq v) => Serialize (LGraph d v e) where
-    put gr = do
-        put nlabs
-        put es
-        put elabs
-      where
-        nlabs = map (nodeLab gr) $ nodes gr
-        es = edges gr
-        elabs = map (edgeLab gr) es
-    get = do
-        nlabs <- get
-        es <- get
-        elabs <- get
-        return $ mkGraph nlabs $ zip es elabs
+instance (Graph d, Serialize v, Serialize e, Hashable v, Eq v)
+    => Serialize (LGraph d v e) where
+        put gr = do
+            put $ nNodes gr
+            go (nodeLab gr) (nNodes gr) 0
+            put $ nEdges gr
+            go (\i -> (getEdgeByEid gr i, edgeLabByEid gr i)) (nEdges gr) 0
+          where
+            go f n i | i >= n = return ()
+                     | otherwise = put (f i) >> go f n (i+1)
+        get = do
+            nn <- get
+            nds <- replicateM nn get
+            ne <- get
+            es <- replicateM ne get
+            return $ mkGraph nds es
+
+        {-
+encodeC :: (Monad m, Graph d, Serialize v, Serialize e, Hashable v, Eq v)
+        => LGraph d v e -> ConduitT i B.ByteString m ()
+encodeC gr = do
+    sourcePut $ put (M.toList $ _labelToNode gr)
+    yieldMany (edges gr) .| mapC (\e -> (e, edgeLab gr e)) .| conduitPut put
+
+decodeC :: ( PrimMonad m, MonadThrow m, Graph d
+           , Serialize v, Serialize e, Hashable v, Eq v )
+        => ConduitT B.ByteString o m (LGraph d v e)
+decodeC = do
+    labelToId <- M.fromList <$> sinkGet get
+    conduitGet2 get .| deserializeGraphFromEdges 10000 labelToId
+    -}
 
 empty :: (Graph d, Hashable v, Serialize v, Eq v, Serialize e)
       => LGraph d v e
 empty = runST $ new 0 >>= unsafeFreeze
 
 mkGraph :: (Graph d, Hashable v, Serialize v, Eq v, Serialize e)
-        => [v] -> [(Edge, e)] -> LGraph d v e
+        => [v] -> [LEdge e] -> LGraph d v e
 mkGraph vattr es = runST $ do
     g <- new 0
-    addLNodes n vattr g
-    addLEdges (map (\((fr,to),x) -> (fr,to,x)) es) g
+    addLNodes vattr g
+    addLEdges es g
     unsafeFreeze g
   where
     n = length vattr
@@ -157,6 +184,48 @@ fromLabeledEdges es = mkGraph labels es'
     labels = S.toList $ S.fromList $ concat [ [a,b] | ((a,b),_) <- es ]
     labelToId = M.fromList $ zip labels [0..]
 
+-- | Deserialize a graph.
+fromLabeledEdges' :: (PrimMonad m, Graph d, Hashable v, Serialize v, Eq v, Serialize e)
+                  => Int -- ^ buffer size
+                  -> a    -- ^ Input, usually a file
+                  -> (a -> ConduitT () ((v, v), e) m ())  -- ^ deserialize the input into a stream of edges
+                  -> m (LGraph d v e)
+fromLabeledEdges' bufferN input mkConduit = do
+    (labelToId, _) <- runConduit $ mkConduit input .| foldlC f (M.empty, 0::Int)
+    let getId x = M.lookupDefault undefined x labelToId
+    runConduit $ mkConduit input .|
+        mapC (\((v1, v2), e) -> ((getId v1, getId v2), e)) .|
+        deserializeGraph bufferN
+            (fst $ unzip $ sortBy (comparing snd) $ M.toList labelToId)
+  where
+    f acc ((v1, v2), _) = add v1 $ add v2 acc
+      where
+        add v (m, i) = if v `M.member` m
+            then (m, i)
+            else (M.insert v i m, i + 1)
+
+deserializeGraph :: ( PrimMonad m, Graph d, Hashable v, Serialize v
+                    , Eq v, Serialize e )
+                 => Int -- ^ buffer size
+                 -> [v]
+                 -> ConduitT (LEdge e) o m (LGraph d v e)
+deserializeGraph bufferN nds = mkChunks bufferN .| buildGraph
+  where
+    buildGraph = do
+        gr <- new 0
+        addLNodes nds gr
+        mapM_C (\es -> addLEdges es gr)
+        unsafeFreeze gr
+    mkChunks n = do
+        isEmpty <- nullC
+        unless isEmpty $ do
+            go 0 >>= yield
+            mkChunks n
+      where
+        go i | i >= n = return []
+             | otherwise = await >>= maybe (return []) (\x -> fmap (x :) $ go (i+1))
+{-# INLINE deserializeGraph #-}
+
 unsafeFreeze :: (Hashable v, Eq v, Serialize v, PrimMonad m)
              => MLGraph (PrimState m) d v e -> m (LGraph d v e)
 unsafeFreeze (MLGraph g) = unsafePrimToPrim $ do
diff --git a/src/IGraph/Mutable.hs b/src/IGraph/Mutable.hs
index 09a8532..15b5c44 100644
--- a/src/IGraph/Mutable.hs
+++ b/src/IGraph/Mutable.hs
@@ -43,15 +43,14 @@ class MGraph d where
     addNodes n (MLGraph g) = unsafePrimToPrim $ igraphAddVertices g n nullPtr
 
     addLNodes :: (Serialize v, PrimMonad m)
-              => Int  -- ^ the number of new vertices add to the graph
-              -> [v]  -- ^ vertices' labels
+              => [v]  -- ^ vertices' labels
               -> MLGraph (PrimState m) d v e -> m ()
-    addLNodes n labels (MLGraph g)
-        | n /= length labels = error "addLVertices: incorrect number of labels"
-        | otherwise = unsafePrimToPrim $ withVertexAttr $ \vattr ->
-            asBSVector labels $ \bsvec -> with (mkStrRec vattr bsvec) $ \ptr -> do
-                vptr <- fromPtrs [castPtr ptr]
-                withVectorPtr vptr (igraphAddVertices g n . castPtr)
+    addLNodes labels (MLGraph g) = unsafePrimToPrim $ withVertexAttr $
+        \vattr -> asBSVector labels $ \bsvec -> with (mkStrRec vattr bsvec) $
+            \ptr -> do vptr <- fromPtrs [castPtr ptr]
+                       withVectorPtr vptr (igraphAddVertices g n . castPtr)
+      where
+        n = length labels
 
     delNodes :: PrimMonad m => [Int] -> MLGraph (PrimState m) d v e -> m ()
     delNodes ns (MLGraph g) = unsafePrimToPrim $ do
@@ -74,7 +73,7 @@ class MGraph d where
             vptr <- fromPtrs [castPtr ptr]
             withVectorPtr vptr (igraphAddEdges g vec . castPtr)
       where
-        (xs, vs) = unzip $ map ( \(a,b,v) -> ([fromIntegral a, fromIntegral b], v) ) es
+        (xs, vs) = unzip $ map ( \((a,b),v) -> ([fromIntegral a, fromIntegral b], v) ) es
 
     delEdges :: PrimMonad m => [(Int, Int)] -> MLGraph (PrimState m) d v e -> m ()
 
diff --git a/src/IGraph/Types.hs b/src/IGraph/Types.hs
index 24d0147..8092dbb 100644
--- a/src/IGraph/Types.hs
+++ b/src/IGraph/Types.hs
@@ -7,7 +7,7 @@ import           IGraph.Internal.Graph
 
 type Node = Int
 type Edge = (Node, Node)
-type LEdge a = (Int, Int, a)
+type LEdge a = (Edge, a)
 
 data U = U
 data D = D
diff --git a/stack.yaml b/stack.yaml
index e0e5b65..93eba36 100644
--- a/stack.yaml
+++ b/stack.yaml
@@ -1,10 +1,10 @@
 flags:
   haskell-igraph:
-    graphics: true
+    graphics: false
 
 packages:
     - '.'
 
 extra-deps: []
 
-resolver: lts-10.10
+resolver: nightly-2018-04-19
diff --git a/tests/Test/Attributes.hs b/tests/Test/Attributes.hs
index d56febb..289314a 100644
--- a/tests/Test/Attributes.hs
+++ b/tests/Test/Attributes.hs
@@ -2,6 +2,7 @@ module Test.Attributes
     ( tests
     ) where
 
+import           Conduit
 import           Control.Monad
 import           Control.Monad.ST
 import           Data.List
@@ -53,4 +54,7 @@ serializeTest = testCase "serialize test" $ do
             Left msg -> error msg
             Right r  -> r
         es' = map (\(a,b) -> ((nodeLab gr' a, nodeLab gr' b), edgeLab gr' (a,b))) $ edges gr'
-    assertBool "" $ sort (map show es) == sort (map show es')
+    gr'' <- runConduit $ encodeC gr .| decodeC :: IO (LGraph D NodeAttr EdgeAttr)
+    let es'' = map (\(a,b) -> ((nodeLab gr'' a, nodeLab gr'' b), edgeLab gr'' (a,b))) $ edges gr''
+    assertBool "" $ sort (map show es) == sort (map show es') &&
+        sort (map show es) == sort (map show es'')
diff --git a/tests/Test/Basic.hs b/tests/Test/Basic.hs
index 8bb82d7..bde87cf 100644
--- a/tests/Test/Basic.hs
+++ b/tests/Test/Basic.hs
@@ -10,6 +10,7 @@ import           System.IO.Unsafe
 import           Test.Tasty
 import           Test.Tasty.HUnit
 import           Test.Utils
+import Conduit
 
 import           IGraph
 import           IGraph.Mutable
@@ -39,14 +40,17 @@ graphCreationLabeled :: TestTree
 graphCreationLabeled = testGroup "Graph creation -- with labels"
     [ testCase "" $ assertBool "" $ nNodes gr == n && nEdges gr == m
     , testCase "" $ edgeList @=? (sort $ map (\(fr,to) ->
-        (nodeLab gr fr, nodeLab gr to)) $ edges gr)
+        ((nodeLab gr fr, nodeLab gr to), edgeLab gr (fr, to))) $ edges gr)
+    , testCase "" $ edgeList @=? (sort $ map (\(fr,to) ->
+       ((nodeLab gr' fr, nodeLab gr' to), edgeLab gr' (fr, to))) $ edges gr')
     ]
   where
-    edgeList = sort $ map (\(a,b) -> (show a, show b)) $ unsafePerformIO $
-        randEdges 10000 1000
-    n = length $ nubSort $ concatMap (\(a,b) -> [a,b]) edgeList
+    edgeList = zip (sort $ map (\(a,b) -> (show a, show b)) $ unsafePerformIO $
+        randEdges 10000 1000) $ repeat 1
+    n = length $ nubSort $ concatMap (\((a,b),_) -> [a,b]) edgeList
     m = length edgeList
-    gr = fromLabeledEdges $ zip edgeList $ repeat () :: LGraph D String ()
+    gr = fromLabeledEdges edgeList :: LGraph D String Int
+    gr' = runST $ fromLabeledEdges' 10 edgeList yieldMany :: LGraph D String Int
 
 graphEdit :: TestTree
 graphEdit = testGroup "Graph editing"
-- 
2.21.0