diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 4da2a90..8e2f4b7 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -2,6 +2,7 @@ module Data.NumIdr.Array.Array import Data.List1 import Data.Vect +import Data.Zippable import Data.NumIdr.PrimArray import Data.NumIdr.Array.Order import Data.NumIdr.Array.Coords @@ -95,6 +96,12 @@ getAllCoords (Z :: s) = [] getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] +constant' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a +constant' s ord x = MkArray ord (calcStrides ord s) s (create (product s) (const x)) + +constant : (s : Vect rk Nat) -> a -> Array s a +constant s = constant' s COrder + ||| Create an array given a vector of its elements. The elements of the vector ||| are arranged into the provided shape using the provided order. ||| @@ -150,7 +157,7 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins) ||| Construct an array using a structure of nested vectors. ||| To explicitly specify the shape and order of the array, use `array'`. export -array : {s : _} -> Vects s a -> Array s a +array : {s : Vect rk Nat} -> Vects s a -> Array s a array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v) @@ -170,7 +177,10 @@ indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a indexRange rs arr = let ord = getOrder arr s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs sts = calcStrides ord s - in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) + -- TODO: Make this actually typecheck without resorting to this mess + in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ + map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ + getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) -------------------------------------------------------------------------------- @@ -197,6 +207,15 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a -> reshape s' arr = reshape' s' (getOrder arr) arr +||| Change the internal order of the array's elements. +export +reorder : Order -> Array s a -> Array s a +reorder ord' arr = let s = shape arr + sts = calcStrides ord' s + in rewrite shapeEq arr + in MkArray ord' sts _ (unsafeFromIns (product s) $ + map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ + getAllCoords' s) export @@ -221,9 +240,101 @@ stack axis a b = let sA = shape a sts = calcStrides COrder s ins = map (mapFst $ getLocation' sts . toNats) (enumerate a) ++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b) + -- TODO: prove that the type-level shape and `s` are equivalent in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) display : Show a => Array s a -> IO () display = printLn . PrimArray.toList . getPrim + + +-------------------------------------------------------------------------------- +-- Implementations +-------------------------------------------------------------------------------- + + +-- Most of these implementations apply the operation pointwise. If there are +-- multiple arrays involved with different orders, then all of the arrays are +-- reordered to match one. + + +export +Functor (Array s) where + map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr) + +export +Zippable (Array s) where + zipWith f a b = rewrite shapeEq a + in MkArray (getOrder a) (strides a) _ $ + if getOrder a == getOrder b + then unsafeZipWith f (getPrim a) (getPrim b) + else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b) + + zipWith3 f a b c = rewrite shapeEq a + in MkArray (getOrder a) (strides a) _ $ + if (getOrder a == getOrder b) && (getOrder b == getOrder c) + then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c) + else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b) + (getPrim $ reorder (getOrder a) c) + + unzipWith f arr = rewrite shapeEq arr + in case unzipWith f (getPrim arr) of + (a, b) => (MkArray (getOrder arr) (strides arr) _ a, + MkArray (getOrder arr) (strides arr) _ b) + + unzipWith3 f arr = rewrite shapeEq arr + in case unzipWith3 f (getPrim arr) of + (a, b, c) => (MkArray (getOrder arr) (strides arr) _ a, + MkArray (getOrder arr) (strides arr) _ b, + MkArray (getOrder arr) (strides arr) _ c) + +export +{s : _} -> Applicative (Array s) where + pure = constant s + (<*>) = zipWith apply + +export +{s : _} -> Monad (Array s) where + join arr = fromFunction s (\is => index is $ index is arr) + + +-- Foldable and Traversable operate on the primitive array directly. This means +-- that their operation is dependent on the order of the array. + +export +Foldable (Array s) where + foldl f z = foldl f z . getPrim + foldr f z = foldr f z . getPrim + null arr = size arr == Z + toList = toList . getPrim + +export +Traversable (Array s) where + traverse f (MkArray ord sts s arr) = + map (MkArray ord sts s) (traverse f arr) + + +export +Eq a => Eq (Array s a) where + a == b = if getOrder a == getOrder b + then unsafeEq (getPrim a) (getPrim b) + else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b) + +export +Semigroup a => Semigroup (Array s a) where + (<+>) = zipWith (<+>) + +export +{s : _} -> Monoid a => Monoid (Array s a) where + neutral = constant s neutral + + +-- the shape must be known at runtime due to `fromInteger`. If `fromInteger` +-- were moved into its own interface, this constraint could be removed. +{s : _} -> Num a => Num (Array s a) where + (+) = zipWith (+) + (*) = zipWith (*) + + fromInteger = constant s . fromInteger + diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 8e0b108..2882269 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -42,7 +42,6 @@ namespace Coords mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs - export index : Coords s -> Vects s a -> a index [] x = x index (i::is) v = index is $ index i v diff --git a/src/Data/NumIdr/Array/Order.idr b/src/Data/NumIdr/Array/Order.idr index 827ca78..c277771 100644 --- a/src/Data/NumIdr/Array/Order.idr +++ b/src/Data/NumIdr/Array/Order.idr @@ -21,6 +21,15 @@ data Order : Type where FOrder : Order +public export +Eq Order where + COrder == COrder = True + FOrder == FOrder = True + COrder == FOrder = False + FOrder == COrder = False + + + scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res scanr _ q0 [] = [q0] scanr f q0 (x::xs) = f x (head qs) :: qs diff --git a/src/Data/NumIdr/PrimArray.idr b/src/Data/NumIdr/PrimArray.idr index 3477d4a..a033aff 100644 --- a/src/Data/NumIdr/PrimArray.idr +++ b/src/Data/NumIdr/PrimArray.idr @@ -1,5 +1,7 @@ module Data.NumIdr.PrimArray +import Data.Nat +import Data.IORef import Data.IOArray.Prims %default total @@ -92,3 +94,56 @@ export map : (a -> b) -> PrimArray a -> PrimArray b map f arr = create (length arr) (\n => f $ index n arr) + +export +unsafeZipWith : (a -> b -> c) -> PrimArray a -> PrimArray b -> PrimArray c +unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b)) + +export +unsafeZipWith3 : (a -> b -> c -> d) -> + PrimArray a -> PrimArray b -> PrimArray c -> PrimArray d +unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c)) + +export +unzipWith : (a -> (b, c)) -> PrimArray a -> (PrimArray b, PrimArray c) +unzipWith f arr = (map (fst . f) arr, map (snd . f) arr) + +export +unzipWith3 : (a -> (b, c, d)) -> PrimArray a -> (PrimArray b, PrimArray c, PrimArray d) +unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr, + map ((\(_,y,_) => y) . f) arr, + map ((\(_,_,z) => z) . f) arr) + + +export +foldl : (b -> a -> b) -> b -> PrimArray a -> b +foldl f z (MkPrimArray size arr) = unsafePerformIO $ do + ref <- newIORef z + for_ [0..pred size] $ \n => do + x <- readIORef ref + y <- arrayDataGet n arr + writeIORef ref (f x y) + readIORef ref + +export +foldr : (a -> b -> b) -> b -> PrimArray a -> b +foldr f z (MkPrimArray size arr) = unsafePerformIO $ do + ref <- newIORef z + for_ [pred size..0] $ \n => do + x <- arrayDataGet n arr + y <- readIORef ref + writeIORef ref (f x y) + readIORef ref + +export +traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b) +traverse f = map fromList . traverse f . toList + + +||| Compares two primitive arrays for equal elements. This function assumes +||| the arrays same the same length; it must not be used in any other case. +export +unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool +unsafeEq a b = unsafePerformIO $ + map (concat @{All}) $ for [0..pred (arraySize a)] $ + \n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)