From 5914652b7c16e80f854ec5c0a79f5f5515f4a104 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Thu, 19 May 2022 08:49:57 -0400 Subject: [PATCH] Add range indexing support --- src/Data/NumIdr/Array/Array.idr | 106 ++++++++++++++++++++++-- src/Data/NumIdr/Array/Coords.idr | 136 ++++++++++++++++++++++--------- 2 files changed, 197 insertions(+), 45 deletions(-) diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 1aa37cb..4da2a90 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -1,5 +1,6 @@ module Data.NumIdr.Array.Array +import Data.List1 import Data.Vect import Data.NumIdr.PrimArray import Data.NumIdr.Array.Order @@ -38,6 +39,11 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where (s : Vect rk Nat) -> PrimArray a -> Array s a +-------------------------------------------------------------------------------- +-- Properties of arrays +-------------------------------------------------------------------------------- + + ||| Extract the primitive array value. export getPrim : Array s a -> PrimArray a @@ -50,8 +56,8 @@ getOrder (MkArray ord _ _ _) = ord ||| The strides of the array, returned in the same axis order as in the shape. export -getStrides : Array {rk} s a -> Vect rk Nat -getStrides (MkArray _ sts _ _) = sts +strides : Array {rk} s a -> Vect rk Nat +strides (MkArray _ sts _ _) = sts ||| The total number of elements of the array ||| This is equivalent to `product s`. @@ -70,6 +76,25 @@ rank : Array s a -> Nat rank = length . shape +-------------------------------------------------------------------------------- +-- Array constructors +-------------------------------------------------------------------------------- + + +shapeEq : (arr : Array s a) -> s = shape arr +shapeEq (MkArray _ _ _ _) = Refl + + +-- Get a list of all coordinates +getAllCoords' : Vect rk Nat -> List (Vect rk Nat) +getAllCoords' = traverse (\case Z => []; S n => [0..n]) + +getAllCoords : (s : Vect rk Nat) -> List (Coords s) +getAllCoords [] = pure [] +getAllCoords (Z :: s) = [] +getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] + + ||| Create an array given a vector of its elements. The elements of the vector ||| are arranged into the provided shape using the provided order. ||| @@ -82,12 +107,30 @@ fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v) ||| Create an array given a vector of its elements. The elements of the vector ||| are assembled into the provided shape using row-major order (the last axis is the ||| least significant). +||| To specify the order of the array, see `fromVect'`. ||| ||| @ s The shape of the constructed array export fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a fromVect s = fromVect' s COrder +||| Create an array given a function to generate its elements. +||| +||| @ s The shape of the constructed array +||| @ ord The order to interpret the elements +export +fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a +fromFunction' s ord f = let sts = calcStrides ord s + in MkArray ord sts s (unsafeFromIns (product s) $ + map (\is => (getLocation sts is, f is)) $ getAllCoords s) + +||| Create an array given a function to generate its elements. +||| To specify the order of the array, use `fromFunction'`. +||| +||| @ s The shape of the constructed array +export +fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a +fromFunction s = fromFunction' s COrder ||| Construct an array using a structure of nested vectors. The elements are arranged ||| to the specified order before being written. @@ -102,14 +145,39 @@ array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins) sts = calcStrides ord s ins : List (Nat, a) - ins = collapse $ mapWithIndex (\is,x => (getLocation' sts is, x)) v + ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v ||| Construct an array using a structure of nested vectors. +||| To explicitly specify the shape and order of the array, use `array'`. export array : {s : _} -> Vects s a -> Array s a array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v) +-------------------------------------------------------------------------------- +-- Indexing +-------------------------------------------------------------------------------- + + +||| Index the array using the given `Coords` object. +export +index : Coords s -> Array s a -> a +index is arr = index (getLocation (strides arr) is) (getPrim arr) + +||| Index the array using the given `CoordsRange` object. +export +indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a +indexRange rs arr = let ord = getOrder arr + s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs + sts = calcStrides ord s + in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) + + +-------------------------------------------------------------------------------- +-- Operations on arrays +-------------------------------------------------------------------------------- + + ||| Reshape the array into the given shape and reinterpret it according to ||| the given order. ||| @@ -129,7 +197,33 @@ reshape : (s' : Vect rk' Nat) -> Array {rk} s a -> reshape s' arr = reshape' s' (getOrder arr) arr -||| Index the array using the given `Coords` object. + + export -index : Coords s -> Array s a -> a -index is arr = index (getLocation (getStrides arr) is) (getPrim arr) +enumerate' : Array {rk} s a -> List (Vect rk Nat, a) +enumerate' (MkArray _ sts sh p) = + map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh) + + +export +enumerate : Array {rk} s a -> List (Coords s, a) +enumerate arr = map (\is => (is, index is arr)) + (rewrite shapeEq arr in getAllCoords $ shape arr) + +export +stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> + Array (replaceAt axis (index axis s + d) s) a +stack axis a b = let sA = shape a + sB = shape b + dA = index axis sA + dB = index axis sB + s = replaceAt axis (dA + dB) sA + sts = calcStrides COrder s + ins = map (mapFst $ getLocation' sts . toNats) (enumerate a) + ++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b) + in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) + + + +display : Show a => Array s a -> IO () +display = printLn . PrimArray.toList . getPrim diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 695f3f5..8e0b108 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -5,53 +5,111 @@ import Data.Vect %default total -||| A type-safe coordinate system for an array. The coordinates are -||| values of `Fin dim`, where `dim` is the dimension of each axis. -public export -data Coords : (s : Vect rk Nat) -> Type where - Nil : Coords Nil - (::) : Fin dim -> Coords s -> Coords (dim :: s) +namespace Coords -||| Forget the shape of the array by converting each index to type `Nat`. -export -toNats : Coords {rk} s -> Vect rk Nat -toNats [] = [] -toNats (i :: is) = finToNat i :: toNats is + ||| A type-safe coordinate system for an array. The coordinates are + ||| values of `Fin dim`, where `dim` is the dimension of each axis. + public export + data Coords : (s : Vect rk Nat) -> Type where + Nil : Coords Nil + (::) : Fin dim -> Coords s -> Coords (dim :: s) + + ||| Forget the shape of the array by converting each index to type `Nat`. + export + toNats : Coords {rk} s -> Vect rk Nat + toNats [] = [] + toNats (i :: is) = finToNat i :: toNats is -public export -Vects : Vect rk Nat -> Type -> Type -Vects [] a = a -Vects (d::s) a = Vect d (Vects s a) + public export + Vects : Vect rk Nat -> Type -> Type + Vects [] a = a + Vects (d::s) a = Vect d (Vects s a) -export -collapse : {s : _} -> Vects s a -> List a -collapse {s=[]} = (::[]) -collapse {s=_::_} = concat . map collapse + export + collapse : {s : _} -> Vects s a -> List a + collapse {s=[]} = (::[]) + collapse {s=_::_} = concat . map collapse -export -mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b -mapWithIndex {s=[]} f x = f [] x -mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v - where - mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b - mapWithIndex' f [] = [] - mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs + export + mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b + mapWithIndex {s=[]} f x = f [] x + mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v + where + mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b + mapWithIndex' f [] = [] + mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs -export -index : Coords s -> Vects s a -> a -index [] x = x -index (i::is) v = index is $ index i v + export + index : Coords s -> Vects s a -> a + index [] x = x + index (i::is) v = index is $ index i v -export -getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat -getLocation' = sum .: zipWith (*) + export + getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat + getLocation' = sum .: zipWith (*) -||| Compute the memory location of an array element -||| given its coordinate and the strides of the array. -export -getLocation : Vect rk Nat -> Coords {rk} s -> Nat -getLocation sts is = getLocation' sts (toNats is) + ||| Compute the memory location of an array element + ||| given its coordinate and the strides of the array. + export + getLocation : Vect rk Nat -> Coords {rk} s -> Nat + getLocation sts is = getLocation' sts (toNats is) + + +namespace CoordsRange + + public export + data CRange : Nat -> Type where + One : Fin n -> CRange n + All : CRange n + StartBound : Fin (S n) -> CRange n + EndBound : Fin (S n) -> CRange n + Bounds : Fin (S n) -> Fin (S n) -> CRange n + + public export + data CoordsRange : Vect rk Nat -> Type where + Nil : CoordsRange [] + (::) : CRange d -> CoordsRange s -> CoordsRange (d :: s) + + + cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat) + cRangeToBounds (One x) = Left (cast x) + cRangeToBounds All = Right (0, n) + cRangeToBounds (StartBound x) = Right (cast x, n) + cRangeToBounds (EndBound x) = Right (0, cast x) + cRangeToBounds (Bounds x y) = Right (cast x, cast y) + + + newRank : {s : _} -> CoordsRange s -> Nat + newRank [] = 0 + newRank (r :: rs) = case cRangeToBounds r of + Left _ => newRank rs + Right _ => S (newRank rs) + + ||| Calculate the new shape given by a coordinate range. + export + newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat + newShape [] = [] + newShape (r :: rs) with (cRangeToBounds r) + newShape (r :: rs) | Left _ = newShape rs + newShape (r :: rs) | Right (x,y) = minus y x :: newShape rs + + + getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat + getNewPos [] [] = [] + getNewPos (r :: rs) (i :: is) with (cRangeToBounds r) + getNewPos (r :: rs) (i :: is) | Left _ = getNewPos rs is + getNewPos (r :: rs) (i :: is) | Right (x, _) = minus i x :: getNewPos rs is + + export + getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat) + getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs + where + go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat) + go [] = pure [] + go (r :: rs) = [| (case cRangeToBounds r of + Left x => pure x + Right (x,y) => [x..pred y]) :: go rs |]