Add implementations for Array

This commit is contained in:
Kiana Sheibani 2022-05-20 08:40:24 -04:00
parent 5914652b7c
commit ff48a18478
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 177 additions and 3 deletions

View file

@ -2,6 +2,7 @@ module Data.NumIdr.Array.Array
import Data.List1 import Data.List1
import Data.Vect import Data.Vect
import Data.Zippable
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order import Data.NumIdr.Array.Order
import Data.NumIdr.Array.Coords import Data.NumIdr.Array.Coords
@ -95,6 +96,12 @@ getAllCoords (Z :: s) = []
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
constant' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
constant' s ord x = MkArray ord (calcStrides ord s) s (create (product s) (const x))
constant : (s : Vect rk Nat) -> a -> Array s a
constant s = constant' s COrder
||| Create an array given a vector of its elements. The elements of the vector ||| Create an array given a vector of its elements. The elements of the vector
||| are arranged into the provided shape using the provided order. ||| are arranged into the provided shape using the provided order.
||| |||
@ -150,7 +157,7 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
||| Construct an array using a structure of nested vectors. ||| Construct an array using a structure of nested vectors.
||| To explicitly specify the shape and order of the array, use `array'`. ||| To explicitly specify the shape and order of the array, use `array'`.
export export
array : {s : _} -> Vects s a -> Array s a array : {s : Vect rk Nat} -> Vects s a -> Array s a
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v) array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
@ -170,7 +177,10 @@ indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange rs arr = let ord = getOrder arr indexRange rs arr = let ord = getOrder arr
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
sts = calcStrides ord s sts = calcStrides ord s
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) -- TODO: Make this actually typecheck without resorting to this mess
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -197,6 +207,15 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
reshape s' arr = reshape' s' (getOrder arr) arr reshape s' arr = reshape' s' (getOrder arr) arr
||| Change the internal order of the array's elements.
export
reorder : Order -> Array s a -> Array s a
reorder ord' arr = let s = shape arr
sts = calcStrides ord' s
in rewrite shapeEq arr
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
export export
@ -221,9 +240,101 @@ stack axis a b = let sA = shape a
sts = calcStrides COrder s sts = calcStrides COrder s
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a) ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b) ++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
-- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
display : Show a => Array s a -> IO () display : Show a => Array s a -> IO ()
display = printLn . PrimArray.toList . getPrim display = printLn . PrimArray.toList . getPrim
--------------------------------------------------------------------------------
-- Implementations
--------------------------------------------------------------------------------
-- Most of these implementations apply the operation pointwise. If there are
-- multiple arrays involved with different orders, then all of the arrays are
-- reordered to match one.
export
Functor (Array s) where
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr)
export
Zippable (Array s) where
zipWith f a b = rewrite shapeEq a
in MkArray (getOrder a) (strides a) _ $
if getOrder a == getOrder b
then unsafeZipWith f (getPrim a) (getPrim b)
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
zipWith3 f a b c = rewrite shapeEq a
in MkArray (getOrder a) (strides a) _ $
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
(getPrim $ reorder (getOrder a) c)
unzipWith f arr = rewrite shapeEq arr
in case unzipWith f (getPrim arr) of
(a, b) => (MkArray (getOrder arr) (strides arr) _ a,
MkArray (getOrder arr) (strides arr) _ b)
unzipWith3 f arr = rewrite shapeEq arr
in case unzipWith3 f (getPrim arr) of
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a,
MkArray (getOrder arr) (strides arr) _ b,
MkArray (getOrder arr) (strides arr) _ c)
export
{s : _} -> Applicative (Array s) where
pure = constant s
(<*>) = zipWith apply
export
{s : _} -> Monad (Array s) where
join arr = fromFunction s (\is => index is $ index is arr)
-- Foldable and Traversable operate on the primitive array directly. This means
-- that their operation is dependent on the order of the array.
export
Foldable (Array s) where
foldl f z = foldl f z . getPrim
foldr f z = foldr f z . getPrim
null arr = size arr == Z
toList = toList . getPrim
export
Traversable (Array s) where
traverse f (MkArray ord sts s arr) =
map (MkArray ord sts s) (traverse f arr)
export
Eq a => Eq (Array s a) where
a == b = if getOrder a == getOrder b
then unsafeEq (getPrim a) (getPrim b)
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
export
Semigroup a => Semigroup (Array s a) where
(<+>) = zipWith (<+>)
export
{s : _} -> Monoid a => Monoid (Array s a) where
neutral = constant s neutral
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
-- were moved into its own interface, this constraint could be removed.
{s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+)
(*) = zipWith (*)
fromInteger = constant s . fromInteger

View file

@ -42,7 +42,6 @@ namespace Coords
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
export
index : Coords s -> Vects s a -> a index : Coords s -> Vects s a -> a
index [] x = x index [] x = x
index (i::is) v = index is $ index i v index (i::is) v = index is $ index i v

View file

@ -21,6 +21,15 @@ data Order : Type where
FOrder : Order FOrder : Order
public export
Eq Order where
COrder == COrder = True
FOrder == FOrder = True
COrder == FOrder = False
FOrder == COrder = False
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
scanr _ q0 [] = [q0] scanr _ q0 [] = [q0]
scanr f q0 (x::xs) = f x (head qs) :: qs scanr f q0 (x::xs) = f x (head qs) :: qs

View file

@ -1,5 +1,7 @@
module Data.NumIdr.PrimArray module Data.NumIdr.PrimArray
import Data.Nat
import Data.IORef
import Data.IOArray.Prims import Data.IOArray.Prims
%default total %default total
@ -92,3 +94,56 @@ export
map : (a -> b) -> PrimArray a -> PrimArray b map : (a -> b) -> PrimArray a -> PrimArray b
map f arr = create (length arr) (\n => f $ index n arr) map f arr = create (length arr) (\n => f $ index n arr)
export
unsafeZipWith : (a -> b -> c) -> PrimArray a -> PrimArray b -> PrimArray c
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
export
unsafeZipWith3 : (a -> b -> c -> d) ->
PrimArray a -> PrimArray b -> PrimArray c -> PrimArray d
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
export
unzipWith : (a -> (b, c)) -> PrimArray a -> (PrimArray b, PrimArray c)
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
export
unzipWith3 : (a -> (b, c, d)) -> PrimArray a -> (PrimArray b, PrimArray c, PrimArray d)
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
map ((\(_,y,_) => y) . f) arr,
map ((\(_,_,z) => z) . f) arr)
export
foldl : (b -> a -> b) -> b -> PrimArray a -> b
foldl f z (MkPrimArray size arr) = unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred size] $ \n => do
x <- readIORef ref
y <- arrayDataGet n arr
writeIORef ref (f x y)
readIORef ref
export
foldr : (a -> b -> b) -> b -> PrimArray a -> b
foldr f z (MkPrimArray size arr) = unsafePerformIO $ do
ref <- newIORef z
for_ [pred size..0] $ \n => do
x <- arrayDataGet n arr
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
traverse f = map fromList . traverse f . toList
||| Compares two primitive arrays for equal elements. This function assumes
||| the arrays same the same length; it must not be used in any other case.
export
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
unsafeEq a b = unsafePerformIO $
map (concat @{All}) $ for [0..pred (arraySize a)] $
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)