Add operators for indexing

This commit is contained in:
Kiana Sheibani 2022-05-24 09:44:13 -04:00
parent 861f1e29f2
commit 97d1bdb538
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 33 additions and 5 deletions

View file

@ -201,12 +201,31 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
-- Indexing -- Indexing
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
infix 2 !!
infix 2 !?
infixl 3 !!..
infix 3 !?..
||| Index the array using the given `Coords` object. ||| Index the array using the given `Coords` object.
export export
index : Coords s -> Array s a -> a index : Coords s -> Array s a -> a
index is arr = index (getLocation (strides arr) is) (getPrim arr) index is arr = index (getLocation (strides arr) is) (getPrim arr)
export
(!!) : Array s a -> Coords s -> a
(!!) = flip index
export
indexMaybe : Vect rk Nat -> Array {rk} s a -> Maybe a
indexMaybe is arr = safeIndex (getLocation' (strides arr) is) (getPrim arr)
export
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
(!?) = flip indexMaybe
||| Index the array using the given `CoordsRange` object. ||| Index the array using the given `CoordsRange` object.
export export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
@ -218,6 +237,10 @@ indexRange rs arr = let ord = getOrder arr
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
export
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
arr !!.. rs = indexRange rs arr
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Operations on arrays -- Operations on arrays

View file

@ -64,12 +64,20 @@ namespace CoordsRange
EndBound : Fin (S n) -> CRange n EndBound : Fin (S n) -> CRange n
Bounds : Fin (S n) -> Fin (S n) -> CRange n Bounds : Fin (S n) -> Fin (S n) -> CRange n
infix 0 ...
public export
(...) : Fin (S n) -> Fin (S n) -> CRange n
(...) = Bounds
public export public export
data CoordsRange : Vect rk Nat -> Type where data CoordsRange : Vect rk Nat -> Type where
Nil : CoordsRange [] Nil : CoordsRange []
(::) : CRange d -> CoordsRange s -> CoordsRange (d :: s) (::) : CRange d -> CoordsRange s -> CoordsRange (d :: s)
public export
cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat) cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat)
cRangeToBounds (One x) = Left (cast x) cRangeToBounds (One x) = Left (cast x)
cRangeToBounds All = Right (0, n) cRangeToBounds All = Right (0, n)
@ -78,18 +86,15 @@ namespace CoordsRange
cRangeToBounds (Bounds x y) = Right (cast x, cast y) cRangeToBounds (Bounds x y) = Right (cast x, cast y)
public export
newRank : {s : _} -> CoordsRange s -> Nat newRank : {s : _} -> CoordsRange s -> Nat
newRank [] = 0 newRank [] = 0
newRank (r :: rs) = case cRangeToBounds r of newRank (r :: rs) = case cRangeToBounds r of
Left _ => newRank rs Left _ => newRank rs
Right _ => S (newRank rs) Right _ => S (newRank rs)
export
newDim : {d : _} -> (r : CRange d) -> Maybe Nat
newDim = map (uncurry $ flip minus) . eitherToMaybe . cRangeToBounds
||| Calculate the new shape given by a coordinate range. ||| Calculate the new shape given by a coordinate range.
export public export
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
newShape [] = [] newShape [] = []
newShape (r :: rs) with (cRangeToBounds r) newShape (r :: rs) with (cRangeToBounds r)