Commit 96e57e41 authored by Alfredo Di Napoli's avatar Alfredo Di Napoli Committed by Alfredo Di Napoli

Fix division by 0 bug in distributional

The previous code was sometimes yielding a matrix of NaN numbers as
it was attempting the division of the input matrix with the diagonal,
which would be 0 in case of an input matrix of 0, resulting in a
division by 0 error.
parent 690629a4
...@@ -188,9 +188,12 @@ library ...@@ -188,9 +188,12 @@ library
Gargantext.Core.Config.Types Gargantext.Core.Config.Types
Gargantext.Core.Config.Utils Gargantext.Core.Config.Utils
Gargantext.Core.Config.Worker Gargantext.Core.Config.Worker
Gargantext.Core.LinearAlgebra
Gargantext.Core.Mail Gargantext.Core.Mail
Gargantext.Core.Mail.Types Gargantext.Core.Mail.Types
Gargantext.Core.Methods.Matrix.Accelerate.Utils
Gargantext.Core.Methods.Similarities Gargantext.Core.Methods.Similarities
Gargantext.Core.Methods.Similarities.Accelerate.Distributional
Gargantext.Core.Methods.Similarities.Conditional Gargantext.Core.Methods.Similarities.Conditional
Gargantext.Core.NLP Gargantext.Core.NLP
Gargantext.Core.NodeStory Gargantext.Core.NodeStory
...@@ -204,7 +207,6 @@ library ...@@ -204,7 +207,6 @@ library
Gargantext.Core.Notifications.Dispatcher.Types Gargantext.Core.Notifications.Dispatcher.Types
Gargantext.Core.Notifications.Dispatcher.WebSocket Gargantext.Core.Notifications.Dispatcher.WebSocket
Gargantext.Core.Notifications.Nanomsg Gargantext.Core.Notifications.Nanomsg
Gargantext.Core.LinearAlgebra
Gargantext.Core.Text Gargantext.Core.Text
Gargantext.Core.Text.Context Gargantext.Core.Text.Context
Gargantext.Core.Text.Corpus.API Gargantext.Core.Text.Corpus.API
...@@ -296,6 +298,9 @@ library ...@@ -296,6 +298,9 @@ library
Gargantext.Database.Schema.User Gargantext.Database.Schema.User
Gargantext.Defaults Gargantext.Defaults
Gargantext.MicroServices.ReverseProxy Gargantext.MicroServices.ReverseProxy
Gargantext.Orphans
Gargantext.Orphans.Accelerate
Gargantext.Orphans.OpenAPI
Gargantext.System.Logging Gargantext.System.Logging
Gargantext.Utils.Dict Gargantext.Utils.Dict
Gargantext.Utils.Jobs.Error Gargantext.Utils.Jobs.Error
...@@ -357,9 +362,7 @@ library ...@@ -357,9 +362,7 @@ library
Gargantext.Core.Flow.Ngrams Gargantext.Core.Flow.Ngrams
Gargantext.Core.Flow.Types Gargantext.Core.Flow.Types
Gargantext.Core.Methods.Graph.MaxClique Gargantext.Core.Methods.Graph.MaxClique
Gargantext.Core.Methods.Matrix.Accelerate.Utils
Gargantext.Core.Methods.Similarities.Accelerate.Conditional Gargantext.Core.Methods.Similarities.Accelerate.Conditional
Gargantext.Core.Methods.Similarities.Accelerate.Distributional
Gargantext.Core.Methods.Similarities.Accelerate.SpeGen Gargantext.Core.Methods.Similarities.Accelerate.SpeGen
Gargantext.Core.Statistics Gargantext.Core.Statistics
Gargantext.Core.Text.Corpus Gargantext.Core.Text.Corpus
...@@ -471,8 +474,6 @@ library ...@@ -471,8 +474,6 @@ library
Gargantext.Database.Schema.NodeNode Gargantext.Database.Schema.NodeNode
Gargantext.Database.Schema.Prelude Gargantext.Database.Schema.Prelude
Gargantext.Database.Types Gargantext.Database.Types
Gargantext.Orphans
Gargantext.Orphans.OpenAPI
Gargantext.Utils.Aeson Gargantext.Utils.Aeson
Gargantext.Utils.Servant Gargantext.Utils.Servant
Gargantext.Utils.UTCTime Gargantext.Utils.UTCTime
...@@ -706,6 +707,8 @@ common testDependencies ...@@ -706,6 +707,8 @@ common testDependencies
build-depends: build-depends:
base >=4.7 && <5 base >=4.7 && <5
, QuickCheck ^>= 2.14.2 , QuickCheck ^>= 2.14.2
, accelerate >= 1.3.0.0
, accelerate-llvm-native
, aeson ^>= 2.1.2.1 , aeson ^>= 2.1.2.1
, aeson-qq , aeson-qq
, async ^>= 2.2.4 , async ^>= 2.2.4
......
...@@ -38,8 +38,6 @@ import Data.Array.Accelerate ...@@ -38,8 +38,6 @@ import Data.Array.Accelerate
import Data.Array.Accelerate.Interpreter (run) import Data.Array.Accelerate.Interpreter (run)
import qualified Gargantext.Prelude as P import qualified Gargantext.Prelude as P
import Debug.Trace (trace)
-- | Matrix cell by cell multiplication -- | Matrix cell by cell multiplication
(.*) :: ( Shape ix (.*) :: ( Shape ix
, Slice ix , Slice ix
...@@ -74,7 +72,7 @@ termDivNan :: ( Shape ix ...@@ -74,7 +72,7 @@ termDivNan :: ( Shape ix
=> Acc (Array ((ix :. Int) :. Int) a) => Acc (Array ((ix :. Int) :. Int) a)
-> Acc (Array ((ix :. Int) :. Int) a) -> Acc (Array ((ix :. Int) :. Int) a)
-> Acc (Array ((ix :. Int) :. Int) a) -> Acc (Array ((ix :. Int) :. Int) a)
termDivNan = trace "termDivNan" $ zipWith (\i j -> cond ((==) j 0) 0 ((/) i j)) termDivNan = zipWith (\i j -> cond ((==) j 0) 0 ((/) i j))
(.-) :: ( Shape ix (.-) :: ( Shape ix
, Slice ix , Slice ix
......
...@@ -95,7 +95,7 @@ module Gargantext.Core.Methods.Similarities.Accelerate.Distributional ...@@ -95,7 +95,7 @@ module Gargantext.Core.Methods.Similarities.Accelerate.Distributional
-- import Debug.Trace (trace) -- import Debug.Trace (trace)
import Data.Array.Accelerate as A import Data.Array.Accelerate as A
-- import Data.Array.Accelerate.Interpreter (run) -- import Data.Array.Accelerate.Interpreter (run)
import Data.Array.Accelerate.LLVM.Native (run) -- TODO: try runQ? import Data.Array.Accelerate.LLVM.Native qualified as LLVM -- TODO: try runQ?
import Gargantext.Core.Methods.Matrix.Accelerate.Utils import Gargantext.Core.Methods.Matrix.Accelerate.Utils
import qualified Gargantext.Prelude as P import qualified Gargantext.Prelude as P
...@@ -139,7 +139,10 @@ import qualified Prelude ...@@ -139,7 +139,10 @@ import qualified Prelude
-- 0.3333333333333333, 5.7692307692307696e-2, 1.0, 1.0] -- 0.3333333333333333, 5.7692307692307696e-2, 1.0, 1.0]
-- --
distributional :: Matrix Int -> Matrix Double distributional :: Matrix Int -> Matrix Double
distributional m' = run $ result distributional = distributionalWith LLVM.run
distributionalWith :: (forall a. Arrays a => Acc a -> a) -> Matrix Int -> Matrix Double
distributionalWith interpret m' = interpret $ result
where where
m = map A.fromIntegral $ use m' m = map A.fromIntegral $ use m'
n = dim m' n = dim m'
...@@ -149,7 +152,7 @@ distributional m' = run $ result ...@@ -149,7 +152,7 @@ distributional m' = run $ result
d_1 = replicate (constant (Z :. n :. All)) diag_m d_1 = replicate (constant (Z :. n :. All)) diag_m
d_2 = replicate (constant (Z :. All :. n)) diag_m d_2 = replicate (constant (Z :. All :. n)) diag_m
mi = (.*) ((./) m d_1) ((./) m d_2) mi = (.*) (termDivNan m d_1) (termDivNan m d_2)
-- w = (.-) mi d_mi -- w = (.-) mi d_mi
...@@ -170,7 +173,7 @@ distributional m' = run $ result ...@@ -170,7 +173,7 @@ distributional m' = run $ result
result = termDivNan z_1 z_2 result = termDivNan z_1 z_2
logDistributional2 :: Matrix Int -> Matrix Double logDistributional2 :: Matrix Int -> Matrix Double
logDistributional2 m = trace ("logDistributional2, dim=" `mappend` show n) . run logDistributional2 m = trace ("logDistributional2, dim=" `mappend` show n) . LLVM.run
$ diagNull n $ diagNull n
$ matMaxMini $ matMaxMini
$ logDistributional' n m $ logDistributional' n m
...@@ -265,7 +268,7 @@ logDistributional' n m' = trace ("logDistributional'") result ...@@ -265,7 +268,7 @@ logDistributional' n m' = trace ("logDistributional'") result
-- --
logDistributional :: Matrix Int -> Matrix Double logDistributional :: Matrix Int -> Matrix Double
logDistributional m' = run $ diagNull n $ result logDistributional m' = LLVM.run $ diagNull n $ result
where where
m = map fromIntegral $ use m' m = map fromIntegral $ use m'
n = dim m' n = dim m'
...@@ -319,7 +322,7 @@ logDistributional m' = run $ diagNull n $ result ...@@ -319,7 +322,7 @@ logDistributional m' = run $ diagNull n $ result
distributional'' :: Matrix Int -> Matrix Double distributional'' :: Matrix Int -> Matrix Double
distributional'' m = -- run {- $ matMaxMini -} distributional'' m = -- run {- $ matMaxMini -}
run $ diagNull n LLVM.run $ diagNull n
$ rIJ n $ rIJ n
$ filterWith 0 100 $ filterWith 0 100
$ filter' 0 $ filter' 0
......
{-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -Wno-orphans #-}
module Gargantext.Orphans ( module Gargantext.Orphans (
module Gargantext.Orphans.OpenAPI module Gargantext.Orphans.OpenAPI
) where ) where
import Data.Aeson qualified as JSON import Data.Aeson qualified as JSON
import Gargantext.Database.Admin.Types.Hyperdata (Hyperdata) import Gargantext.Database.Admin.Types.Hyperdata (Hyperdata)
import Gargantext.Orphans.Accelerate ()
import Gargantext.Orphans.OpenAPI import Gargantext.Orphans.OpenAPI
instance Hyperdata JSON.Value instance Hyperdata JSON.Value
{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE MultiWayIf #-}
module Gargantext.Orphans.Accelerate () where
import Prelude
import Test.QuickCheck
import Data.Array.Accelerate hiding ((<=))
import qualified Data.Array.Accelerate.Sugar.Shape as A
instance Arbitrary DIM0 where
arbitrary = return Z
instance Arbitrary DIM1 where
arbitrary = (Z :.) <$> choose (0,1024)
shrink = \(Z :. i) -> if i <= 0 then [] else [Z :. i - 1 ]
instance Arbitrary DIM2 where
arbitrary = do
x <- choose (0,128)
y <- choose (0,48)
return (Z :. y :. x)
shrink = \(Z :. r :. c) ->
if | r <= 0 -> []
| c <= 0 -> []
| otherwise -> [Z :. (r - 1)
:. (c - 1)
]
instance Arbitrary DIM3 where
arbitrary = do
x <- choose (0,64)
y <- choose (0,32)
z <- choose (0,16)
return (Z :. z :. y :. x)
instance (Arbitrary sh, Shape sh, Elt e, Arbitrary e) => Arbitrary (Array sh e) where
arbitrary = do
sh <- arbitrary
fromList sh <$> vectorOf (A.size sh) arbitrary
{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Test.Core.LinearAlgebra where module Test.Core.LinearAlgebra where
import Prelude import Data.Bifunctor (first)
import Test.Tasty.QuickCheck import Data.Bimap (Bimap)
import Gargantext.Core.LinearAlgebra qualified as LA import Data.Bimap qualified as Bimap
import Test.Tasty
import Gargantext.Core.Viz.Graph.Index qualified as Legacy
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Map.Strict qualified as M import Data.Map.Strict qualified as M
import Data.Bimap (Bimap) import Gargantext.Core.LinearAlgebra qualified as LA
import qualified Data.Bimap as Bimap import Gargantext.Core.Methods.Matrix.Accelerate.Utils qualified as Legacy
import Data.Bifunctor (first) import Gargantext.Core.Methods.Similarities.Accelerate.Distributional qualified as Legacy
import Gargantext.Core.Viz.Graph.Index qualified as Legacy
import Gargantext.Orphans.Accelerate ()
import Prelude
import Test.Tasty
import Test.Tasty.QuickCheck
import Data.Array.Accelerate hiding (Ord, Eq, map)
import Data.Array.Accelerate.Interpreter qualified as Naive
import Data.Array.Accelerate.LLVM.Native qualified as LLVM
compareImplementations :: (Arbitrary a, Eq b, Show b) compareImplementations :: (Arbitrary a, Eq b, Show b)
=> (a -> b) => (a -> b)
...@@ -22,10 +29,34 @@ compareImplementations :: (Arbitrary a, Eq b, Show b) ...@@ -22,10 +29,34 @@ compareImplementations :: (Arbitrary a, Eq b, Show b)
compareImplementations implementation1 implementation2 mapResults inputData compareImplementations implementation1 implementation2 mapResults inputData
= implementation1 inputData === mapResults (implementation2 inputData) = implementation1 inputData === mapResults (implementation2 inputData)
compareTermDivNan :: (Array TermDivNanShape Double)
-> (Array TermDivNanShape Double)
-> Property
compareTermDivNan i1 i2
= Naive.run (Legacy.termDivNan (use i1) (use i2)) === Naive.run (Legacy.termDivNan (use i1) (use i2))
compareDistributional :: Matrix Int
-> Property
compareDistributional i1
= Legacy.distributionalWith Naive.run i1 === Legacy.distributionalWith LLVM.run i1
mapCreateIndices :: Ord t => (Map t Legacy.Index, Map Legacy.Index t) -> Bimap LA.Index t mapCreateIndices :: Ord t => (Map t Legacy.Index, Map Legacy.Index t) -> Bimap LA.Index t
mapCreateIndices (_m1, m2) = Bimap.fromList $ map (first LA.Index) $ M.toList m2 mapCreateIndices (_m1, m2) = Bimap.fromList $ map (first LA.Index) $ M.toList m2
type TermDivNanShape = Z :. Int :. Int
twoByTwo :: Matrix Int
twoByTwo = fromList (Z :. 2 :. 2) (Prelude.replicate 4 0)
tests :: TestTree tests :: TestTree
tests = testGroup "LinearAlgebra" [ tests = testGroup "LinearAlgebra" [
testProperty "createIndices roundtrip" (compareImplementations (LA.createIndices @Int @Int) Legacy.createIndices mapCreateIndices) testProperty "createIndices roundtrip" (compareImplementations (LA.createIndices @Int @Int) Legacy.createIndices mapCreateIndices)
, testProperty "termDivNan" compareTermDivNan
, testGroup "distributional" [
testProperty "2x2" (compareDistributional twoByTwo)
, testProperty "3x2" (compareDistributional $ fromList (Z :. 3 :. 2) (Prelude.replicate 6 0))
, testProperty "roundtrips" (compareImplementations (Legacy.distributionalWith Naive.run)
(Legacy.distributionalWith LLVM.run)
id)
]
] ]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment