Commit 5ee7761f authored by Alfredo Di Napoli's avatar Alfredo Di Napoli Committed by Alfredo Di Napoli

Improve the performance of sumRows in massiv implementation

parent e2d59228
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
{-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Main where module Main where
...@@ -15,15 +16,16 @@ import Gargantext.Core.Viz.Phylo.PhyloTools ...@@ -15,15 +16,16 @@ import Gargantext.Core.Viz.Phylo.PhyloTools
import Gargantext.Prelude.Crypto.Auth (createPasswordHash) import Gargantext.Prelude.Crypto.Auth (createPasswordHash)
import Paths_gargantext import Paths_gargantext
import qualified Data.Array.Accelerate as A import qualified Data.Array.Accelerate as A
import qualified Data.Array.Accelerate as Accelerate
import qualified Data.Array.Accelerate.Interpreter as LLVM
import qualified Data.Array.Accelerate.Interpreter as Naive import qualified Data.Array.Accelerate.Interpreter as Naive
import qualified Data.List.Split as Split import qualified Data.List.Split as Split
import qualified Data.Massiv.Array as Massiv import qualified Data.Massiv.Array as Massiv
import qualified Gargantext.Core.LinearAlgebra as LA import qualified Gargantext.Core.LinearAlgebra as LA
import qualified Gargantext.Core.Methods.Similarities.Accelerate.Distributional as Accelerate
import qualified Gargantext.Core.Methods.Matrix.Accelerate.Utils as Accelerate import qualified Gargantext.Core.Methods.Matrix.Accelerate.Utils as Accelerate
import qualified Gargantext.Core.Methods.Similarities.Accelerate.Distributional as Accelerate
import Test.Tasty.Bench import Test.Tasty.Bench
import qualified Data.Array.Accelerate.Interpreter as LLVM import Data.Array.Accelerate ((:.))
import qualified Data.Array.Accelerate as Accelerate
phyloConfig :: PhyloConfig phyloConfig :: PhyloConfig
phyloConfig = PhyloConfig { phyloConfig = PhyloConfig {
...@@ -59,16 +61,26 @@ testMatrix :: A.Matrix Int ...@@ -59,16 +61,26 @@ testMatrix :: A.Matrix Int
testMatrix = A.fromList (A.Z A.:. matrixDim A.:. matrixDim) $ matrixValues testMatrix = A.fromList (A.Z A.:. matrixDim A.:. matrixDim) $ matrixValues
{-# INLINE testMatrix #-} {-# INLINE testMatrix #-}
testVector :: A.Array (A.Z :. Int :. Int :. Int) Int
testVector = A.fromList (A.Z A.:. 20 A.:. 20 A.:. 20) $ matrixValues
{-# INLINE testVector #-}
testMassivMatrix :: Massiv.Matrix Massiv.U Int testMassivMatrix :: Massiv.Matrix Massiv.U Int
testMassivMatrix = Massiv.fromLists' Massiv.Par $ Split.chunksOf matrixDim $ matrixValues testMassivMatrix = Massiv.fromLists' Massiv.Par $ Split.chunksOf matrixDim $ matrixValues
{-# INLINE testMassivMatrix #-} {-# INLINE testMassivMatrix #-}
testMassivVector :: Massiv.Array Massiv.U Massiv.Ix3 Int
testMassivVector = LA.accelerate2Massiv3DMatrix testVector
{-# INLINE testMassivVector #-}
main :: IO () main :: IO ()
main = do main = do
_issue290Phylo <- force . setConfig phyloConfig <$> (readPhylo =<< getDataFileName "bench-data/phylo/issue-290.json") _issue290Phylo <- force . setConfig phyloConfig <$> (readPhylo =<< getDataFileName "bench-data/phylo/issue-290.json")
issue290PhyloSmall <- force . setConfig phyloConfig <$> (readPhylo =<< getDataFileName "bench-data/phylo/issue-290-small.json") issue290PhyloSmall <- force . setConfig phyloConfig <$> (readPhylo =<< getDataFileName "bench-data/phylo/issue-290-small.json")
let !accInput = force testMatrix let !accInput = force testMatrix
let !accVector = force testVector
let !massivInput = force testMassivMatrix let !massivInput = force testMassivMatrix
let !massivVector = force testMassivVector
let !(accDoubleInput :: Accelerate.Matrix Double) = force $ Naive.run $ Accelerate.map Accelerate.fromIntegral (Accelerate.use testMatrix) let !(accDoubleInput :: Accelerate.Matrix Double) = force $ Naive.run $ Accelerate.map Accelerate.fromIntegral (Accelerate.use testMatrix)
let !massivInput = force testMassivMatrix let !massivInput = force testMassivMatrix
let !(massivDoubleInput :: Massiv.Matrix Massiv.U Double) = force $ Massiv.computeP $ Massiv.map fromIntegral testMassivMatrix let !(massivDoubleInput :: Massiv.Matrix Massiv.U Double) = force $ Massiv.computeP $ Massiv.map fromIntegral testMassivMatrix
...@@ -87,6 +99,16 @@ main = do ...@@ -87,6 +99,16 @@ main = do
, bench "Accelerate (LLVM)" $ nf (LLVM.run . Accelerate.diag . Accelerate.use) accInput , bench "Accelerate (LLVM)" $ nf (LLVM.run . Accelerate.diag . Accelerate.use) accInput
, bench "Massiv " $ nf (LA.diag @_) massivInput , bench "Massiv " $ nf (LA.diag @_) massivInput
] ]
, bgroup "(.*)" [
bench "Accelerate (Naive)" $ nf (\v -> Naive.run $ (Accelerate.use v) Accelerate..* (Accelerate.use v)) accDoubleInput
, bench "Accelerate (LLVM)" $ nf (\v -> LLVM.run $ (Accelerate.use v) Accelerate..* (Accelerate.use v)) accDoubleInput
, bench "Massiv " $ nf (\v -> (v LA..* v) :: Massiv.Matrix Massiv.U Double) massivDoubleInput
]
, bgroup "sumRows" [
bench "Accelerate (Naive)" $ nf (Naive.run . Accelerate.sum . Accelerate.use) accVector
, bench "Accelerate (LLVM)" $ nf (LLVM.run . Accelerate.sum . Accelerate.use) accVector
, bench "Massiv " $ nf LA.sumRows massivVector
]
, bgroup "termDivNan" [ , bgroup "termDivNan" [
bench "Accelerate (Naive)" $ bench "Accelerate (Naive)" $
nf (\m -> Naive.run $ Accelerate.termDivNan (Accelerate.use m) (Accelerate.use m)) accDoubleInput nf (\m -> Naive.run $ Accelerate.termDivNan (Accelerate.use m) (Accelerate.use m)) accDoubleInput
...@@ -97,7 +119,7 @@ main = do ...@@ -97,7 +119,7 @@ main = do
, bgroup "distributional" [ , bgroup "distributional" [
bench "Accelerate (Naive)" $ nf (Accelerate.distributionalWith @Double Naive.run) accInput bench "Accelerate (Naive)" $ nf (Accelerate.distributionalWith @Double Naive.run) accInput
, bench "Accelerate (LLVM)" $ nf Accelerate.distributional accInput , bench "Accelerate (LLVM)" $ nf Accelerate.distributional accInput
, bench "Massiv " $ nf (Massiv.computeP @Massiv.U . LA.distributional @_ @Double) massivInput , bench "Massiv " $ nf (LA.distributional @_ @Double) massivInput
] ]
] ]
] ]
{-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-| {-|
Module : Gargantext.Core.LinearAlgebra Module : Gargantext.Core.LinearAlgebra
Description : Linear Algebra utility functions Description : Linear Algebra utility functions
...@@ -22,6 +23,7 @@ module Gargantext.Core.LinearAlgebra ( ...@@ -22,6 +23,7 @@ module Gargantext.Core.LinearAlgebra (
-- * Convertion functions -- * Convertion functions
, accelerate2MassivMatrix , accelerate2MassivMatrix
, accelerate2Massiv3DMatrix
, massiv2AccelerateMatrix , massiv2AccelerateMatrix
, massiv2AccelerateVector , massiv2AccelerateVector
...@@ -30,6 +32,10 @@ module Gargantext.Core.LinearAlgebra ( ...@@ -30,6 +32,10 @@ module Gargantext.Core.LinearAlgebra (
, diag , diag
, termDivNan , termDivNan
, distributional , distributional
, sumRows
-- * Internals for testing
, sumRowsReferenceImplementation
) where ) where
import Data.Array.Accelerate qualified as Acc import Data.Array.Accelerate qualified as Acc
...@@ -44,6 +50,7 @@ import Data.Set qualified as S ...@@ -44,6 +50,7 @@ import Data.Set qualified as S
import Data.Set (Set) import Data.Set (Set)
import Prelude import Prelude
import Protolude.Safe (headMay) import Protolude.Safe (headMay)
import Data.Monoid
newtype Index = Index { _Index :: Int } newtype Index = Index { _Index :: Int }
deriving newtype (Eq, Show, Ord, Num, Enum) deriving newtype (Eq, Show, Ord, Num, Enum)
...@@ -78,6 +85,14 @@ massiv2AccelerateVector m = ...@@ -78,6 +85,14 @@ massiv2AccelerateVector m =
r = Prelude.length m' r = Prelude.length m'
in Acc.fromList (Acc.Z Acc.:. r) m' in Acc.fromList (Acc.Z Acc.:. r) m'
accelerate2Massiv3DMatrix :: (A.Unbox e, Acc.Elt e, A.Manifest r e)
=> Acc.Array (Acc.Z Acc.:. Int Acc.:. Int Acc.:. Int) e
-> A.Array r A.Ix3 e
accelerate2Massiv3DMatrix m =
let (Acc.Z Acc.:. _r Acc.:. _c Acc.:. _z) = Acc.arrayShape m
in A.fromLists' A.Par $ map (Split.chunksOf $ _z) $ Split.chunksOf (_c*_z) (Acc.toList m)
-- | Computes the diagnonal matrix of the input one. -- | Computes the diagnonal matrix of the input one.
diag :: (A.Unbox e, A.Manifest r e, A.Source r e, Num e) => Matrix r e -> Vector A.U e diag :: (A.Unbox e, A.Manifest r e, A.Source r e, Num e) => Matrix r e -> Vector A.U e
...@@ -128,7 +143,7 @@ distributional :: forall r e. (A.Manifest r e, A.Unbox e, A.Source r Int, A.Size ...@@ -128,7 +143,7 @@ distributional :: forall r e. (A.Manifest r e, A.Unbox e, A.Source r Int, A.Size
distributional m' = result distributional m' = result
where where
m :: Matrix A.U e m :: Matrix A.U e
m = A.compute$ A.map fromIntegral m' m = A.computeP $ A.map fromIntegral m'
n = dim m' n = dim m'
diag_m :: Vector A.U e diag_m :: Vector A.U e
...@@ -182,16 +197,28 @@ termDivNan :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, Eq a, Fractional a ...@@ -182,16 +197,28 @@ termDivNan :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, Eq a, Fractional a
=> Matrix r1 a => Matrix r1 a
-> Matrix r2 a -> Matrix r2 a
-> Matrix r3 a -> Matrix r3 a
termDivNan m1 m2 = A.compute $ A.zipWith (\i j -> if j == 0 then 0 else i / j) m1 m2 termDivNan m1 m2 = A.computeP $ A.zipWith (\i j -> if j == 0 then 0 else i / j) m1 m2
sumRows :: (A.Load r A.Ix2 e sumRows :: ( A.Load r A.Ix2 e
, A.Source r e , A.Source r e
, A.Manifest r e
, A.Strategy r , A.Strategy r
, A.Size r , A.Size r
, Num e , Num e
) => Array r A.Ix3 e ) => Array r A.Ix3 e
-> Array r A.Ix2 e -> Array r A.Ix2 e
sumRows matrix = sumRows matrix =
A.computeP $ A.map getSum $ A.foldlWithin' 1 (\(Sum s) n -> Sum $ s + n) mempty matrix
sumRowsReferenceImplementation :: ( A.Load r A.Ix2 e
, A.Source r e
, A.Manifest r e
, A.Strategy r
, A.Size r
, Num e
) => Array r A.Ix3 e
-> Array r A.Ix2 e
sumRowsReferenceImplementation matrix =
let A.Sz3 rows cols z = A.size matrix let A.Sz3 rows cols z = A.size matrix
in A.makeArray (A.getComp matrix) (A.Sz2 rows cols) $ \(i A.:. j) -> in A.makeArray (A.getComp matrix) (A.Sz2 rows cols) $ \(i A.:. j) ->
A.sum (A.backpermute' (A.Sz1 z) (\c -> i A.:> j A.:. c) matrix) A.sum (A.backpermute' (A.Sz1 z) (\c -> i A.:> j A.:. c) matrix)
...@@ -201,7 +228,7 @@ sumRows matrix = ...@@ -201,7 +228,7 @@ sumRows matrix =
=> Array r1 ix a => Array r1 ix a
-> Array r2 ix a -> Array r2 ix a
-> Array r3 ix a -> Array r3 ix a
(.*) m1 m2 = A.compute $ A.zipWith (*) m1 m2 (.*) m1 m2 = A.computeP $ A.zipWith (*) m1 m2
-- | Get the dimensions of a /square/ matrix. -- | Get the dimensions of a /square/ matrix.
dim :: A.Size r => Matrix r a -> Int dim :: A.Size r => Matrix r a -> Int
......
...@@ -7,10 +7,18 @@ module Gargantext.Orphans.Accelerate where ...@@ -7,10 +7,18 @@ module Gargantext.Orphans.Accelerate where
import Prelude import Prelude
import Test.QuickCheck import Test.QuickCheck
import Data.Scientific () import Data.Scientific ()
import Data.Array.Accelerate (DIM2, Z (..), (:.) (..), Array, Elt, fromList, arrayShape) import Data.Array.Accelerate (DIM2, Z (..), (:.) (..), Array, Elt, fromList, arrayShape, DIM3)
import Data.Array.Accelerate qualified as A import Data.Array.Accelerate qualified as A
import qualified Data.List.Split as Split import qualified Data.List.Split as Split
instance (Show e, Elt e, Arbitrary e, Num e, Ord e) => Arbitrary (Array DIM3 e) where
arbitrary = do
x <- choose (1,10)
y <- choose (1,10)
z <- choose (1,10)
let sh = Z :. x :. y :. z
fromList sh <$> vectorOf (x * y * z) (getPositive <$> arbitrary)
instance (Show e, Elt e, Arbitrary e) => Arbitrary (Array DIM2 e) where instance (Show e, Elt e, Arbitrary e) => Arbitrary (Array DIM2 e) where
arbitrary = do arbitrary = do
x <- choose (1,128) x <- choose (1,128)
......
...@@ -65,7 +65,7 @@ mapCreateIndices (_m1, m2) = Bimap.fromList $ map (first LA.Index) $ M.toList m2 ...@@ -65,7 +65,7 @@ mapCreateIndices (_m1, m2) = Bimap.fromList $ map (first LA.Index) $ M.toList m2
type TermDivNanShape = Z :. Int :. Int type TermDivNanShape = Z :. Int :. Int
twoByTwo :: SquareMatrix Int twoByTwo :: SquareMatrix Int
twoByTwo = SquareMatrix $ fromList (Z :. 2 :. 2) (Prelude.replicate 4 0) twoByTwo = SquareMatrix $ fromList (Z :. 2 :. 2) (Prelude.replicate 4 5)
testMatrix_01 :: SquareMatrix Int testMatrix_01 :: SquareMatrix Int
testMatrix_01 = SquareMatrix $ fromList (Z :. 14 :. 14) $ testMatrix_01 = SquareMatrix $ fromList (Z :. 14 :. 14) $
...@@ -106,6 +106,7 @@ tests = testGroup "LinearAlgebra" [ ...@@ -106,6 +106,7 @@ tests = testGroup "LinearAlgebra" [
testProperty "createIndices roundtrip" (compareImplementations (LA.createIndices @Int @Int) Legacy.createIndices mapCreateIndices) testProperty "createIndices roundtrip" (compareImplementations (LA.createIndices @Int @Int) Legacy.createIndices mapCreateIndices)
, testProperty "termDivNan" compareTermDivNan , testProperty "termDivNan" compareTermDivNan
, testProperty "diag" compareDiag , testProperty "diag" compareDiag
, testProperty "sumRows" compareSumRows
, testGroup "distributional" [ , testGroup "distributional" [
testProperty "2x2" (compareDistributional (Proxy @Double) twoByTwo) testProperty "2x2" (compareDistributional (Proxy @Double) twoByTwo)
, testProperty "7x7" (compareDistributional (Proxy @Double) testMatrix_02) , testProperty "7x7" (compareDistributional (Proxy @Double) testMatrix_02)
...@@ -132,6 +133,14 @@ compareDiag (SquareMatrix i1) ...@@ -132,6 +133,14 @@ compareDiag (SquareMatrix i1)
accelerate = Naive.run (Legacy.diag (use i1)) accelerate = Naive.run (Legacy.diag (use i1))
in accelerate === LA.massiv2AccelerateVector massiv in accelerate === LA.massiv2AccelerateVector massiv
compareSumRows :: Array (Z :. Int :. Int :. Int) Int -> Property
compareSumRows i1
= let massiv = LA.sumRows @Massiv.U (LA.accelerate2Massiv3DMatrix i1)
massiv' = LA.sumRowsReferenceImplementation @Massiv.U (LA.accelerate2Massiv3DMatrix i1)
accelerate = Naive.run (A.sum (use i1))
in counterexample "sumRows and reference implementation do not agree" (massiv === massiv') .&&.
accelerate === LA.massiv2AccelerateMatrix massiv
compareDistributional :: forall e. compareDistributional :: forall e.
( Eq e ( Eq e
, Show e , Show e
...@@ -142,6 +151,7 @@ compareDistributional :: forall e. ...@@ -142,6 +151,7 @@ compareDistributional :: forall e.
, Ord e , Ord e
, Prelude.Fractional (Exp e) , Prelude.Fractional (Exp e)
, Prelude.Fractional e , Prelude.Fractional e
, Monoid e
) => Proxy e ) => Proxy e
-> SquareMatrix Int -> SquareMatrix Int
-> Property -> Property
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment