Add implementations for Array
This commit is contained in:
parent
5914652b7c
commit
ff48a18478
|
@ -2,6 +2,7 @@ module Data.NumIdr.Array.Array
|
||||||
|
|
||||||
import Data.List1
|
import Data.List1
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.Zippable
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import Data.NumIdr.Array.Order
|
import Data.NumIdr.Array.Order
|
||||||
import Data.NumIdr.Array.Coords
|
import Data.NumIdr.Array.Coords
|
||||||
|
@ -95,6 +96,12 @@ getAllCoords (Z :: s) = []
|
||||||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||||
|
|
||||||
|
|
||||||
|
constant' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
|
||||||
|
constant' s ord x = MkArray ord (calcStrides ord s) s (create (product s) (const x))
|
||||||
|
|
||||||
|
constant : (s : Vect rk Nat) -> a -> Array s a
|
||||||
|
constant s = constant' s COrder
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are arranged into the provided shape using the provided order.
|
||| are arranged into the provided shape using the provided order.
|
||||||
|||
|
|||
|
||||||
|
@ -150,7 +157,7 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
||||||
||| Construct an array using a structure of nested vectors.
|
||| Construct an array using a structure of nested vectors.
|
||||||
||| To explicitly specify the shape and order of the array, use `array'`.
|
||| To explicitly specify the shape and order of the array, use `array'`.
|
||||||
export
|
export
|
||||||
array : {s : _} -> Vects s a -> Array s a
|
array : {s : Vect rk Nat} -> Vects s a -> Array s a
|
||||||
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,7 +177,10 @@ indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||||
indexRange rs arr = let ord = getOrder arr
|
indexRange rs arr = let ord = getOrder arr
|
||||||
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
||||||
sts = calcStrides ord s
|
sts = calcStrides ord s
|
||||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
-- TODO: Make this actually typecheck without resorting to this mess
|
||||||
|
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $
|
||||||
|
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||||
|
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -197,6 +207,15 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
||||||
reshape s' arr = reshape' s' (getOrder arr) arr
|
reshape s' arr = reshape' s' (getOrder arr) arr
|
||||||
|
|
||||||
|
|
||||||
|
||| Change the internal order of the array's elements.
|
||||||
|
export
|
||||||
|
reorder : Order -> Array s a -> Array s a
|
||||||
|
reorder ord' arr = let s = shape arr
|
||||||
|
sts = calcStrides ord' s
|
||||||
|
in rewrite shapeEq arr
|
||||||
|
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
||||||
|
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||||
|
getAllCoords' s)
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
@ -221,9 +240,101 @@ stack axis a b = let sA = shape a
|
||||||
sts = calcStrides COrder s
|
sts = calcStrides COrder s
|
||||||
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||||
|
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
display : Show a => Array s a -> IO ()
|
display : Show a => Array s a -> IO ()
|
||||||
display = printLn . PrimArray.toList . getPrim
|
display = printLn . PrimArray.toList . getPrim
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Implementations
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
-- Most of these implementations apply the operation pointwise. If there are
|
||||||
|
-- multiple arrays involved with different orders, then all of the arrays are
|
||||||
|
-- reordered to match one.
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Functor (Array s) where
|
||||||
|
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
Zippable (Array s) where
|
||||||
|
zipWith f a b = rewrite shapeEq a
|
||||||
|
in MkArray (getOrder a) (strides a) _ $
|
||||||
|
if getOrder a == getOrder b
|
||||||
|
then unsafeZipWith f (getPrim a) (getPrim b)
|
||||||
|
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||||
|
|
||||||
|
zipWith3 f a b c = rewrite shapeEq a
|
||||||
|
in MkArray (getOrder a) (strides a) _ $
|
||||||
|
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
|
||||||
|
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
|
||||||
|
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||||
|
(getPrim $ reorder (getOrder a) c)
|
||||||
|
|
||||||
|
unzipWith f arr = rewrite shapeEq arr
|
||||||
|
in case unzipWith f (getPrim arr) of
|
||||||
|
(a, b) => (MkArray (getOrder arr) (strides arr) _ a,
|
||||||
|
MkArray (getOrder arr) (strides arr) _ b)
|
||||||
|
|
||||||
|
unzipWith3 f arr = rewrite shapeEq arr
|
||||||
|
in case unzipWith3 f (getPrim arr) of
|
||||||
|
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a,
|
||||||
|
MkArray (getOrder arr) (strides arr) _ b,
|
||||||
|
MkArray (getOrder arr) (strides arr) _ c)
|
||||||
|
|
||||||
|
export
|
||||||
|
{s : _} -> Applicative (Array s) where
|
||||||
|
pure = constant s
|
||||||
|
(<*>) = zipWith apply
|
||||||
|
|
||||||
|
export
|
||||||
|
{s : _} -> Monad (Array s) where
|
||||||
|
join arr = fromFunction s (\is => index is $ index is arr)
|
||||||
|
|
||||||
|
|
||||||
|
-- Foldable and Traversable operate on the primitive array directly. This means
|
||||||
|
-- that their operation is dependent on the order of the array.
|
||||||
|
|
||||||
|
export
|
||||||
|
Foldable (Array s) where
|
||||||
|
foldl f z = foldl f z . getPrim
|
||||||
|
foldr f z = foldr f z . getPrim
|
||||||
|
null arr = size arr == Z
|
||||||
|
toList = toList . getPrim
|
||||||
|
|
||||||
|
export
|
||||||
|
Traversable (Array s) where
|
||||||
|
traverse f (MkArray ord sts s arr) =
|
||||||
|
map (MkArray ord sts s) (traverse f arr)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Eq a => Eq (Array s a) where
|
||||||
|
a == b = if getOrder a == getOrder b
|
||||||
|
then unsafeEq (getPrim a) (getPrim b)
|
||||||
|
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||||
|
|
||||||
|
export
|
||||||
|
Semigroup a => Semigroup (Array s a) where
|
||||||
|
(<+>) = zipWith (<+>)
|
||||||
|
|
||||||
|
export
|
||||||
|
{s : _} -> Monoid a => Monoid (Array s a) where
|
||||||
|
neutral = constant s neutral
|
||||||
|
|
||||||
|
|
||||||
|
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
|
||||||
|
-- were moved into its own interface, this constraint could be removed.
|
||||||
|
{s : _} -> Num a => Num (Array s a) where
|
||||||
|
(+) = zipWith (+)
|
||||||
|
(*) = zipWith (*)
|
||||||
|
|
||||||
|
fromInteger = constant s . fromInteger
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,6 @@ namespace Coords
|
||||||
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
||||||
|
|
||||||
|
|
||||||
export
|
|
||||||
index : Coords s -> Vects s a -> a
|
index : Coords s -> Vects s a -> a
|
||||||
index [] x = x
|
index [] x = x
|
||||||
index (i::is) v = index is $ index i v
|
index (i::is) v = index is $ index i v
|
||||||
|
|
|
@ -21,6 +21,15 @@ data Order : Type where
|
||||||
FOrder : Order
|
FOrder : Order
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
Eq Order where
|
||||||
|
COrder == COrder = True
|
||||||
|
FOrder == FOrder = True
|
||||||
|
COrder == FOrder = False
|
||||||
|
FOrder == COrder = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
||||||
scanr _ q0 [] = [q0]
|
scanr _ q0 [] = [q0]
|
||||||
scanr f q0 (x::xs) = f x (head qs) :: qs
|
scanr f q0 (x::xs) = f x (head qs) :: qs
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
module Data.NumIdr.PrimArray
|
module Data.NumIdr.PrimArray
|
||||||
|
|
||||||
|
import Data.Nat
|
||||||
|
import Data.IORef
|
||||||
import Data.IOArray.Prims
|
import Data.IOArray.Prims
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
@ -92,3 +94,56 @@ export
|
||||||
map : (a -> b) -> PrimArray a -> PrimArray b
|
map : (a -> b) -> PrimArray a -> PrimArray b
|
||||||
map f arr = create (length arr) (\n => f $ index n arr)
|
map f arr = create (length arr) (\n => f $ index n arr)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
unsafeZipWith : (a -> b -> c) -> PrimArray a -> PrimArray b -> PrimArray c
|
||||||
|
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
|
||||||
|
|
||||||
|
export
|
||||||
|
unsafeZipWith3 : (a -> b -> c -> d) ->
|
||||||
|
PrimArray a -> PrimArray b -> PrimArray c -> PrimArray d
|
||||||
|
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
|
||||||
|
|
||||||
|
export
|
||||||
|
unzipWith : (a -> (b, c)) -> PrimArray a -> (PrimArray b, PrimArray c)
|
||||||
|
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
unzipWith3 : (a -> (b, c, d)) -> PrimArray a -> (PrimArray b, PrimArray c, PrimArray d)
|
||||||
|
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
|
||||||
|
map ((\(_,y,_) => y) . f) arr,
|
||||||
|
map ((\(_,_,z) => z) . f) arr)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
foldl : (b -> a -> b) -> b -> PrimArray a -> b
|
||||||
|
foldl f z (MkPrimArray size arr) = unsafePerformIO $ do
|
||||||
|
ref <- newIORef z
|
||||||
|
for_ [0..pred size] $ \n => do
|
||||||
|
x <- readIORef ref
|
||||||
|
y <- arrayDataGet n arr
|
||||||
|
writeIORef ref (f x y)
|
||||||
|
readIORef ref
|
||||||
|
|
||||||
|
export
|
||||||
|
foldr : (a -> b -> b) -> b -> PrimArray a -> b
|
||||||
|
foldr f z (MkPrimArray size arr) = unsafePerformIO $ do
|
||||||
|
ref <- newIORef z
|
||||||
|
for_ [pred size..0] $ \n => do
|
||||||
|
x <- arrayDataGet n arr
|
||||||
|
y <- readIORef ref
|
||||||
|
writeIORef ref (f x y)
|
||||||
|
readIORef ref
|
||||||
|
|
||||||
|
export
|
||||||
|
traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
|
||||||
|
traverse f = map fromList . traverse f . toList
|
||||||
|
|
||||||
|
|
||||||
|
||| Compares two primitive arrays for equal elements. This function assumes
|
||||||
|
||| the arrays same the same length; it must not be used in any other case.
|
||||||
|
export
|
||||||
|
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
|
||||||
|
unsafeEq a b = unsafePerformIO $
|
||||||
|
map (concat @{All}) $ for [0..pred (arraySize a)] $
|
||||||
|
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)
|
||||||
|
|
Loading…
Reference in a new issue