Commit 13131f51 authored by Alp Mestanogullari's avatar Alp Mestanogullari

update for accelerate 1.3

parent f68f9e78
...@@ -110,7 +110,7 @@ matrixIdentity n' = ...@@ -110,7 +110,7 @@ matrixIdentity n' =
ones = fill (index1 n) 1 ones = fill (index1 n) 1
n = constant n' n = constant n'
in in
permute const zeros (\(unindex1 -> i) -> index2 i i) ones permute const zeros (\(unindex1 -> i) -> Just_ $ index2 i i) ones
matrixEye :: Num a => Dim -> Acc (Matrix a) matrixEye :: Num a => Dim -> Acc (Matrix a)
...@@ -119,7 +119,7 @@ matrixEye n' = ...@@ -119,7 +119,7 @@ matrixEye n' =
zeros = fill (index1 n) 0 zeros = fill (index1 n) 0
n = constant n' n = constant n'
in in
permute const ones (\(unindex1 -> i) -> index2 i i) zeros permute const ones (\(unindex1 -> i) -> Just_ $ index2 i i) zeros
diagNull :: Num a => Dim -> Acc (Matrix a) -> Acc (Matrix a) diagNull :: Num a => Dim -> Acc (Matrix a) -> Acc (Matrix a)
...@@ -134,7 +134,7 @@ condOrDefault ...@@ -134,7 +134,7 @@ condOrDefault
condOrDefault theCond def x = permute const zeros filterInd x condOrDefault theCond def x = permute const zeros filterInd x
where where
zeros = fill (shape x) (def) zeros = fill (shape x) (def)
filterInd ix = (cond (theCond ix)) ix ignore filterInd ix = (cond (theCond ix)) (Just_ ix) Nothing_
----------------------------------------------------------------------- -----------------------------------------------------------------------
_runExp :: Elt e => Exp e -> e _runExp :: Elt e => Exp e -> e
...@@ -163,7 +163,7 @@ matrix n l = fromList (Z :. n :. n) l ...@@ -163,7 +163,7 @@ matrix n l = fromList (Z :. n :. n) l
-- >>> rank (matrix 3 ([1..] :: [Int])) -- >>> rank (matrix 3 ([1..] :: [Int]))
-- 2 -- 2
rank :: (Matrix a) -> Int rank :: (Matrix a) -> Int
rank m = arrayRank $ arrayShape m rank m = arrayRank m
----------------------------------------------------------------------- -----------------------------------------------------------------------
-- | Dimension of a square Matrix -- | Dimension of a square Matrix
...@@ -278,7 +278,7 @@ nullOf n' dir = ...@@ -278,7 +278,7 @@ nullOf n' dir =
zeros = fill (index2 n n) 0 zeros = fill (index2 n n) 0
n = constant n' n = constant n'
in in
permute const ones ( lift1 ( \(Z :. (i :: Exp Int) :. (_j:: Exp Int)) permute const ones ( Just_ . lift1 ( \(Z :. (i :: Exp Int) :. (_j:: Exp Int))
-> case dir of -> case dir of
MatCol m -> (Z :. i :. m) MatCol m -> (Z :. i :. m)
MatRow m -> (Z :. m :. i) MatRow m -> (Z :. m :. i)
...@@ -308,7 +308,7 @@ sumRowMin n m = {-trace (P.show $ run m') $-} m' ...@@ -308,7 +308,7 @@ sumRowMin n m = {-trace (P.show $ run m') $-} m'
$ P.map (\z -> sumRowMin1 n (constant z) m) [0..n-1] $ P.map (\z -> sumRowMin1 n (constant z) m) [0..n-1]
sumRowMin1 :: (Num a, Ord a) => Dim -> Exp Int -> Acc (Matrix a) -> Acc (Vector a) sumRowMin1 :: (Num a, Ord a) => Dim -> Exp Int -> Acc (Matrix a) -> Acc (Vector a)
sumRowMin1 n x m = trace (P.show (run m,run $ transpose m)) $ m'' sumRowMin1 n x m = {- trace (P.show (run m,run $ transpose m)) $ -} m''
where where
m'' = sum $ zipWith min (transpose m) m m'' = sum $ zipWith min (transpose m) m
_m' = zipWith (*) (zipWith (*) (nullOf n (MatCol x)) $ nullOfWithDiag n (MatRow x)) m _m' = zipWith (*) (zipWith (*) (nullOf n (MatCol x)) $ nullOfWithDiag n (MatRow x)) m
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment