Add new indexing functions

This commit is contained in:
Kiana Sheibani 2022-08-29 13:55:55 -04:00
parent 9ece7e6963
commit 3246e0ed94
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
3 changed files with 81 additions and 37 deletions

View file

@ -241,11 +241,14 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
-- Indexing -- Indexing
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
infixl 10 !! infixl 10 !!
infixl 10 !? infixl 10 !?
infixl 10 !# infixl 10 !#
infixl 11 !!.. infixl 11 !!..
infixl 11 !?.. infixl 11 !?..
infixl 11 !#..
||| Index the array using the given coordinates. ||| Index the array using the given coordinates.
export export
@ -255,9 +258,9 @@ index is arr = index (getLocation (strides arr) is) (getPrim arr)
||| Index the array using the given coordinates. ||| Index the array using the given coordinates.
||| |||
||| This is the operator form of `index`. ||| This is the operator form of `index`.
export export %inline
(!!) : Array s a -> Coords s -> a (!!) : Array s a -> Coords s -> a
(!!) = flip index arr !! is = index is arr
-- TODO: Create set/update at index functions -- TODO: Create set/update at index functions
@ -273,36 +276,6 @@ indexSet : Coords s -> a -> Array s a -> Array s a
indexSet is = indexUpdate is . const indexSet is = indexUpdate is . const
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is arr = if and $ map delay $ zipWith (<) is (shape arr)
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
else Nothing
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
|||
||| This is the operator form of `indexNB`.
export
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
(!?) = flip indexNB
||| Update the value at the given coordinates using the function. The array is
||| returned unchanged if the coordinate is out of bounds.
export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
indexUpdateNB is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation' sts is) f arr)
||| Set the value at the given coordinates to the given value. The array is
||| returned unchanged if the coordinate is out of bounds.
export
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
indexSetNB is = indexUpdateNB is . const
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
export export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
@ -321,18 +294,85 @@ indexRange rs arr with (viewShape arr)
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| |||
||| This is the operator form of `indexRange`. ||| This is the operator form of `indexRange`.
export export %inline
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
arr !!.. rs = indexRange rs arr arr !!.. rs = indexRange rs arr
export
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
Array s a -> Array s a
indexSetRange rs rpl (MkArray ord sts s arr) =
MkArray ord sts s (unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns (map (\(is,is') =>
(getLocation' sts is, index (getLocation' (strides rpl) is') (getPrim rpl)))
$ getCoordsList rs) arr'
pure arr')
export
indexUpdateRange : (rs : CoordsRange s) ->
(Array (newShape rs) a -> Array (newShape rs) a) ->
Array s a -> Array s a
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is arr = if and $ map delay $ zipWith (<) is (shape arr)
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
else Nothing
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
|||
||| This is the operator form of `indexNB`.
export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr
||| Update the value at the given coordinates using the function. The array is
||| returned unchanged if the coordinate is out of bounds.
export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
indexUpdateNB is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation' sts is) f arr)
||| Set the value at the given coordinates to the given value. The array is
||| returned unchanged if the coordinate is out of bounds.
export
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
indexSetNB is = indexUpdateNB is . const
export
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
indexRangeNB rs arr with (viewShape arr)
_ | Shape s =
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
where s' : Vect ? Nat
s' = newShape s rs
export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
arr !?.. rs = indexRangeNB rs arr
export export
indexUnsafe : Vect rk Nat -> Array {rk} s a -> a indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr) indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr)
export export %inline
(!#) : Array {rk} s a -> Vect rk Nat -> a (!#) : Array {rk} s a -> Vect rk Nat -> a
(!#) = flip indexUnsafe arr !# is = indexUnsafe is arr
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View file

@ -65,6 +65,6 @@ power (S n@(S _)) x = x *. power n x
||| number power. ||| number power.
||| |||
||| This is the operator form of `power`. ||| This is the operator form of `power`.
public export public export %inline
(^) : MultGroup a => a -> Nat -> a (^) : MultGroup a => a -> Nat -> a
(^) = flip power a ^ n = power n a

View file

@ -44,6 +44,10 @@ unsafeFromIns size ins = unsafePerformIO $ do
for_ ins $ \(i,x) => arrayDataSet i x arr for_ ins $ \(i,x) => arrayDataSet i x arr
pure $ MkPrimArray size arr pure $ MkPrimArray size arr
export
unsafeDoIns : List (Nat, a) -> PrimArray a -> IO ()
unsafeDoIns ins arr = for_ ins $ \(i,x) => arrayDataSet i x arr.content
||| Create an array given its size and a function to generate its elements by ||| Create an array given its size and a function to generate its elements by
||| its index. ||| its index.
export export