Add range indexing support
This commit is contained in:
parent
76e16574f1
commit
5914652b7c
|
@ -1,5 +1,6 @@
|
||||||
module Data.NumIdr.Array.Array
|
module Data.NumIdr.Array.Array
|
||||||
|
|
||||||
|
import Data.List1
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import Data.NumIdr.Array.Order
|
import Data.NumIdr.Array.Order
|
||||||
|
@ -38,6 +39,11 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
|
||||||
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Properties of arrays
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
||| Extract the primitive array value.
|
||| Extract the primitive array value.
|
||||||
export
|
export
|
||||||
getPrim : Array s a -> PrimArray a
|
getPrim : Array s a -> PrimArray a
|
||||||
|
@ -50,8 +56,8 @@ getOrder (MkArray ord _ _ _) = ord
|
||||||
|
|
||||||
||| The strides of the array, returned in the same axis order as in the shape.
|
||| The strides of the array, returned in the same axis order as in the shape.
|
||||||
export
|
export
|
||||||
getStrides : Array {rk} s a -> Vect rk Nat
|
strides : Array {rk} s a -> Vect rk Nat
|
||||||
getStrides (MkArray _ sts _ _) = sts
|
strides (MkArray _ sts _ _) = sts
|
||||||
|
|
||||||
||| The total number of elements of the array
|
||| The total number of elements of the array
|
||||||
||| This is equivalent to `product s`.
|
||| This is equivalent to `product s`.
|
||||||
|
@ -70,6 +76,25 @@ rank : Array s a -> Nat
|
||||||
rank = length . shape
|
rank = length . shape
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Array constructors
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
shapeEq : (arr : Array s a) -> s = shape arr
|
||||||
|
shapeEq (MkArray _ _ _ _) = Refl
|
||||||
|
|
||||||
|
|
||||||
|
-- Get a list of all coordinates
|
||||||
|
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||||
|
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||||
|
|
||||||
|
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
|
||||||
|
getAllCoords [] = pure []
|
||||||
|
getAllCoords (Z :: s) = []
|
||||||
|
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||||
|
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are arranged into the provided shape using the provided order.
|
||| are arranged into the provided shape using the provided order.
|
||||||
|||
|
|||
|
||||||
|
@ -82,12 +107,30 @@ fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are assembled into the provided shape using row-major order (the last axis is the
|
||| are assembled into the provided shape using row-major order (the last axis is the
|
||||||
||| least significant).
|
||| least significant).
|
||||||
|
||| To specify the order of the array, see `fromVect'`.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
export
|
export
|
||||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||||
fromVect s = fromVect' s COrder
|
fromVect s = fromVect' s COrder
|
||||||
|
|
||||||
|
||| Create an array given a function to generate its elements.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ ord The order to interpret the elements
|
||||||
|
export
|
||||||
|
fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a
|
||||||
|
fromFunction' s ord f = let sts = calcStrides ord s
|
||||||
|
in MkArray ord sts s (unsafeFromIns (product s) $
|
||||||
|
map (\is => (getLocation sts is, f is)) $ getAllCoords s)
|
||||||
|
|
||||||
|
||| Create an array given a function to generate its elements.
|
||||||
|
||| To specify the order of the array, use `fromFunction'`.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
export
|
||||||
|
fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
|
||||||
|
fromFunction s = fromFunction' s COrder
|
||||||
|
|
||||||
||| Construct an array using a structure of nested vectors. The elements are arranged
|
||| Construct an array using a structure of nested vectors. The elements are arranged
|
||||||
||| to the specified order before being written.
|
||| to the specified order before being written.
|
||||||
|
@ -102,14 +145,39 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
||||||
sts = calcStrides ord s
|
sts = calcStrides ord s
|
||||||
|
|
||||||
ins : List (Nat, a)
|
ins : List (Nat, a)
|
||||||
ins = collapse $ mapWithIndex (\is,x => (getLocation' sts is, x)) v
|
ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v
|
||||||
|
|
||||||
||| Construct an array using a structure of nested vectors.
|
||| Construct an array using a structure of nested vectors.
|
||||||
|
||| To explicitly specify the shape and order of the array, use `array'`.
|
||||||
export
|
export
|
||||||
array : {s : _} -> Vects s a -> Array s a
|
array : {s : _} -> Vects s a -> Array s a
|
||||||
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Indexing
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
||| Index the array using the given `Coords` object.
|
||||||
|
export
|
||||||
|
index : Coords s -> Array s a -> a
|
||||||
|
index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
||||||
|
|
||||||
|
||| Index the array using the given `CoordsRange` object.
|
||||||
|
export
|
||||||
|
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||||
|
indexRange rs arr = let ord = getOrder arr
|
||||||
|
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
||||||
|
sts = calcStrides ord s
|
||||||
|
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Operations on arrays
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
||| Reshape the array into the given shape and reinterpret it according to
|
||| Reshape the array into the given shape and reinterpret it according to
|
||||||
||| the given order.
|
||| the given order.
|
||||||
|||
|
|||
|
||||||
|
@ -129,7 +197,33 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
||||||
reshape s' arr = reshape' s' (getOrder arr) arr
|
reshape s' arr = reshape' s' (getOrder arr) arr
|
||||||
|
|
||||||
|
|
||||||
||| Index the array using the given `Coords` object.
|
|
||||||
|
|
||||||
export
|
export
|
||||||
index : Coords s -> Array s a -> a
|
enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
|
||||||
index is arr = index (getLocation (getStrides arr) is) (getPrim arr)
|
enumerate' (MkArray _ sts sh p) =
|
||||||
|
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
enumerate : Array {rk} s a -> List (Coords s, a)
|
||||||
|
enumerate arr = map (\is => (is, index is arr))
|
||||||
|
(rewrite shapeEq arr in getAllCoords $ shape arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||||
|
Array (replaceAt axis (index axis s + d) s) a
|
||||||
|
stack axis a b = let sA = shape a
|
||||||
|
sB = shape b
|
||||||
|
dA = index axis sA
|
||||||
|
dB = index axis sB
|
||||||
|
s = replaceAt axis (dA + dB) sA
|
||||||
|
sts = calcStrides COrder s
|
||||||
|
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||||
|
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||||
|
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
display : Show a => Array s a -> IO ()
|
||||||
|
display = printLn . PrimArray.toList . getPrim
|
||||||
|
|
|
@ -5,6 +5,8 @@ import Data.Vect
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
|
namespace Coords
|
||||||
|
|
||||||
||| A type-safe coordinate system for an array. The coordinates are
|
||| A type-safe coordinate system for an array. The coordinates are
|
||||||
||| values of `Fin dim`, where `dim` is the dimension of each axis.
|
||| values of `Fin dim`, where `dim` is the dimension of each axis.
|
||||||
public export
|
public export
|
||||||
|
@ -55,3 +57,59 @@ getLocation' = sum .: zipWith (*)
|
||||||
export
|
export
|
||||||
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
||||||
getLocation sts is = getLocation' sts (toNats is)
|
getLocation sts is = getLocation' sts (toNats is)
|
||||||
|
|
||||||
|
|
||||||
|
namespace CoordsRange
|
||||||
|
|
||||||
|
public export
|
||||||
|
data CRange : Nat -> Type where
|
||||||
|
One : Fin n -> CRange n
|
||||||
|
All : CRange n
|
||||||
|
StartBound : Fin (S n) -> CRange n
|
||||||
|
EndBound : Fin (S n) -> CRange n
|
||||||
|
Bounds : Fin (S n) -> Fin (S n) -> CRange n
|
||||||
|
|
||||||
|
public export
|
||||||
|
data CoordsRange : Vect rk Nat -> Type where
|
||||||
|
Nil : CoordsRange []
|
||||||
|
(::) : CRange d -> CoordsRange s -> CoordsRange (d :: s)
|
||||||
|
|
||||||
|
|
||||||
|
cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat)
|
||||||
|
cRangeToBounds (One x) = Left (cast x)
|
||||||
|
cRangeToBounds All = Right (0, n)
|
||||||
|
cRangeToBounds (StartBound x) = Right (cast x, n)
|
||||||
|
cRangeToBounds (EndBound x) = Right (0, cast x)
|
||||||
|
cRangeToBounds (Bounds x y) = Right (cast x, cast y)
|
||||||
|
|
||||||
|
|
||||||
|
newRank : {s : _} -> CoordsRange s -> Nat
|
||||||
|
newRank [] = 0
|
||||||
|
newRank (r :: rs) = case cRangeToBounds r of
|
||||||
|
Left _ => newRank rs
|
||||||
|
Right _ => S (newRank rs)
|
||||||
|
|
||||||
|
||| Calculate the new shape given by a coordinate range.
|
||||||
|
export
|
||||||
|
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
|
||||||
|
newShape [] = []
|
||||||
|
newShape (r :: rs) with (cRangeToBounds r)
|
||||||
|
newShape (r :: rs) | Left _ = newShape rs
|
||||||
|
newShape (r :: rs) | Right (x,y) = minus y x :: newShape rs
|
||||||
|
|
||||||
|
|
||||||
|
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
|
||||||
|
getNewPos [] [] = []
|
||||||
|
getNewPos (r :: rs) (i :: is) with (cRangeToBounds r)
|
||||||
|
getNewPos (r :: rs) (i :: is) | Left _ = getNewPos rs is
|
||||||
|
getNewPos (r :: rs) (i :: is) | Right (x, _) = minus i x :: getNewPos rs is
|
||||||
|
|
||||||
|
export
|
||||||
|
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
|
||||||
|
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
|
||||||
|
where
|
||||||
|
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
|
||||||
|
go [] = pure []
|
||||||
|
go (r :: rs) = [| (case cRangeToBounds r of
|
||||||
|
Left x => pure x
|
||||||
|
Right (x,y) => [x..pred y]) :: go rs |]
|
||||||
|
|
Loading…
Reference in a new issue