From 3c2ebdc04848640e094b7815993e8592e465e850 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Fri, 21 Oct 2022 17:49:01 -0400 Subject: [PATCH] Tweak how runtime-checked indexing works --- src/Data/NumIdr/Array/Array.idr | 43 ++++++++++++++++++-------------- src/Data/NumIdr/Array/Coords.idr | 19 +++++++++++--- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index bcbeda5..2f0ef01 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -333,7 +333,7 @@ indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr ||| coordinates are out of bounds. export indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a -indexNB is arr = if and $ map delay $ zipWith (<) is (shape arr) +indexNB is arr = if all id $ zipWith (<) is (shape arr) then Just $ index (getLocation' (strides arr) is) (getPrim arr) else Nothing @@ -345,42 +345,47 @@ export %inline (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a arr !? is = indexNB is arr -||| Update the entry at the given coordinates using the function. The array is -||| returned unchanged if the coordinate is out of bounds. +||| Update the entry at the given coordinates using the function. `Nothing` is +||| returned if the coordinates are out of bounds. export -indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a +indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a) indexUpdateNB is f (MkArray ord sts s arr) = - MkArray ord sts s (updateAt (getLocation' sts is) f arr) + if all id $ zipWith (<) is s + then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr) + else Nothing -||| Set the entry at the given coordinates to the given value. The array is -||| returned unchanged if the coordinate is out of bounds. +||| Set the entry at the given coordinates using the function. `Nothing` is +||| returned if the coordinates are out of bounds. export -indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a +indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Maybe (Array s a) indexSetNB is = indexUpdateNB is . const ||| Index the array using the given range of coordinates, returning a new array. -||| Entries outside of the array's bounds are not included. +||| If any of the given indices are out of bounds, then `Nothing` is returned. export -indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a +indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a) indexRangeNB {s} rs arr with (viewShape arr) _ | Shape s = - let ord = getOrder arr - sts = calcStrides ord s' - in MkArray ord sts s' - (unsafeFromIns (product s') $ - map (\(is,is') => (getLocation' sts is', - index (getLocation' (strides arr) is) (getPrim arr))) - $ getCoordsList s rs) + if validateShape s rs + then + let ord = getOrder arr + sts = calcStrides ord s' + in Just $ MkArray ord sts s' + (unsafeFromIns (product s') $ + map (\(is,is') => (getLocation' sts is', + index (getLocation' (strides arr) is) (getPrim arr))) + $ getCoordsList s rs) + else Nothing where s' : Vect ? Nat s' = newShape s rs ||| Index the array using the given range of coordinates, returning a new array. -||| Entries outside of the array's bounds are not included. +||| If any of the given indices are out of bounds, then `Nothing` is returned. ||| ||| This is the operator form of `indexRangeNB`. export %inline -(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a +(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a) arr !?.. rs = indexRangeNB rs arr diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index afb8482..d47f182 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -121,7 +121,7 @@ namespace Strict cRangeToList (StartBound x) = Right $ range (cast x) n cRangeToList (EndBound x) = Right $ range 0 (cast x) cRangeToList (Bounds x y) = Right $ range (cast x) (cast y) - cRangeToList (Indices xs) = Right $ map cast xs + cRangeToList (Indices xs) = Right $ map cast $ nub xs cRangeToList (Filter p) = Right $ map cast $ filter p $ toList Fin.range @@ -163,15 +163,28 @@ namespace NB cRangeNBToList s All = range 0 s cRangeNBToList s (StartBound x) = range x s cRangeNBToList s (EndBound x) = range 0 x - cRangeNBToList s (Bounds x y) = range x (minimum y s) - cRangeNBToList s (Indices xs) = filter ( CRangeNB -> Bool + validateCRangeNB s All = True + validateCRangeNB s (StartBound x) = x < s + validateCRangeNB s (EndBound x) = x <= s + validateCRangeNB s (Bounds x y) = x < s && y <= s + validateCRangeNB s (Indices xs) = all ( Vect rk CRangeNB -> Vect rk Nat newShape = zipWith (length .: cRangeNBToList) + export + validateShape : Vect rk Nat -> Vect rk CRangeNB -> Bool + validateShape = all id .: zipWith validateCRangeNB + export getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat getNewPos = zipWith3 (\d,r,i => assert_total $