From af8d8c8c59e2b4d7949c51ccb4411122028ba43a Mon Sep 17 00:00:00 2001 From: Kai Zhang <kai@kzhang.org> Date: Thu, 19 Apr 2018 14:33:52 -0700 Subject: [PATCH] rewrite cereal instances --- haskell-igraph.cabal | 10 ++-- src/IGraph.hs | 107 ++++++++++++++++++++++++++++++++------- src/IGraph/Mutable.hs | 17 +++---- src/IGraph/Types.hs | 2 +- stack.yaml | 4 +- tests/Test/Attributes.hs | 6 ++- tests/Test/Basic.hs | 14 +++-- 7 files changed, 119 insertions(+), 41 deletions(-) diff --git a/haskell-igraph.cabal b/haskell-igraph.cabal index eeb6288..5cdb4c3 100644 --- a/haskell-igraph.cabal +++ b/haskell-igraph.cabal @@ -7,7 +7,7 @@ license: MIT license-file: LICENSE author: Kai Zhang maintainer: kai@kzhang.org -copyright: (c) 2016-2017 Kai Zhang +copyright: (c) 2016-2018 Kai Zhang category: Math build-type: Simple cabal-version: >=1.24 @@ -56,11 +56,12 @@ library build-depends: diagrams-lib, diagrams-cairo build-depends: - base >=4.0 && <5.0 - , bytestring >=0.9 - , bytestring-lexing >=0.5 + base >= 4.0 && < 5.0 + , bytestring >= 0.9 + , bytestring-lexing >= 0.5 , cereal , colour + , conduit >= 1.3.0 , primitive , unordered-containers , hashable @@ -95,6 +96,7 @@ test-suite tests base , haskell-igraph , cereal + , conduit >= 1.3.0 , data-ordlist , matrices , tasty diff --git a/src/IGraph.hs b/src/IGraph.hs index 6705c5b..c435491 100644 --- a/src/IGraph.hs +++ b/src/IGraph.hs @@ -5,9 +5,12 @@ module IGraph , U(..) , D(..) , Graph(..) +-- , encodeC + -- , decodeC , empty , mkGraph , fromLabeledEdges + , fromLabeledEdges' , unsafeFreeze , freeze @@ -27,14 +30,18 @@ module IGraph , emap ) where +import Conduit import Control.Arrow ((***)) -import Control.Monad (forM, forM_, liftM) +import Control.Monad (forM, forM_, liftM, unless, replicateM) import Control.Monad.Primitive import Control.Monad.ST (runST) +import qualified Data.ByteString as B import Data.Hashable (Hashable) import qualified Data.HashMap.Strict as M import qualified Data.HashSet as S +import Data.List (sortBy) import Data.Maybe +import Data.Ord (comparing) import Data.Serialize import Foreign (with) import System.IO.Unsafe (unsafePerformIO) @@ -105,12 +112,15 @@ class MGraph d => Graph d where else Nothing {-# INLINE edgeLabMaybe #-} + getEdgeByEid :: LGraph d v e -> Int -> Edge + getEdgeByEid gr@(LGraph g _) i = unsafePerformIO $ igraphEdge g i + {-# INLINE getEdgeByEid #-} + edgeLabByEid :: Serialize e => LGraph d v e -> Int -> e edgeLabByEid (LGraph g _) i = unsafePerformIO $ igraphHaskellAttributeEAS g edgeAttr i >>= fromBS {-# INLINE edgeLabByEid #-} - instance Graph U where isDirected = const False isD = const False @@ -119,31 +129,48 @@ instance Graph D where isDirected = const True isD = const True -instance (Graph d, Serialize v, Serialize e, Hashable v, Eq v) => Serialize (LGraph d v e) where - put gr = do - put nlabs - put es - put elabs - where - nlabs = map (nodeLab gr) $ nodes gr - es = edges gr - elabs = map (edgeLab gr) es - get = do - nlabs <- get - es <- get - elabs <- get - return $ mkGraph nlabs $ zip es elabs +instance (Graph d, Serialize v, Serialize e, Hashable v, Eq v) + => Serialize (LGraph d v e) where + put gr = do + put $ nNodes gr + go (nodeLab gr) (nNodes gr) 0 + put $ nEdges gr + go (\i -> (getEdgeByEid gr i, edgeLabByEid gr i)) (nEdges gr) 0 + where + go f n i | i >= n = return () + | otherwise = put (f i) >> go f n (i+1) + get = do + nn <- get + nds <- replicateM nn get + ne <- get + es <- replicateM ne get + return $ mkGraph nds es + + {- +encodeC :: (Monad m, Graph d, Serialize v, Serialize e, Hashable v, Eq v) + => LGraph d v e -> ConduitT i B.ByteString m () +encodeC gr = do + sourcePut $ put (M.toList $ _labelToNode gr) + yieldMany (edges gr) .| mapC (\e -> (e, edgeLab gr e)) .| conduitPut put + +decodeC :: ( PrimMonad m, MonadThrow m, Graph d + , Serialize v, Serialize e, Hashable v, Eq v ) + => ConduitT B.ByteString o m (LGraph d v e) +decodeC = do + labelToId <- M.fromList <$> sinkGet get + conduitGet2 get .| deserializeGraphFromEdges 10000 labelToId + -} empty :: (Graph d, Hashable v, Serialize v, Eq v, Serialize e) => LGraph d v e empty = runST $ new 0 >>= unsafeFreeze mkGraph :: (Graph d, Hashable v, Serialize v, Eq v, Serialize e) - => [v] -> [(Edge, e)] -> LGraph d v e + => [v] -> [LEdge e] -> LGraph d v e mkGraph vattr es = runST $ do g <- new 0 - addLNodes n vattr g - addLEdges (map (\((fr,to),x) -> (fr,to,x)) es) g + addLNodes vattr g + addLEdges es g unsafeFreeze g where n = length vattr @@ -157,6 +184,48 @@ fromLabeledEdges es = mkGraph labels es' labels = S.toList $ S.fromList $ concat [ [a,b] | ((a,b),_) <- es ] labelToId = M.fromList $ zip labels [0..] +-- | Deserialize a graph. +fromLabeledEdges' :: (PrimMonad m, Graph d, Hashable v, Serialize v, Eq v, Serialize e) + => Int -- ^ buffer size + -> a -- ^ Input, usually a file + -> (a -> ConduitT () ((v, v), e) m ()) -- ^ deserialize the input into a stream of edges + -> m (LGraph d v e) +fromLabeledEdges' bufferN input mkConduit = do + (labelToId, _) <- runConduit $ mkConduit input .| foldlC f (M.empty, 0::Int) + let getId x = M.lookupDefault undefined x labelToId + runConduit $ mkConduit input .| + mapC (\((v1, v2), e) -> ((getId v1, getId v2), e)) .| + deserializeGraph bufferN + (fst $ unzip $ sortBy (comparing snd) $ M.toList labelToId) + where + f acc ((v1, v2), _) = add v1 $ add v2 acc + where + add v (m, i) = if v `M.member` m + then (m, i) + else (M.insert v i m, i + 1) + +deserializeGraph :: ( PrimMonad m, Graph d, Hashable v, Serialize v + , Eq v, Serialize e ) + => Int -- ^ buffer size + -> [v] + -> ConduitT (LEdge e) o m (LGraph d v e) +deserializeGraph bufferN nds = mkChunks bufferN .| buildGraph + where + buildGraph = do + gr <- new 0 + addLNodes nds gr + mapM_C (\es -> addLEdges es gr) + unsafeFreeze gr + mkChunks n = do + isEmpty <- nullC + unless isEmpty $ do + go 0 >>= yield + mkChunks n + where + go i | i >= n = return [] + | otherwise = await >>= maybe (return []) (\x -> fmap (x :) $ go (i+1)) +{-# INLINE deserializeGraph #-} + unsafeFreeze :: (Hashable v, Eq v, Serialize v, PrimMonad m) => MLGraph (PrimState m) d v e -> m (LGraph d v e) unsafeFreeze (MLGraph g) = unsafePrimToPrim $ do diff --git a/src/IGraph/Mutable.hs b/src/IGraph/Mutable.hs index 09a8532..15b5c44 100644 --- a/src/IGraph/Mutable.hs +++ b/src/IGraph/Mutable.hs @@ -43,15 +43,14 @@ class MGraph d where addNodes n (MLGraph g) = unsafePrimToPrim $ igraphAddVertices g n nullPtr addLNodes :: (Serialize v, PrimMonad m) - => Int -- ^ the number of new vertices add to the graph - -> [v] -- ^ vertices' labels + => [v] -- ^ vertices' labels -> MLGraph (PrimState m) d v e -> m () - addLNodes n labels (MLGraph g) - | n /= length labels = error "addLVertices: incorrect number of labels" - | otherwise = unsafePrimToPrim $ withVertexAttr $ \vattr -> - asBSVector labels $ \bsvec -> with (mkStrRec vattr bsvec) $ \ptr -> do - vptr <- fromPtrs [castPtr ptr] - withVectorPtr vptr (igraphAddVertices g n . castPtr) + addLNodes labels (MLGraph g) = unsafePrimToPrim $ withVertexAttr $ + \vattr -> asBSVector labels $ \bsvec -> with (mkStrRec vattr bsvec) $ + \ptr -> do vptr <- fromPtrs [castPtr ptr] + withVectorPtr vptr (igraphAddVertices g n . castPtr) + where + n = length labels delNodes :: PrimMonad m => [Int] -> MLGraph (PrimState m) d v e -> m () delNodes ns (MLGraph g) = unsafePrimToPrim $ do @@ -74,7 +73,7 @@ class MGraph d where vptr <- fromPtrs [castPtr ptr] withVectorPtr vptr (igraphAddEdges g vec . castPtr) where - (xs, vs) = unzip $ map ( \(a,b,v) -> ([fromIntegral a, fromIntegral b], v) ) es + (xs, vs) = unzip $ map ( \((a,b),v) -> ([fromIntegral a, fromIntegral b], v) ) es delEdges :: PrimMonad m => [(Int, Int)] -> MLGraph (PrimState m) d v e -> m () diff --git a/src/IGraph/Types.hs b/src/IGraph/Types.hs index 24d0147..8092dbb 100644 --- a/src/IGraph/Types.hs +++ b/src/IGraph/Types.hs @@ -7,7 +7,7 @@ import IGraph.Internal.Graph type Node = Int type Edge = (Node, Node) -type LEdge a = (Int, Int, a) +type LEdge a = (Edge, a) data U = U data D = D diff --git a/stack.yaml b/stack.yaml index e0e5b65..93eba36 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,10 +1,10 @@ flags: haskell-igraph: - graphics: true + graphics: false packages: - '.' extra-deps: [] -resolver: lts-10.10 +resolver: nightly-2018-04-19 diff --git a/tests/Test/Attributes.hs b/tests/Test/Attributes.hs index d56febb..289314a 100644 --- a/tests/Test/Attributes.hs +++ b/tests/Test/Attributes.hs @@ -2,6 +2,7 @@ module Test.Attributes ( tests ) where +import Conduit import Control.Monad import Control.Monad.ST import Data.List @@ -53,4 +54,7 @@ serializeTest = testCase "serialize test" $ do Left msg -> error msg Right r -> r es' = map (\(a,b) -> ((nodeLab gr' a, nodeLab gr' b), edgeLab gr' (a,b))) $ edges gr' - assertBool "" $ sort (map show es) == sort (map show es') + gr'' <- runConduit $ encodeC gr .| decodeC :: IO (LGraph D NodeAttr EdgeAttr) + let es'' = map (\(a,b) -> ((nodeLab gr'' a, nodeLab gr'' b), edgeLab gr'' (a,b))) $ edges gr'' + assertBool "" $ sort (map show es) == sort (map show es') && + sort (map show es) == sort (map show es'') diff --git a/tests/Test/Basic.hs b/tests/Test/Basic.hs index 8bb82d7..bde87cf 100644 --- a/tests/Test/Basic.hs +++ b/tests/Test/Basic.hs @@ -10,6 +10,7 @@ import System.IO.Unsafe import Test.Tasty import Test.Tasty.HUnit import Test.Utils +import Conduit import IGraph import IGraph.Mutable @@ -39,14 +40,17 @@ graphCreationLabeled :: TestTree graphCreationLabeled = testGroup "Graph creation -- with labels" [ testCase "" $ assertBool "" $ nNodes gr == n && nEdges gr == m , testCase "" $ edgeList @=? (sort $ map (\(fr,to) -> - (nodeLab gr fr, nodeLab gr to)) $ edges gr) + ((nodeLab gr fr, nodeLab gr to), edgeLab gr (fr, to))) $ edges gr) + , testCase "" $ edgeList @=? (sort $ map (\(fr,to) -> + ((nodeLab gr' fr, nodeLab gr' to), edgeLab gr' (fr, to))) $ edges gr') ] where - edgeList = sort $ map (\(a,b) -> (show a, show b)) $ unsafePerformIO $ - randEdges 10000 1000 - n = length $ nubSort $ concatMap (\(a,b) -> [a,b]) edgeList + edgeList = zip (sort $ map (\(a,b) -> (show a, show b)) $ unsafePerformIO $ + randEdges 10000 1000) $ repeat 1 + n = length $ nubSort $ concatMap (\((a,b),_) -> [a,b]) edgeList m = length edgeList - gr = fromLabeledEdges $ zip edgeList $ repeat () :: LGraph D String () + gr = fromLabeledEdges edgeList :: LGraph D String Int + gr' = runST $ fromLabeledEdges' 10 edgeList yieldMany :: LGraph D String Int graphEdit :: TestTree graphEdit = testGroup "Graph editing" -- 2.21.0