Tweak how runtime-checked indexing works

This commit is contained in:
Kiana Sheibani 2022-10-21 17:49:01 -04:00
parent 077b393bd1
commit 3c2ebdc048
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 40 additions and 22 deletions

View file

@ -333,7 +333,7 @@ indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
||| coordinates are out of bounds. ||| coordinates are out of bounds.
export export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is arr = if and $ map delay $ zipWith (<) is (shape arr) indexNB is arr = if all id $ zipWith (<) is (shape arr)
then Just $ index (getLocation' (strides arr) is) (getPrim arr) then Just $ index (getLocation' (strides arr) is) (getPrim arr)
else Nothing else Nothing
@ -345,42 +345,47 @@ export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr arr !? is = indexNB is arr
||| Update the entry at the given coordinates using the function. The array is ||| Update the entry at the given coordinates using the function. `Nothing` is
||| returned unchanged if the coordinate is out of bounds. ||| returned if the coordinates are out of bounds.
export export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
indexUpdateNB is f (MkArray ord sts s arr) = indexUpdateNB is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation' sts is) f arr) if all id $ zipWith (<) is s
then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr)
else Nothing
||| Set the entry at the given coordinates to the given value. The array is ||| Set the entry at the given coordinates using the function. `Nothing` is
||| returned unchanged if the coordinate is out of bounds. ||| returned if the coordinates are out of bounds.
export export
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Maybe (Array s a)
indexSetNB is = indexUpdateNB is . const indexSetNB is = indexUpdateNB is . const
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| Entries outside of the array's bounds are not included. ||| If any of the given indices are out of bounds, then `Nothing` is returned.
export export
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
indexRangeNB {s} rs arr with (viewShape arr) indexRangeNB {s} rs arr with (viewShape arr)
_ | Shape s = _ | Shape s =
let ord = getOrder arr if validateShape s rs
sts = calcStrides ord s' then
in MkArray ord sts s' let ord = getOrder arr
(unsafeFromIns (product s') $ sts = calcStrides ord s'
map (\(is,is') => (getLocation' sts is', in Just $ MkArray ord sts s'
index (getLocation' (strides arr) is) (getPrim arr))) (unsafeFromIns (product s') $
$ getCoordsList s rs) map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
else Nothing
where s' : Vect ? Nat where s' : Vect ? Nat
s' = newShape s rs s' = newShape s rs
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| Entries outside of the array's bounds are not included. ||| If any of the given indices are out of bounds, then `Nothing` is returned.
||| |||
||| This is the operator form of `indexRangeNB`. ||| This is the operator form of `indexRangeNB`.
export %inline export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
arr !?.. rs = indexRangeNB rs arr arr !?.. rs = indexRangeNB rs arr

View file

@ -121,7 +121,7 @@ namespace Strict
cRangeToList (StartBound x) = Right $ range (cast x) n cRangeToList (StartBound x) = Right $ range (cast x) n
cRangeToList (EndBound x) = Right $ range 0 (cast x) cRangeToList (EndBound x) = Right $ range 0 (cast x)
cRangeToList (Bounds x y) = Right $ range (cast x) (cast y) cRangeToList (Bounds x y) = Right $ range (cast x) (cast y)
cRangeToList (Indices xs) = Right $ map cast xs cRangeToList (Indices xs) = Right $ map cast $ nub xs
cRangeToList (Filter p) = Right $ map cast $ filter p $ toList Fin.range cRangeToList (Filter p) = Right $ map cast $ filter p $ toList Fin.range
@ -163,15 +163,28 @@ namespace NB
cRangeNBToList s All = range 0 s cRangeNBToList s All = range 0 s
cRangeNBToList s (StartBound x) = range x s cRangeNBToList s (StartBound x) = range x s
cRangeNBToList s (EndBound x) = range 0 x cRangeNBToList s (EndBound x) = range 0 x
cRangeNBToList s (Bounds x y) = range x (minimum y s) cRangeNBToList s (Bounds x y) = range x y
cRangeNBToList s (Indices xs) = filter (<s) xs cRangeNBToList s (Indices xs) = nub xs
cRangeNBToList s (Filter p) = filter p $ range 0 s cRangeNBToList s (Filter p) = filter p $ range 0 s
export
validateCRangeNB : Nat -> CRangeNB -> Bool
validateCRangeNB s All = True
validateCRangeNB s (StartBound x) = x < s
validateCRangeNB s (EndBound x) = x <= s
validateCRangeNB s (Bounds x y) = x < s && y <= s
validateCRangeNB s (Indices xs) = all (<s) xs
validateCRangeNB s (Filter p) = True
||| Calculate the new shape given by a coordinate range. ||| Calculate the new shape given by a coordinate range.
export export
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
newShape = zipWith (length .: cRangeNBToList) newShape = zipWith (length .: cRangeNBToList)
export
validateShape : Vect rk Nat -> Vect rk CRangeNB -> Bool
validateShape = all id .: zipWith validateCRangeNB
export export
getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat
getNewPos = zipWith3 (\d,r,i => assert_total $ getNewPos = zipWith3 (\d,r,i => assert_total $