Commit 286fc44d authored by Alfredo Di Napoli's avatar Alfredo Di Napoli Committed by Alfredo Di Napoli

Fix bug in unsafe division in both implementations of logDistributional2

parent 681eb942
...@@ -212,7 +212,8 @@ distributionalReferenceImplementation m' = result ...@@ -212,7 +212,8 @@ distributionalReferenceImplementation m' = result
logDistributional2 :: (A.Manifest r e logDistributional2 :: (A.Manifest r e
, A.Unbox e , A.Unbox e
, A.Source r Int , A.Source r Int
, A.Shape r Ix2, Num e , A.Shape r Ix2
, Num e
, Ord e , Ord e
, A.Source r e , A.Source r e
, Fractional e , Fractional e
...@@ -272,7 +273,7 @@ logDistributional' n m' = result ...@@ -272,7 +273,7 @@ logDistributional' n m' = result
mi_divvy :: Matrix A.D e mi_divvy :: Matrix A.D e
mi_divvy = A.zipWith (\m_val ss_val -> mi_divvy = A.zipWith (\m_val ss_val ->
let x = m_val / ss_val let x = m_val `safeDiv` ss_val
x' = x * to x' = x * to
in if (x' < 1) then 0 else log x') m ss in if (x' < 1) then 0 else log x') m ss
......
...@@ -53,13 +53,14 @@ import qualified Gargantext.Prelude as P ...@@ -53,13 +53,14 @@ import qualified Gargantext.Prelude as P
(./) :: ( Shape ix (./) :: ( Shape ix
, Slice ix , Slice ix
, Elt a , Elt a
, Eq a
, P.Num (Exp a) , P.Num (Exp a)
, P.Fractional (Exp a) , P.Fractional (Exp a)
) )
=> Acc (Array ((ix :. Int) :. Int) a) => Acc (Array ((ix :. Int) :. Int) a)
-> Acc (Array ((ix :. Int) :. Int) a) -> Acc (Array ((ix :. Int) :. Int) a)
-> Acc (Array ((ix :. Int) :. Int) a) -> Acc (Array ((ix :. Int) :. Int) a)
(./) = zipWith (/) (./) = zipWith safeDivCond
-- | Term by term division where divisions by 0 produce 0 rather than NaN. -- | Term by term division where divisions by 0 produce 0 rather than NaN.
termDivNan :: ( Elt a termDivNan :: ( Elt a
...@@ -70,7 +71,10 @@ termDivNan :: ( Elt a ...@@ -70,7 +71,10 @@ termDivNan :: ( Elt a
=> Acc (Matrix a) => Acc (Matrix a)
-> Acc (Matrix a) -> Acc (Matrix a)
-> Acc (Matrix a) -> Acc (Matrix a)
termDivNan = zipWith (\i j -> cond ((==) j 0) 0 ((/) i j)) termDivNan = zipWith safeDivCond
safeDivCond :: (Eq a, P.Num (Exp a), P.Fractional (Exp a)) => Exp a -> Exp a -> Exp a
safeDivCond i j = cond ((==) j 0) 0 ((/) i j)
(.-) :: ( Shape ix (.-) :: ( Shape ix
, Slice ix , Slice ix
......
...@@ -96,7 +96,30 @@ testMatrix_02 = SquareMatrix $ fromList (Z :. 7 :. 7) $ ...@@ -96,7 +96,30 @@ testMatrix_02 = SquareMatrix $ fromList (Z :. 7 :. 7) $
7, -12, 12, -2, 36, 10, 34, 7, -12, 12, -2, 36, 10, 34,
13, -37, -16, 2, 7, -13, 21] 13, -37, -16, 2, 7, -13, 21]
testMatrix_03 :: SquareMatrix Int
testMatrix_03 = SquareMatrix $ fromList (Z :. 11 :. 11) $
[ 1, -1, 1, 0, 1, -1, 0, 1, 1, 0, 0,
1, 1, 1, 1, 1, 0, 1, -1, 1, 0, 0,
-1, 1, 0, -1, 0, -1, 0, 1, 0, -1, 0,
1, 1, 1, -1, -1, 0, 1, -1, 0, 0, -1,
-1, 1, -1, -1, 0, 1, 1, 1, -1, -1, -1,
1, 1, 0, -1, -1, -1, 1, 0, 1, -1, -1,
-1, 1, 0, -1, 1, -1, 0, 1, -1, -1, -1,
1, 1, -1, 1, 1, 0, 1, -1, 1, -1, 1,
-1, -1, 0, 1, 1, 0, 1, 1, -1, 1, 0,
1, 1, 0, -1, 1, -1, 1, 0, 1, 0, -1,
1, 1, -1, 0, -1, -1, 1, 0, 1, 0, -1]
testMatrix_04 :: SquareMatrix Int
testMatrix_04 = SquareMatrix $ fromList (Z :. 8 :. 8) $
[ 3, -1, 0, 1, -1, 1, 1, -3,
-2, -2, 2, 1, 1, -2, 1, -1,
-2, -3, -1, 1, 1, -3, -2, -1,
1, -2, 2, 0, 1, 0, 2, 0,
-1, -3, -1, 3, -3, 0, -1, 2,
0, 0, -3, 3, -1, -2, -1, 1,
-2, 1, -1, 2, 1, -1, -2, 0,
-2, 2, 1, 1, 1, 0, 2, -3]
-- --
-- Main test runner -- Main test runner
-- --
...@@ -124,6 +147,8 @@ tests = testGroup "LinearAlgebra" [ ...@@ -124,6 +147,8 @@ tests = testGroup "LinearAlgebra" [
, testGroup "logDistributional2" [ , testGroup "logDistributional2" [
testProperty "2x2" (compareLogDistributional2 (Proxy @Double) twoByTwo) testProperty "2x2" (compareLogDistributional2 (Proxy @Double) twoByTwo)
, testProperty "7x7" (compareLogDistributional2 (Proxy @Double) testMatrix_02) , testProperty "7x7" (compareLogDistributional2 (Proxy @Double) testMatrix_02)
, testProperty "8x8" (compareLogDistributional2 (Proxy @Double) testMatrix_04)
, testProperty "11x11" (compareLogDistributional2 (Proxy @Double) testMatrix_03)
, testProperty "14x14" (compareLogDistributional2 (Proxy @Double) testMatrix_01) , testProperty "14x14" (compareLogDistributional2 (Proxy @Double) testMatrix_01)
,testProperty "roundtrips" (compareLogDistributional2 (Proxy @Double)) ,testProperty "roundtrips" (compareLogDistributional2 (Proxy @Double))
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment