Commit 756aea9e authored by Alfredo Di Napoli's avatar Alfredo Di Napoli Committed by Alfredo Di Napoli

Even more performance tuning for distributional

parent f6c42d01
......@@ -49,7 +49,7 @@ import Data.Bimap qualified as Bimap
import Data.List.Split qualified as Split
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as M
import Data.Massiv.Array (D, Matrix, Vector, Array, Ix3)
import Data.Massiv.Array (D, Matrix, Vector, Array, Ix3, U)
import Data.Massiv.Array qualified as A
import Data.Set qualified as S
import Data.Set (Set)
......@@ -155,8 +155,13 @@ distributional :: forall r e.
-> Matrix r e
distributional m' = result
where
mD :: Matrix D e
mD = A.map fromIntegral m'
m :: Matrix A.U e
m = A.compute $ A.map fromIntegral m'
m = A.compute mD
n :: Int
n = dim m'
-- Computes the diagonal matrix of the input ..
......@@ -169,27 +174,36 @@ distributional m' = result
-- Then we create a matrix that contains the same elements of diag_m
-- for the rows and columns, to make it square again.
d_1 :: Matrix A.U e
d_1 = A.makeArrayR A.U A.Seq (A.Sz2 n diag_m_size) $ \(_ A.:. i) -> diag_m A.! i
d_1 :: Matrix A.D e
d_1 = A.backpermute' (A.Sz2 n diag_m_size) (\(_ A.:. i) -> i) diag_m
d_2 :: Matrix A.D e
d_2 = A.backpermute' (A.Sz2 diag_m_size n) (\(i A.:. _) -> i) diag_m
d_2 :: Matrix A.U e
d_2 = A.makeArrayR A.U A.Seq (A.Sz2 n diag_m_size) $ \(i A.:. _) -> diag_m A.! i
a :: Matrix D e
a = termDivNanD mD d_1
mi :: Matrix A.U e
mi = (.*) (termDivNan @A.U m d_1) (termDivNan @A.U m d_2)
b :: Matrix D e
b = termDivNanD mD d_2
miDelayed :: Matrix D e
miDelayed = a `mulD` b
miMemo :: Matrix D e
miMemo = A.delay (A.compute @U miDelayed)
mi_r, mi_c :: Int
(A.Sz2 mi_r mi_c) = A.size mi
(A.Sz2 mi_r mi_c) = A.size miMemo
-- The matrix permutations is taken care of below by directly replicating
-- the matrix mi, making the matrix w unneccessary and saving one step.
-- replicate (constant (Z :. All :. n :. All)) mi
w_1 :: Array D Ix3 e
w_1 = A.backpermute' (A.Sz3 mi_r n mi_c) (\(x A.:> _y A.:. z) -> x A.:. z) mi
w_1 = A.backpermute' (A.Sz3 mi_r n mi_c) (\(x A.:> _y A.:. z) -> x A.:. z) miMemo
-- replicate (constant (Z :. n :. All :. All)) mi
w_2 :: Array D Ix3 e
w_2 = A.backpermute' (A.Sz3 n mi_r mi_c) (\(_x A.:> y A.:. z) -> y A.:. z) mi
w_2 = A.backpermute' (A.Sz3 n mi_r mi_c) (\(_x A.:> y A.:. z) -> y A.:. z) miMemo
w' :: Array D Ix3 e
w' = A.zipWith min w_1 w_2
......@@ -197,8 +211,8 @@ distributional m' = result
-- The matrix ii = [r_{i,j,k}]_{i,j,k} has r_(i,j,k) = 0 if k = i OR k = j
-- and r_(i,j,k) = 1 otherwise (i.e. k /= i AND k /= j).
-- generate (constant (Z :. n :. n :. n)) (lift1 (\( i A.:. j A.:. k) -> cond ((&&) ((/=) k i) ((/=) k j)) 1 0))
ii :: Array A.U Ix3 e
ii = A.makeArrayR A.U A.Seq (A.Sz3 n n n) $ \(i A.:> j A.:. k) -> if k /= i && k /= j then 1 else 0
ii :: Array A.D Ix3 e
ii = A.makeArrayR A.D A.Seq (A.Sz3 n n n) $ \(i A.:> j A.:. k) -> if k /= i && k /= j then 1 else 0
z_1 :: Matrix A.D e
z_1 = sumRowsD (w' `mulD` ii)
......@@ -206,14 +220,20 @@ distributional m' = result
z_2 :: Matrix A.D e
z_2 = sumRowsD (w_1 `mulD` ii)
result = termDivNan z_1 z_2
result = A.computeP (termDivNanD z_1 z_2)
-- | Term by term division where divisions by 0 produce 0 rather than NaN.
termDivNan :: (A.Manifest r3 a, A.Source r1 a, A.Source r2 a, Eq a, Fractional a)
=> Matrix r1 a
-> Matrix r2 a
-> Matrix r3 a
termDivNan m1 m2 = A.computeP $ A.zipWith (\i j -> if j == 0 then 0 else i / j) m1 m2
termDivNan m1 = A.compute . termDivNanD m1
termDivNanD :: (A.Source r1 a, A.Source r2 a, Eq a, Fractional a)
=> Matrix r1 a
-> Matrix r2 a
-> Matrix D a
termDivNanD m1 m2 = A.zipWith (\i j -> if j == 0 then 0 else i / j) m1 m2
sumRows :: ( A.Load r A.Ix2 e
, A.Source r e
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment