Add range indexing support
This commit is contained in:
parent
76e16574f1
commit
5914652b7c
|
@ -1,5 +1,6 @@
|
|||
module Data.NumIdr.Array.Array
|
||||
|
||||
import Data.List1
|
||||
import Data.Vect
|
||||
import Data.NumIdr.PrimArray
|
||||
import Data.NumIdr.Array.Order
|
||||
|
@ -38,6 +39,11 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
|
|||
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Properties of arrays
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
||| Extract the primitive array value.
|
||||
export
|
||||
getPrim : Array s a -> PrimArray a
|
||||
|
@ -50,8 +56,8 @@ getOrder (MkArray ord _ _ _) = ord
|
|||
|
||||
||| The strides of the array, returned in the same axis order as in the shape.
|
||||
export
|
||||
getStrides : Array {rk} s a -> Vect rk Nat
|
||||
getStrides (MkArray _ sts _ _) = sts
|
||||
strides : Array {rk} s a -> Vect rk Nat
|
||||
strides (MkArray _ sts _ _) = sts
|
||||
|
||||
||| The total number of elements of the array
|
||||
||| This is equivalent to `product s`.
|
||||
|
@ -70,6 +76,25 @@ rank : Array s a -> Nat
|
|||
rank = length . shape
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Array constructors
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
shapeEq : (arr : Array s a) -> s = shape arr
|
||||
shapeEq (MkArray _ _ _ _) = Refl
|
||||
|
||||
|
||||
-- Get a list of all coordinates
|
||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||
|
||||
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
|
||||
getAllCoords [] = pure []
|
||||
getAllCoords (Z :: s) = []
|
||||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||
|
||||
|
||||
||| Create an array given a vector of its elements. The elements of the vector
|
||||
||| are arranged into the provided shape using the provided order.
|
||||
|||
|
||||
|
@ -82,12 +107,30 @@ fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
|||
||| Create an array given a vector of its elements. The elements of the vector
|
||||
||| are assembled into the provided shape using row-major order (the last axis is the
|
||||
||| least significant).
|
||||
||| To specify the order of the array, see `fromVect'`.
|
||||
|||
|
||||
||| @ s The shape of the constructed array
|
||||
export
|
||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||
fromVect s = fromVect' s COrder
|
||||
|
||||
||| Create an array given a function to generate its elements.
|
||||
|||
|
||||
||| @ s The shape of the constructed array
|
||||
||| @ ord The order to interpret the elements
|
||||
export
|
||||
fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a
|
||||
fromFunction' s ord f = let sts = calcStrides ord s
|
||||
in MkArray ord sts s (unsafeFromIns (product s) $
|
||||
map (\is => (getLocation sts is, f is)) $ getAllCoords s)
|
||||
|
||||
||| Create an array given a function to generate its elements.
|
||||
||| To specify the order of the array, use `fromFunction'`.
|
||||
|||
|
||||
||| @ s The shape of the constructed array
|
||||
export
|
||||
fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
|
||||
fromFunction s = fromFunction' s COrder
|
||||
|
||||
||| Construct an array using a structure of nested vectors. The elements are arranged
|
||||
||| to the specified order before being written.
|
||||
|
@ -102,14 +145,39 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
|||
sts = calcStrides ord s
|
||||
|
||||
ins : List (Nat, a)
|
||||
ins = collapse $ mapWithIndex (\is,x => (getLocation' sts is, x)) v
|
||||
ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v
|
||||
|
||||
||| Construct an array using a structure of nested vectors.
|
||||
||| To explicitly specify the shape and order of the array, use `array'`.
|
||||
export
|
||||
array : {s : _} -> Vects s a -> Array s a
|
||||
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Indexing
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
||| Index the array using the given `Coords` object.
|
||||
export
|
||||
index : Coords s -> Array s a -> a
|
||||
index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
||||
|
||||
||| Index the array using the given `CoordsRange` object.
|
||||
export
|
||||
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||
indexRange rs arr = let ord = getOrder arr
|
||||
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
||||
sts = calcStrides ord s
|
||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Operations on arrays
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
||| Reshape the array into the given shape and reinterpret it according to
|
||||
||| the given order.
|
||||
|||
|
||||
|
@ -129,7 +197,33 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
|||
reshape s' arr = reshape' s' (getOrder arr) arr
|
||||
|
||||
|
||||
||| Index the array using the given `Coords` object.
|
||||
|
||||
|
||||
export
|
||||
index : Coords s -> Array s a -> a
|
||||
index is arr = index (getLocation (getStrides arr) is) (getPrim arr)
|
||||
enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
|
||||
enumerate' (MkArray _ sts sh p) =
|
||||
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
||||
|
||||
|
||||
export
|
||||
enumerate : Array {rk} s a -> List (Coords s, a)
|
||||
enumerate arr = map (\is => (is, index is arr))
|
||||
(rewrite shapeEq arr in getAllCoords $ shape arr)
|
||||
|
||||
export
|
||||
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||
Array (replaceAt axis (index axis s + d) s) a
|
||||
stack axis a b = let sA = shape a
|
||||
sB = shape b
|
||||
dA = index axis sA
|
||||
dB = index axis sB
|
||||
s = replaceAt axis (dA + dB) sA
|
||||
sts = calcStrides COrder s
|
||||
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||
|
||||
|
||||
|
||||
display : Show a => Array s a -> IO ()
|
||||
display = printLn . PrimArray.toList . getPrim
|
||||
|
|
|
@ -5,53 +5,111 @@ import Data.Vect
|
|||
%default total
|
||||
|
||||
|
||||
||| A type-safe coordinate system for an array. The coordinates are
|
||||
||| values of `Fin dim`, where `dim` is the dimension of each axis.
|
||||
public export
|
||||
data Coords : (s : Vect rk Nat) -> Type where
|
||||
namespace Coords
|
||||
|
||||
||| A type-safe coordinate system for an array. The coordinates are
|
||||
||| values of `Fin dim`, where `dim` is the dimension of each axis.
|
||||
public export
|
||||
data Coords : (s : Vect rk Nat) -> Type where
|
||||
Nil : Coords Nil
|
||||
(::) : Fin dim -> Coords s -> Coords (dim :: s)
|
||||
|
||||
||| Forget the shape of the array by converting each index to type `Nat`.
|
||||
export
|
||||
toNats : Coords {rk} s -> Vect rk Nat
|
||||
toNats [] = []
|
||||
toNats (i :: is) = finToNat i :: toNats is
|
||||
||| Forget the shape of the array by converting each index to type `Nat`.
|
||||
export
|
||||
toNats : Coords {rk} s -> Vect rk Nat
|
||||
toNats [] = []
|
||||
toNats (i :: is) = finToNat i :: toNats is
|
||||
|
||||
|
||||
public export
|
||||
Vects : Vect rk Nat -> Type -> Type
|
||||
Vects [] a = a
|
||||
Vects (d::s) a = Vect d (Vects s a)
|
||||
public export
|
||||
Vects : Vect rk Nat -> Type -> Type
|
||||
Vects [] a = a
|
||||
Vects (d::s) a = Vect d (Vects s a)
|
||||
|
||||
export
|
||||
collapse : {s : _} -> Vects s a -> List a
|
||||
collapse {s=[]} = (::[])
|
||||
collapse {s=_::_} = concat . map collapse
|
||||
export
|
||||
collapse : {s : _} -> Vects s a -> List a
|
||||
collapse {s=[]} = (::[])
|
||||
collapse {s=_::_} = concat . map collapse
|
||||
|
||||
|
||||
export
|
||||
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
|
||||
mapWithIndex {s=[]} f x = f [] x
|
||||
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
|
||||
export
|
||||
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
|
||||
mapWithIndex {s=[]} f x = f [] x
|
||||
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
|
||||
where
|
||||
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
|
||||
mapWithIndex' f [] = []
|
||||
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
||||
|
||||
|
||||
export
|
||||
index : Coords s -> Vects s a -> a
|
||||
index [] x = x
|
||||
index (i::is) v = index is $ index i v
|
||||
export
|
||||
index : Coords s -> Vects s a -> a
|
||||
index [] x = x
|
||||
index (i::is) v = index is $ index i v
|
||||
|
||||
|
||||
export
|
||||
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
||||
getLocation' = sum .: zipWith (*)
|
||||
export
|
||||
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
||||
getLocation' = sum .: zipWith (*)
|
||||
|
||||
||| Compute the memory location of an array element
|
||||
||| given its coordinate and the strides of the array.
|
||||
export
|
||||
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
||||
getLocation sts is = getLocation' sts (toNats is)
|
||||
||| Compute the memory location of an array element
|
||||
||| given its coordinate and the strides of the array.
|
||||
export
|
||||
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
||||
getLocation sts is = getLocation' sts (toNats is)
|
||||
|
||||
|
||||
namespace CoordsRange
|
||||
|
||||
public export
|
||||
data CRange : Nat -> Type where
|
||||
One : Fin n -> CRange n
|
||||
All : CRange n
|
||||
StartBound : Fin (S n) -> CRange n
|
||||
EndBound : Fin (S n) -> CRange n
|
||||
Bounds : Fin (S n) -> Fin (S n) -> CRange n
|
||||
|
||||
public export
|
||||
data CoordsRange : Vect rk Nat -> Type where
|
||||
Nil : CoordsRange []
|
||||
(::) : CRange d -> CoordsRange s -> CoordsRange (d :: s)
|
||||
|
||||
|
||||
cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat)
|
||||
cRangeToBounds (One x) = Left (cast x)
|
||||
cRangeToBounds All = Right (0, n)
|
||||
cRangeToBounds (StartBound x) = Right (cast x, n)
|
||||
cRangeToBounds (EndBound x) = Right (0, cast x)
|
||||
cRangeToBounds (Bounds x y) = Right (cast x, cast y)
|
||||
|
||||
|
||||
newRank : {s : _} -> CoordsRange s -> Nat
|
||||
newRank [] = 0
|
||||
newRank (r :: rs) = case cRangeToBounds r of
|
||||
Left _ => newRank rs
|
||||
Right _ => S (newRank rs)
|
||||
|
||||
||| Calculate the new shape given by a coordinate range.
|
||||
export
|
||||
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
|
||||
newShape [] = []
|
||||
newShape (r :: rs) with (cRangeToBounds r)
|
||||
newShape (r :: rs) | Left _ = newShape rs
|
||||
newShape (r :: rs) | Right (x,y) = minus y x :: newShape rs
|
||||
|
||||
|
||||
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
|
||||
getNewPos [] [] = []
|
||||
getNewPos (r :: rs) (i :: is) with (cRangeToBounds r)
|
||||
getNewPos (r :: rs) (i :: is) | Left _ = getNewPos rs is
|
||||
getNewPos (r :: rs) (i :: is) | Right (x, _) = minus i x :: getNewPos rs is
|
||||
|
||||
export
|
||||
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
|
||||
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
|
||||
where
|
||||
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
|
||||
go [] = pure []
|
||||
go (r :: rs) = [| (case cRangeToBounds r of
|
||||
Left x => pure x
|
||||
Right (x,y) => [x..pred y]) :: go rs |]
|
||||
|
|
Loading…
Reference in a new issue