Allow One in non-safe ranged indexing

This commit is contained in:
Kiana Sheibani 2024-05-06 04:36:02 -04:00
parent 56d49256a0
commit 598e60c2da
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
3 changed files with 52 additions and 27 deletions

View file

@ -26,6 +26,9 @@ export
rangeLenZ : (x : Nat) -> length (range 0 x) = x
rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x
export %unsafe
assertFin : Nat -> Fin n
assertFin n = natToFinLt n @{believe_me Oh}
--------------------------------------------------------------------------------
-- Array coordinate types
@ -63,6 +66,7 @@ namespace Strict
public export
data CRange : Nat -> Type where
One : Fin n -> CRange n
One' : Fin n -> CRange n
All : CRange n
StartBound : Fin (S n) -> CRange n
EndBound : Fin (S n) -> CRange n
@ -79,6 +83,8 @@ namespace Strict
namespace NB
public export
data CRangeNB : Type where
One : Nat -> CRangeNB
One' : Nat -> CRangeNB
All : CRangeNB
StartBound : Nat -> CRangeNB
EndBound : Nat -> CRangeNB
@ -128,6 +134,7 @@ namespace Strict
public export
cRangeToList : {n : Nat} -> CRange n -> Either Nat (List Nat)
cRangeToList (One x) = Left (cast x)
cRangeToList (One' x) = Right [cast x]
cRangeToList All = Right $ range 0 n
cRangeToList (StartBound x) = Right $ range (cast x) n
cRangeToList (EndBound x) = Right $ range 0 (cast x)
@ -148,8 +155,8 @@ namespace Strict
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
newShape [] = []
newShape (r :: rs) with (cRangeToList r)
newShape (r :: rs) | Left _ = newShape rs
newShape (r :: rs) | Right xs = length xs :: newShape rs
_ | Left _ = newShape rs
_ | Right xs = length xs :: newShape rs
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
@ -175,6 +182,14 @@ namespace NB
validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |]
where
validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n)
validate' n (One i) =
case isLT i n of
Yes _ => Just (One (natToFinLT i))
_ => Nothing
validate' n (One' i) =
case isLT i n of
Yes _ => Just (One' (natToFinLT i))
_ => Nothing
validate' n All = Just All
validate' n (StartBound x) =
case isLTE x n of
@ -197,29 +212,44 @@ namespace NB
export %unsafe
assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s
assertCRange [] [] = []
assertCRange (d :: s) (r :: rs) = assert' d r :: assertCRange s rs
assertCRange (d :: s) (r :: rs) = assert' r :: assertCRange s rs
where
assert' : (n : Nat) -> CRangeNB -> CRange n
assert' n All = All
assert' n (StartBound x) = StartBound (believe_me x)
assert' n (EndBound x) = EndBound (believe_me x)
assert' n (Bounds x y) = Bounds (believe_me x) (believe_me y)
assert' n (Indices xs) = Indices (believe_me <$> xs)
assert' n (Filter f) = Filter (f . finToNat)
assert' : forall n. CRangeNB -> CRange n
assert' (One i) = One (assertFin i)
assert' (One' i) = One' (assertFin i)
assert' All = All
assert' (StartBound x) = StartBound (assertFin x)
assert' (EndBound x) = EndBound (assertFin x)
assert' (Bounds x y) = Bounds (assertFin x) (assertFin y)
assert' (Indices xs) = Indices (assertFin <$> xs)
assert' (Filter f) = Filter (f . finToNat)
public export
cRangeNBToList : Nat -> CRangeNB -> List Nat
cRangeNBToList s All = range 0 s
cRangeNBToList s (StartBound x) = range x s
cRangeNBToList s (EndBound x) = range 0 x
cRangeNBToList s (Bounds x y) = range x y
cRangeNBToList s (Indices xs) = nub xs
cRangeNBToList s (Filter p) = filter p $ range 0 s
cRangeNBToList : Nat -> CRangeNB -> Either Nat (List Nat)
cRangeNBToList s (One i) = Left i
cRangeNBToList s (One' i) = Right [i]
cRangeNBToList s All = Right $ range 0 s
cRangeNBToList s (StartBound x) = Right $ range x s
cRangeNBToList s (EndBound x) = Right $ range 0 x
cRangeNBToList s (Bounds x y) = Right $ range x y
cRangeNBToList s (Indices xs) = Right $ nub xs
cRangeNBToList s (Filter p) = Right $ filter p $ range 0 s
public export
newRank : Vect rk Nat -> Vect rk CRangeNB -> Nat
newRank _ [] = 0
newRank (d :: s) (r :: rs) =
case cRangeNBToList d r of
Left _ => newRank s rs
Right _ => S (newRank s rs)
||| Calculate the new shape given by a coordinate range.
public export
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
newShape = zipWith (length .: cRangeNBToList)
newShape : (s : Vect rk Nat) -> (is : Vect rk CRangeNB) -> Vect (newRank s is) Nat
newShape [] [] = []
newShape (d :: s) (r :: rs) with (cRangeNBToList d r)
_ | Left _ = newShape s rs
_ | Right xs = length xs :: newShape s rs
export
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)

View file

@ -568,9 +568,8 @@ solveLowerTri' {n} mat b with (viewShape b)
construct [] = []
construct {i=S i} (b :: bs) =
let xs = construct bs
i' = assert_total $ case natToFin i n of Just i' => i'
in (b - sum (zipWith (*) xs (reverse $ toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $
mat !!.. [One i', EndBound (weaken i')]))) / mat!#[i,i] :: xs
mat !#.. [One i, EndBound i]))) / mat!#[i,i] :: xs
solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
@ -581,9 +580,8 @@ solveUpperTri' {n} mat b with (viewShape b)
construct _ [] = []
construct i (b :: bs) =
let xs = construct (S i) bs
i' = assert_total $ case natToFin i n of Just i' => i'
in (b - sum (zipWith (*) xs (toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $
mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs
mat !#.. [One i, StartBound (S i)]))) / mat!#[i,i] :: xs
||| Solve a linear equation, assuming the matrix is lower triangular.
@ -628,6 +626,7 @@ solve mat = solveWithLUP mat (decompLUP mat)
--------------------------------------------------------------------------------
||| Determine whether a matrix has an inverse.
export
invertible : FieldCmp a => Matrix' n a -> Bool
invertible {n} mat with (viewShape mat)

View file

@ -23,10 +23,6 @@ update : Coords s -> (a -> a) -> Vects s a -> Vects s a
update [] f v = f v
update (i :: is) f v = updateAt i (update is f) v
export %unsafe
assertFin : Nat -> Fin n
assertFin n = natToFinLt n @{believe_me Oh}
export
indexRange : {s : _} -> (rs : CoordsRange s) -> Vects s a -> Vects (newShape rs) a
indexRange [] v = v