Add range indexing support
This commit is contained in:
parent
76e16574f1
commit
5914652b7c
2 changed files with 197 additions and 45 deletions
|
|
@ -1,5 +1,6 @@
|
|||
module Data.NumIdr.Array.Array
|
||||
|
||||
import Data.List1
|
||||
import Data.Vect
|
||||
import Data.NumIdr.PrimArray
|
||||
import Data.NumIdr.Array.Order
|
||||
|
|
@ -38,6 +39,11 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
|
|||
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Properties of arrays
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
||| Extract the primitive array value.
|
||||
export
|
||||
getPrim : Array s a -> PrimArray a
|
||||
|
|
@ -50,8 +56,8 @@ getOrder (MkArray ord _ _ _) = ord
|
|||
|
||||
||| The strides of the array, returned in the same axis order as in the shape.
|
||||
export
|
||||
getStrides : Array {rk} s a -> Vect rk Nat
|
||||
getStrides (MkArray _ sts _ _) = sts
|
||||
strides : Array {rk} s a -> Vect rk Nat
|
||||
strides (MkArray _ sts _ _) = sts
|
||||
|
||||
||| The total number of elements of the array
|
||||
||| This is equivalent to `product s`.
|
||||
|
|
@ -70,6 +76,25 @@ rank : Array s a -> Nat
|
|||
rank = length . shape
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Array constructors
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
shapeEq : (arr : Array s a) -> s = shape arr
|
||||
shapeEq (MkArray _ _ _ _) = Refl
|
||||
|
||||
|
||||
-- Get a list of all coordinates
|
||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||
|
||||
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
|
||||
getAllCoords [] = pure []
|
||||
getAllCoords (Z :: s) = []
|
||||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||
|
||||
|
||||
||| Create an array given a vector of its elements. The elements of the vector
|
||||
||| are arranged into the provided shape using the provided order.
|
||||
|||
|
||||
|
|
@ -82,12 +107,30 @@ fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
|||
||| Create an array given a vector of its elements. The elements of the vector
|
||||
||| are assembled into the provided shape using row-major order (the last axis is the
|
||||
||| least significant).
|
||||
||| To specify the order of the array, see `fromVect'`.
|
||||
|||
|
||||
||| @ s The shape of the constructed array
|
||||
export
|
||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||
fromVect s = fromVect' s COrder
|
||||
|
||||
||| Create an array given a function to generate its elements.
|
||||
|||
|
||||
||| @ s The shape of the constructed array
|
||||
||| @ ord The order to interpret the elements
|
||||
export
|
||||
fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a
|
||||
fromFunction' s ord f = let sts = calcStrides ord s
|
||||
in MkArray ord sts s (unsafeFromIns (product s) $
|
||||
map (\is => (getLocation sts is, f is)) $ getAllCoords s)
|
||||
|
||||
||| Create an array given a function to generate its elements.
|
||||
||| To specify the order of the array, use `fromFunction'`.
|
||||
|||
|
||||
||| @ s The shape of the constructed array
|
||||
export
|
||||
fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
|
||||
fromFunction s = fromFunction' s COrder
|
||||
|
||||
||| Construct an array using a structure of nested vectors. The elements are arranged
|
||||
||| to the specified order before being written.
|
||||
|
|
@ -102,14 +145,39 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
|||
sts = calcStrides ord s
|
||||
|
||||
ins : List (Nat, a)
|
||||
ins = collapse $ mapWithIndex (\is,x => (getLocation' sts is, x)) v
|
||||
ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v
|
||||
|
||||
||| Construct an array using a structure of nested vectors.
|
||||
||| To explicitly specify the shape and order of the array, use `array'`.
|
||||
export
|
||||
array : {s : _} -> Vects s a -> Array s a
|
||||
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Indexing
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
||| Index the array using the given `Coords` object.
|
||||
export
|
||||
index : Coords s -> Array s a -> a
|
||||
index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
||||
|
||||
||| Index the array using the given `CoordsRange` object.
|
||||
export
|
||||
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||
indexRange rs arr = let ord = getOrder arr
|
||||
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
||||
sts = calcStrides ord s
|
||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Operations on arrays
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
||| Reshape the array into the given shape and reinterpret it according to
|
||||
||| the given order.
|
||||
|||
|
||||
|
|
@ -129,7 +197,33 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
|||
reshape s' arr = reshape' s' (getOrder arr) arr
|
||||
|
||||
|
||||
||| Index the array using the given `Coords` object.
|
||||
|
||||
|
||||
export
|
||||
index : Coords s -> Array s a -> a
|
||||
index is arr = index (getLocation (getStrides arr) is) (getPrim arr)
|
||||
enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
|
||||
enumerate' (MkArray _ sts sh p) =
|
||||
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
||||
|
||||
|
||||
export
|
||||
enumerate : Array {rk} s a -> List (Coords s, a)
|
||||
enumerate arr = map (\is => (is, index is arr))
|
||||
(rewrite shapeEq arr in getAllCoords $ shape arr)
|
||||
|
||||
export
|
||||
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||
Array (replaceAt axis (index axis s + d) s) a
|
||||
stack axis a b = let sA = shape a
|
||||
sB = shape b
|
||||
dA = index axis sA
|
||||
dB = index axis sB
|
||||
s = replaceAt axis (dA + dB) sA
|
||||
sts = calcStrides COrder s
|
||||
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||
|
||||
|
||||
|
||||
display : Show a => Array s a -> IO ()
|
||||
display = printLn . PrimArray.toList . getPrim
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue