Add implementations for Array
This commit is contained in:
parent
5914652b7c
commit
ff48a18478
|
@ -2,6 +2,7 @@ module Data.NumIdr.Array.Array
|
|||
|
||||
import Data.List1
|
||||
import Data.Vect
|
||||
import Data.Zippable
|
||||
import Data.NumIdr.PrimArray
|
||||
import Data.NumIdr.Array.Order
|
||||
import Data.NumIdr.Array.Coords
|
||||
|
@ -95,6 +96,12 @@ getAllCoords (Z :: s) = []
|
|||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||
|
||||
|
||||
constant' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
|
||||
constant' s ord x = MkArray ord (calcStrides ord s) s (create (product s) (const x))
|
||||
|
||||
constant : (s : Vect rk Nat) -> a -> Array s a
|
||||
constant s = constant' s COrder
|
||||
|
||||
||| Create an array given a vector of its elements. The elements of the vector
|
||||
||| are arranged into the provided shape using the provided order.
|
||||
|||
|
||||
|
@ -150,7 +157,7 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
|||
||| Construct an array using a structure of nested vectors.
|
||||
||| To explicitly specify the shape and order of the array, use `array'`.
|
||||
export
|
||||
array : {s : _} -> Vects s a -> Array s a
|
||||
array : {s : Vect rk Nat} -> Vects s a -> Array s a
|
||||
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||
|
||||
|
||||
|
@ -170,7 +177,10 @@ indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
|||
indexRange rs arr = let ord = getOrder arr
|
||||
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
||||
sts = calcStrides ord s
|
||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||
-- TODO: Make this actually typecheck without resorting to this mess
|
||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $
|
||||
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -197,6 +207,15 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
|||
reshape s' arr = reshape' s' (getOrder arr) arr
|
||||
|
||||
|
||||
||| Change the internal order of the array's elements.
|
||||
export
|
||||
reorder : Order -> Array s a -> Array s a
|
||||
reorder ord' arr = let s = shape arr
|
||||
sts = calcStrides ord' s
|
||||
in rewrite shapeEq arr
|
||||
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
||||
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||
getAllCoords' s)
|
||||
|
||||
|
||||
export
|
||||
|
@ -221,9 +240,101 @@ stack axis a b = let sA = shape a
|
|||
sts = calcStrides COrder s
|
||||
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||
|
||||
|
||||
|
||||
display : Show a => Array s a -> IO ()
|
||||
display = printLn . PrimArray.toList . getPrim
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Implementations
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
-- Most of these implementations apply the operation pointwise. If there are
|
||||
-- multiple arrays involved with different orders, then all of the arrays are
|
||||
-- reordered to match one.
|
||||
|
||||
|
||||
export
|
||||
Functor (Array s) where
|
||||
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr)
|
||||
|
||||
export
|
||||
Zippable (Array s) where
|
||||
zipWith f a b = rewrite shapeEq a
|
||||
in MkArray (getOrder a) (strides a) _ $
|
||||
if getOrder a == getOrder b
|
||||
then unsafeZipWith f (getPrim a) (getPrim b)
|
||||
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||
|
||||
zipWith3 f a b c = rewrite shapeEq a
|
||||
in MkArray (getOrder a) (strides a) _ $
|
||||
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
|
||||
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
|
||||
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||
(getPrim $ reorder (getOrder a) c)
|
||||
|
||||
unzipWith f arr = rewrite shapeEq arr
|
||||
in case unzipWith f (getPrim arr) of
|
||||
(a, b) => (MkArray (getOrder arr) (strides arr) _ a,
|
||||
MkArray (getOrder arr) (strides arr) _ b)
|
||||
|
||||
unzipWith3 f arr = rewrite shapeEq arr
|
||||
in case unzipWith3 f (getPrim arr) of
|
||||
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a,
|
||||
MkArray (getOrder arr) (strides arr) _ b,
|
||||
MkArray (getOrder arr) (strides arr) _ c)
|
||||
|
||||
export
|
||||
{s : _} -> Applicative (Array s) where
|
||||
pure = constant s
|
||||
(<*>) = zipWith apply
|
||||
|
||||
export
|
||||
{s : _} -> Monad (Array s) where
|
||||
join arr = fromFunction s (\is => index is $ index is arr)
|
||||
|
||||
|
||||
-- Foldable and Traversable operate on the primitive array directly. This means
|
||||
-- that their operation is dependent on the order of the array.
|
||||
|
||||
export
|
||||
Foldable (Array s) where
|
||||
foldl f z = foldl f z . getPrim
|
||||
foldr f z = foldr f z . getPrim
|
||||
null arr = size arr == Z
|
||||
toList = toList . getPrim
|
||||
|
||||
export
|
||||
Traversable (Array s) where
|
||||
traverse f (MkArray ord sts s arr) =
|
||||
map (MkArray ord sts s) (traverse f arr)
|
||||
|
||||
|
||||
export
|
||||
Eq a => Eq (Array s a) where
|
||||
a == b = if getOrder a == getOrder b
|
||||
then unsafeEq (getPrim a) (getPrim b)
|
||||
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||
|
||||
export
|
||||
Semigroup a => Semigroup (Array s a) where
|
||||
(<+>) = zipWith (<+>)
|
||||
|
||||
export
|
||||
{s : _} -> Monoid a => Monoid (Array s a) where
|
||||
neutral = constant s neutral
|
||||
|
||||
|
||||
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
|
||||
-- were moved into its own interface, this constraint could be removed.
|
||||
{s : _} -> Num a => Num (Array s a) where
|
||||
(+) = zipWith (+)
|
||||
(*) = zipWith (*)
|
||||
|
||||
fromInteger = constant s . fromInteger
|
||||
|
||||
|
|
|
@ -42,7 +42,6 @@ namespace Coords
|
|||
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
||||
|
||||
|
||||
export
|
||||
index : Coords s -> Vects s a -> a
|
||||
index [] x = x
|
||||
index (i::is) v = index is $ index i v
|
||||
|
|
|
@ -21,6 +21,15 @@ data Order : Type where
|
|||
FOrder : Order
|
||||
|
||||
|
||||
public export
|
||||
Eq Order where
|
||||
COrder == COrder = True
|
||||
FOrder == FOrder = True
|
||||
COrder == FOrder = False
|
||||
FOrder == COrder = False
|
||||
|
||||
|
||||
|
||||
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
||||
scanr _ q0 [] = [q0]
|
||||
scanr f q0 (x::xs) = f x (head qs) :: qs
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
module Data.NumIdr.PrimArray
|
||||
|
||||
import Data.Nat
|
||||
import Data.IORef
|
||||
import Data.IOArray.Prims
|
||||
|
||||
%default total
|
||||
|
@ -92,3 +94,56 @@ export
|
|||
map : (a -> b) -> PrimArray a -> PrimArray b
|
||||
map f arr = create (length arr) (\n => f $ index n arr)
|
||||
|
||||
|
||||
export
|
||||
unsafeZipWith : (a -> b -> c) -> PrimArray a -> PrimArray b -> PrimArray c
|
||||
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
|
||||
|
||||
export
|
||||
unsafeZipWith3 : (a -> b -> c -> d) ->
|
||||
PrimArray a -> PrimArray b -> PrimArray c -> PrimArray d
|
||||
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
|
||||
|
||||
export
|
||||
unzipWith : (a -> (b, c)) -> PrimArray a -> (PrimArray b, PrimArray c)
|
||||
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
|
||||
|
||||
export
|
||||
unzipWith3 : (a -> (b, c, d)) -> PrimArray a -> (PrimArray b, PrimArray c, PrimArray d)
|
||||
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
|
||||
map ((\(_,y,_) => y) . f) arr,
|
||||
map ((\(_,_,z) => z) . f) arr)
|
||||
|
||||
|
||||
export
|
||||
foldl : (b -> a -> b) -> b -> PrimArray a -> b
|
||||
foldl f z (MkPrimArray size arr) = unsafePerformIO $ do
|
||||
ref <- newIORef z
|
||||
for_ [0..pred size] $ \n => do
|
||||
x <- readIORef ref
|
||||
y <- arrayDataGet n arr
|
||||
writeIORef ref (f x y)
|
||||
readIORef ref
|
||||
|
||||
export
|
||||
foldr : (a -> b -> b) -> b -> PrimArray a -> b
|
||||
foldr f z (MkPrimArray size arr) = unsafePerformIO $ do
|
||||
ref <- newIORef z
|
||||
for_ [pred size..0] $ \n => do
|
||||
x <- arrayDataGet n arr
|
||||
y <- readIORef ref
|
||||
writeIORef ref (f x y)
|
||||
readIORef ref
|
||||
|
||||
export
|
||||
traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
|
||||
traverse f = map fromList . traverse f . toList
|
||||
|
||||
|
||||
||| Compares two primitive arrays for equal elements. This function assumes
|
||||
||| the arrays same the same length; it must not be used in any other case.
|
||||
export
|
||||
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
|
||||
unsafeEq a b = unsafePerformIO $
|
||||
map (concat @{All}) $ for [0..pred (arraySize a)] $
|
||||
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)
|
||||
|
|
Loading…
Reference in a new issue