{-# OPTIONS_GHC -Wno-orphans #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE MultiWayIf #-}

module Gargantext.Orphans.Accelerate where

import Prelude
import Test.QuickCheck
import Data.Scientific ()
import Data.Array.Accelerate (DIM2, Z (..), (:.) (..), Array, Elt, fromList, arrayShape, DIM3)
import Data.Array.Accelerate qualified as A
import qualified Data.List.Split as Split

instance (Show e, Elt e, Arbitrary e, Num e, Ord e) => Arbitrary (Array DIM3 e) where
  arbitrary = do
    x <- choose (1,10)
    y <- choose (1,10)
    z <- choose (1,10)
    let sh = Z :. x :. y :. z
    fromList sh <$> vectorOf (x * y * z) (getPositive <$> arbitrary)

instance (Show e, Elt e, Arbitrary e) => Arbitrary (Array DIM2 e) where
  arbitrary = do
    x <- choose (1,128)
    y <- choose (1,48)
    let sh = Z :. x :. y
    fromList sh <$> vectorOf (x * y) arbitrary
  shrink arr = sliceArray arr

-- Slice the array to the new shape, keeping the square dimensions.
sliceArray :: (Elt e, Show e) => Array DIM2 e -> [Array DIM2 e]
sliceArray arr =
  case arrayShape arr of
    (Z :. x :. y) -> case (x, y) of
      (_,1) -> [ ]
      (1,_) -> [ ]
      _     -> [ resizeArray arr (max 1 (x - 1)) (max 1 (y - 1)) ]

resizeArray :: (Show e, Elt e) => Array DIM2 e -> Int -> Int -> Array DIM2 e
resizeArray arr rows cols =
  let (Z :. _originRows :. originCols) = arrayShape arr
      vals   = A.toList arr
      chunks = map (take cols) $ Split.chunksOf originCols vals
      m'     = mconcat $ take rows chunks
  in A.fromList (Z :. rows :. cols) m'
