Add range indexing support

This commit is contained in:
Kiana Sheibani 2022-05-19 08:49:57 -04:00
parent 76e16574f1
commit 5914652b7c
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 197 additions and 45 deletions

View file

@ -1,5 +1,6 @@
module Data.NumIdr.Array.Array module Data.NumIdr.Array.Array
import Data.List1
import Data.Vect import Data.Vect
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order import Data.NumIdr.Array.Order
@ -38,6 +39,11 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
(s : Vect rk Nat) -> PrimArray a -> Array s a (s : Vect rk Nat) -> PrimArray a -> Array s a
--------------------------------------------------------------------------------
-- Properties of arrays
--------------------------------------------------------------------------------
||| Extract the primitive array value. ||| Extract the primitive array value.
export export
getPrim : Array s a -> PrimArray a getPrim : Array s a -> PrimArray a
@ -50,8 +56,8 @@ getOrder (MkArray ord _ _ _) = ord
||| The strides of the array, returned in the same axis order as in the shape. ||| The strides of the array, returned in the same axis order as in the shape.
export export
getStrides : Array {rk} s a -> Vect rk Nat strides : Array {rk} s a -> Vect rk Nat
getStrides (MkArray _ sts _ _) = sts strides (MkArray _ sts _ _) = sts
||| The total number of elements of the array ||| The total number of elements of the array
||| This is equivalent to `product s`. ||| This is equivalent to `product s`.
@ -70,6 +76,25 @@ rank : Array s a -> Nat
rank = length . shape rank = length . shape
--------------------------------------------------------------------------------
-- Array constructors
--------------------------------------------------------------------------------
shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl
-- Get a list of all coordinates
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
getAllCoords' = traverse (\case Z => []; S n => [0..n])
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
getAllCoords [] = pure []
getAllCoords (Z :: s) = []
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
||| Create an array given a vector of its elements. The elements of the vector ||| Create an array given a vector of its elements. The elements of the vector
||| are arranged into the provided shape using the provided order. ||| are arranged into the provided shape using the provided order.
||| |||
@ -82,12 +107,30 @@ fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
||| Create an array given a vector of its elements. The elements of the vector ||| Create an array given a vector of its elements. The elements of the vector
||| are assembled into the provided shape using row-major order (the last axis is the ||| are assembled into the provided shape using row-major order (the last axis is the
||| least significant). ||| least significant).
||| To specify the order of the array, see `fromVect'`.
||| |||
||| @ s The shape of the constructed array ||| @ s The shape of the constructed array
export export
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
fromVect s = fromVect' s COrder fromVect s = fromVect' s COrder
||| Create an array given a function to generate its elements.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
export
fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a
fromFunction' s ord f = let sts = calcStrides ord s
in MkArray ord sts s (unsafeFromIns (product s) $
map (\is => (getLocation sts is, f is)) $ getAllCoords s)
||| Create an array given a function to generate its elements.
||| To specify the order of the array, use `fromFunction'`.
|||
||| @ s The shape of the constructed array
export
fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
fromFunction s = fromFunction' s COrder
||| Construct an array using a structure of nested vectors. The elements are arranged ||| Construct an array using a structure of nested vectors. The elements are arranged
||| to the specified order before being written. ||| to the specified order before being written.
@ -102,14 +145,39 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
sts = calcStrides ord s sts = calcStrides ord s
ins : List (Nat, a) ins : List (Nat, a)
ins = collapse $ mapWithIndex (\is,x => (getLocation' sts is, x)) v ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v
||| Construct an array using a structure of nested vectors. ||| Construct an array using a structure of nested vectors.
||| To explicitly specify the shape and order of the array, use `array'`.
export export
array : {s : _} -> Vects s a -> Array s a array : {s : _} -> Vects s a -> Array s a
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v) array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
--------------------------------------------------------------------------------
-- Indexing
--------------------------------------------------------------------------------
||| Index the array using the given `Coords` object.
export
index : Coords s -> Array s a -> a
index is arr = index (getLocation (strides arr) is) (getPrim arr)
||| Index the array using the given `CoordsRange` object.
export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange rs arr = let ord = getOrder arr
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
sts = calcStrides ord s
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
--------------------------------------------------------------------------------
-- Operations on arrays
--------------------------------------------------------------------------------
||| Reshape the array into the given shape and reinterpret it according to ||| Reshape the array into the given shape and reinterpret it according to
||| the given order. ||| the given order.
||| |||
@ -129,7 +197,33 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
reshape s' arr = reshape' s' (getOrder arr) arr reshape s' arr = reshape' s' (getOrder arr) arr
||| Index the array using the given `Coords` object.
export export
index : Coords s -> Array s a -> a enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
index is arr = index (getLocation (getStrides arr) is) (getPrim arr) enumerate' (MkArray _ sts sh p) =
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
export
enumerate : Array {rk} s a -> List (Coords s, a)
enumerate arr = map (\is => (is, index is arr))
(rewrite shapeEq arr in getAllCoords $ shape arr)
export
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
Array (replaceAt axis (index axis s + d) s) a
stack axis a b = let sA = shape a
sB = shape b
dA = index axis sA
dB = index axis sB
s = replaceAt axis (dA + dB) sA
sts = calcStrides COrder s
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
display : Show a => Array s a -> IO ()
display = printLn . PrimArray.toList . getPrim

View file

@ -5,6 +5,8 @@ import Data.Vect
%default total %default total
namespace Coords
||| A type-safe coordinate system for an array. The coordinates are ||| A type-safe coordinate system for an array. The coordinates are
||| values of `Fin dim`, where `dim` is the dimension of each axis. ||| values of `Fin dim`, where `dim` is the dimension of each axis.
public export public export
@ -55,3 +57,59 @@ getLocation' = sum .: zipWith (*)
export export
getLocation : Vect rk Nat -> Coords {rk} s -> Nat getLocation : Vect rk Nat -> Coords {rk} s -> Nat
getLocation sts is = getLocation' sts (toNats is) getLocation sts is = getLocation' sts (toNats is)
namespace CoordsRange
public export
data CRange : Nat -> Type where
One : Fin n -> CRange n
All : CRange n
StartBound : Fin (S n) -> CRange n
EndBound : Fin (S n) -> CRange n
Bounds : Fin (S n) -> Fin (S n) -> CRange n
public export
data CoordsRange : Vect rk Nat -> Type where
Nil : CoordsRange []
(::) : CRange d -> CoordsRange s -> CoordsRange (d :: s)
cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat)
cRangeToBounds (One x) = Left (cast x)
cRangeToBounds All = Right (0, n)
cRangeToBounds (StartBound x) = Right (cast x, n)
cRangeToBounds (EndBound x) = Right (0, cast x)
cRangeToBounds (Bounds x y) = Right (cast x, cast y)
newRank : {s : _} -> CoordsRange s -> Nat
newRank [] = 0
newRank (r :: rs) = case cRangeToBounds r of
Left _ => newRank rs
Right _ => S (newRank rs)
||| Calculate the new shape given by a coordinate range.
export
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
newShape [] = []
newShape (r :: rs) with (cRangeToBounds r)
newShape (r :: rs) | Left _ = newShape rs
newShape (r :: rs) | Right (x,y) = minus y x :: newShape rs
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
getNewPos [] [] = []
getNewPos (r :: rs) (i :: is) with (cRangeToBounds r)
getNewPos (r :: rs) (i :: is) | Left _ = getNewPos rs is
getNewPos (r :: rs) (i :: is) | Right (x, _) = minus i x :: getNewPos rs is
export
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
where
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
go [] = pure []
go (r :: rs) = [| (case cRangeToBounds r of
Left x => pure x
Right (x,y) => [x..pred y]) :: go rs |]