Commit 4371781e authored by Alfredo Di Napoli's avatar Alfredo Di Napoli Committed by Alfredo Di Napoli

Initial work on the linear algebra tests

parent 96e57e41
......@@ -99,7 +99,7 @@ import Data.Array.Accelerate.LLVM.Native qualified as LLVM -- TODO: try runQ?
import Gargantext.Core.Methods.Matrix.Accelerate.Utils
import qualified Gargantext.Prelude as P
import Debug.Trace
import Debug.Trace (trace)
import Prelude (show, mappend{- , String, (<>), fromIntegral, flip -})
import qualified Prelude
......@@ -138,6 +138,8 @@ import qualified Prelude
-- 8.333333333333333e-2, 4.6875e-2, 1.0, 0.25,
-- 0.3333333333333333, 5.7692307692307696e-2, 1.0, 1.0]
--
-- /IMPORTANT/: As this function computes the diagonal matrix in order to carry on the computation
-- the input has to be a square matrix, or this function will fail at runtime.
distributional :: Matrix Int -> Matrix Double
distributional = distributionalWith LLVM.run
......
......@@ -2,12 +2,14 @@
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE MultiWayIf #-}
module Gargantext.Orphans.Accelerate () where
module Gargantext.Orphans.Accelerate where
import Prelude
import Test.QuickCheck
import Data.Array.Accelerate hiding ((<=))
import qualified Data.Array.Accelerate.Sugar.Shape as A
import Data.Array.Accelerate (DIM0, DIM1, DIM2, DIM3, Z (..), (:.) (..), Array, Elt, fromList, use, arrayShape)
import Data.Array.Accelerate qualified as A
import qualified Data.Array.Accelerate.Sugar.Shape as AS
import qualified Data.Array.Accelerate.Interpreter as Naive
instance Arbitrary DIM0 where
arbitrary = return Z
......@@ -35,8 +37,16 @@ instance Arbitrary DIM3 where
z <- choose (0,16)
return (Z :. z :. y :. x)
instance (Arbitrary sh, Shape sh, Elt e, Arbitrary e) => Arbitrary (Array sh e) where
instance (Elt e, Arbitrary e) => Arbitrary (Array DIM2 e) where
arbitrary = do
sh <- arbitrary
fromList sh <$> vectorOf (A.size sh) arbitrary
fromList sh <$> vectorOf (AS.size sh) arbitrary
shrink arr = sliceArray arr
-- Slice the array to the new shape, keeping the square dimensions.
sliceArray :: Elt e => Array DIM2 e -> [Array DIM2 e]
sliceArray arr =
case arrayShape arr of
(Z :. x :. y) -> case (x, y) of
(0,0) -> []
_ -> [ Naive.run $ A.init $ A.transpose $ A.init $ use arr ]
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE DerivingStrategies #-}
module Test.Core.LinearAlgebra where
import Data.Array.Accelerate hiding (Ord, Eq, map, (<=))
import Data.Array.Accelerate.Interpreter qualified as Naive
import Data.Bifunctor (first)
import Data.Bimap (Bimap)
import Data.Bimap qualified as Bimap
......@@ -12,13 +16,20 @@ import Gargantext.Core.LinearAlgebra qualified as LA
import Gargantext.Core.Methods.Matrix.Accelerate.Utils qualified as Legacy
import Gargantext.Core.Methods.Similarities.Accelerate.Distributional qualified as Legacy
import Gargantext.Core.Viz.Graph.Index qualified as Legacy
import Gargantext.Orphans.Accelerate ()
import Prelude
import Gargantext.Orphans.Accelerate (sliceArray)
import Prelude hiding ((^))
import Test.Tasty
import Test.Tasty.QuickCheck
import Data.Array.Accelerate hiding (Ord, Eq, map)
import Data.Array.Accelerate.Interpreter qualified as Naive
import Data.Array.Accelerate.LLVM.Native qualified as LLVM
newtype SquareMatrix a = SquareMatrix { _SquareMatrix :: Matrix a }
deriving newtype (Show, Eq)
instance (Elt a, Show a, Arbitrary a) => Arbitrary (SquareMatrix a) where
arbitrary = do
x <- choose (0,30)
let sh = Z :. x :. x
SquareMatrix . fromList sh <$> vectorOf (x*x) arbitrary
shrink = map (SquareMatrix) . sliceArray . _SquareMatrix
compareImplementations :: (Arbitrary a, Eq b, Show b)
=> (a -> b)
......@@ -29,6 +40,15 @@ compareImplementations :: (Arbitrary a, Eq b, Show b)
compareImplementations implementation1 implementation2 mapResults inputData
= implementation1 inputData === mapResults (implementation2 inputData)
compareImplementations' :: (Arbitrary a, Eq c, Show c)
=> (a -> b)
-> (a -> b)
-> (b -> c)
-> a
-> Property
compareImplementations' implementation1 implementation2 mapResults inputData
= mapResults (implementation1 inputData) === mapResults (implementation2 inputData)
compareTermDivNan :: (Array TermDivNanShape Double)
-> (Array TermDivNanShape Double)
-> Property
......@@ -38,7 +58,7 @@ compareTermDivNan i1 i2
compareDistributional :: Matrix Int
-> Property
compareDistributional i1
= Legacy.distributionalWith Naive.run i1 === Legacy.distributionalWith LLVM.run i1
= Legacy.distributionalWith Naive.run i1 === Legacy.distributionalWith Naive.run i1
mapCreateIndices :: Ord t => (Map t Legacy.Index, Map Legacy.Index t) -> Bimap LA.Index t
mapCreateIndices (_m1, m2) = Bimap.fromList $ map (first LA.Index) $ M.toList m2
......@@ -48,15 +68,16 @@ type TermDivNanShape = Z :. Int :. Int
twoByTwo :: Matrix Int
twoByTwo = fromList (Z :. 2 :. 2) (Prelude.replicate 4 0)
-- | Needed as the LLVM and Naive backend generates some double with a long exponent which
-- won't compare verbatim.
tests :: TestTree
tests = testGroup "LinearAlgebra" [
testProperty "createIndices roundtrip" (compareImplementations (LA.createIndices @Int @Int) Legacy.createIndices mapCreateIndices)
, testProperty "termDivNan" compareTermDivNan
, testGroup "distributional" [
testProperty "2x2" (compareDistributional twoByTwo)
, testProperty "3x2" (compareDistributional $ fromList (Z :. 3 :. 2) (Prelude.replicate 6 0))
, testProperty "roundtrips" (compareImplementations (Legacy.distributionalWith Naive.run)
(Legacy.distributionalWith LLVM.run)
id)
, testProperty "roundtrips" (compareImplementations' @(SquareMatrix Int) (Legacy.distributionalWith Naive.run . _SquareMatrix)
(Legacy.distributionalWith Naive.run . _SquareMatrix)
id)
]
]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment