Commit 756aea9e authored by Alfredo Di Napoli's avatar Alfredo Di Napoli Committed by Alfredo Di Napoli

Even more performance tuning for distributional

parent f6c42d01
...@@ -49,7 +49,7 @@ import Data.Bimap qualified as Bimap ...@@ -49,7 +49,7 @@ import Data.Bimap qualified as Bimap
import Data.List.Split qualified as Split import Data.List.Split qualified as Split
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Map.Strict qualified as M import Data.Map.Strict qualified as M
import Data.Massiv.Array (D, Matrix, Vector, Array, Ix3) import Data.Massiv.Array (D, Matrix, Vector, Array, Ix3, U)
import Data.Massiv.Array qualified as A import Data.Massiv.Array qualified as A
import Data.Set qualified as S import Data.Set qualified as S
import Data.Set (Set) import Data.Set (Set)
...@@ -155,8 +155,13 @@ distributional :: forall r e. ...@@ -155,8 +155,13 @@ distributional :: forall r e.
-> Matrix r e -> Matrix r e
distributional m' = result distributional m' = result
where where
mD :: Matrix D e
mD = A.map fromIntegral m'
m :: Matrix A.U e m :: Matrix A.U e
m = A.compute $ A.map fromIntegral m' m = A.compute mD
n :: Int
n = dim m' n = dim m'
-- Computes the diagonal matrix of the input .. -- Computes the diagonal matrix of the input ..
...@@ -169,27 +174,36 @@ distributional m' = result ...@@ -169,27 +174,36 @@ distributional m' = result
-- Then we create a matrix that contains the same elements of diag_m -- Then we create a matrix that contains the same elements of diag_m
-- for the rows and columns, to make it square again. -- for the rows and columns, to make it square again.
d_1 :: Matrix A.U e d_1 :: Matrix A.D e
d_1 = A.makeArrayR A.U A.Seq (A.Sz2 n diag_m_size) $ \(_ A.:. i) -> diag_m A.! i d_1 = A.backpermute' (A.Sz2 n diag_m_size) (\(_ A.:. i) -> i) diag_m
d_2 :: Matrix A.D e
d_2 = A.backpermute' (A.Sz2 diag_m_size n) (\(i A.:. _) -> i) diag_m
d_2 :: Matrix A.U e a :: Matrix D e
d_2 = A.makeArrayR A.U A.Seq (A.Sz2 n diag_m_size) $ \(i A.:. _) -> diag_m A.! i a = termDivNanD mD d_1
mi :: Matrix A.U e b :: Matrix D e
mi = (.*) (termDivNan @A.U m d_1) (termDivNan @A.U m d_2) b = termDivNanD mD d_2
miDelayed :: Matrix D e
miDelayed = a `mulD` b
miMemo :: Matrix D e
miMemo = A.delay (A.compute @U miDelayed)
mi_r, mi_c :: Int mi_r, mi_c :: Int
(A.Sz2 mi_r mi_c) = A.size mi (A.Sz2 mi_r mi_c) = A.size miMemo
-- The matrix permutations is taken care of below by directly replicating -- The matrix permutations is taken care of below by directly replicating
-- the matrix mi, making the matrix w unneccessary and saving one step. -- the matrix mi, making the matrix w unneccessary and saving one step.
-- replicate (constant (Z :. All :. n :. All)) mi -- replicate (constant (Z :. All :. n :. All)) mi
w_1 :: Array D Ix3 e w_1 :: Array D Ix3 e
w_1 = A.backpermute' (A.Sz3 mi_r n mi_c) (\(x A.:> _y A.:. z) -> x A.:. z) mi w_1 = A.backpermute' (A.Sz3 mi_r n mi_c) (\(x A.:> _y A.:. z) -> x A.:. z) miMemo
-- replicate (constant (Z :. n :. All :. All)) mi -- replicate (constant (Z :. n :. All :. All)) mi
w_2 :: Array D Ix3 e w_2 :: Array D Ix3 e
w_2 = A.backpermute' (A.Sz3 n mi_r mi_c) (\(_x A.:> y A.:. z) -> y A.:. z) mi w_2 = A.backpermute' (A.Sz3 n mi_r mi_c) (\(_x A.:> y A.:. z) -> y A.:. z) miMemo
w' :: Array D Ix3 e w' :: Array D Ix3 e
w' = A.zipWith min w_1 w_2 w' = A.zipWith min w_1 w_2
...@@ -197,8 +211,8 @@ distributional m' = result ...@@ -197,8 +211,8 @@ distributional m' = result
-- The matrix ii = [r_{i,j,k}]_{i,j,k} has r_(i,j,k) = 0 if k = i OR k = j -- The matrix ii = [r_{i,j,k}]_{i,j,k} has r_(i,j,k) = 0 if k = i OR k = j
-- and r_(i,j,k) = 1 otherwise (i.e. k /= i AND k /= j). -- and r_(i,j,k) = 1 otherwise (i.e. k /= i AND k /= j).
-- generate (constant (Z :. n :. n :. n)) (lift1 (\( i A.:. j A.:. k) -> cond ((&&) ((/=) k i) ((/=) k j)) 1 0)) -- generate (constant (Z :. n :. n :. n)) (lift1 (\( i A.:. j A.:. k) -> cond ((&&) ((/=) k i) ((/=) k j)) 1 0))
ii :: Array A.U Ix3 e ii :: Array A.D Ix3 e
ii = A.makeArrayR A.U A.Seq (A.Sz3 n n n) $ \(i A.:> j A.:. k) -> if k /= i && k /= j then 1 else 0 ii = A.makeArrayR A.D A.Seq (A.Sz3 n n n) $ \(i A.:> j A.:. k) -> if k /= i && k /= j then 1 else 0
z_1 :: Matrix A.D e z_1 :: Matrix A.D e
z_1 = sumRowsD (w' `mulD` ii) z_1 = sumRowsD (w' `mulD` ii)
...@@ -206,14 +220,20 @@ distributional m' = result ...@@ -206,14 +220,20 @@ distributional m' = result
z_2 :: Matrix A.D e z_2 :: Matrix A.D e
z_2 = sumRowsD (w_1 `mulD` ii) z_2 = sumRowsD (w_1 `mulD` ii)
result = termDivNan z_1 z_2 result = A.computeP (termDivNanD z_1 z_2)
-- | Term by term division where divisions by 0 produce 0 rather than NaN. -- | Term by term division where divisions by 0 produce 0 rather than NaN.
termDivNan :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, Eq a, Fractional a) termDivNan :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, Eq a, Fractional a)
=> Matrix r1 a => Matrix r1 a
-> Matrix r2 a -> Matrix r2 a
-> Matrix r3 a -> Matrix r3 a
termDivNan m1 m2 = A.computeP $ A.zipWith (\i j -> if j == 0 then 0 else i / j) m1 m2 termDivNan m1 = A.compute . termDivNanD m1
termDivNanD :: (A.Source r1 a, A.Source r2 a, Eq a, Fractional a)
=> Matrix r1 a
-> Matrix r2 a
-> Matrix D a
termDivNanD m1 m2 = A.zipWith (\i j -> if j == 0 then 0 else i / j) m1 m2
sumRows :: ( A.Load r A.Ix2 e sumRows :: ( A.Load r A.Ix2 e
, A.Source r e , A.Source r e
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment