From 598e60c2dab1ea55044db1d1edf1f5b3941585e0 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Mon, 6 May 2024 04:36:02 -0400 Subject: [PATCH] Allow `One` in non-safe ranged indexing --- src/Data/NumIdr/Array/Coords.idr | 68 ++++++++++++++++++++-------- src/Data/NumIdr/Matrix.idr | 7 ++- src/Data/NumIdr/PrimArray/Linked.idr | 4 -- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 669471a..104ae63 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -26,6 +26,9 @@ export rangeLenZ : (x : Nat) -> length (range 0 x) = x rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x +export %unsafe +assertFin : Nat -> Fin n +assertFin n = natToFinLt n @{believe_me Oh} -------------------------------------------------------------------------------- -- Array coordinate types @@ -63,6 +66,7 @@ namespace Strict public export data CRange : Nat -> Type where One : Fin n -> CRange n + One' : Fin n -> CRange n All : CRange n StartBound : Fin (S n) -> CRange n EndBound : Fin (S n) -> CRange n @@ -79,6 +83,8 @@ namespace Strict namespace NB public export data CRangeNB : Type where + One : Nat -> CRangeNB + One' : Nat -> CRangeNB All : CRangeNB StartBound : Nat -> CRangeNB EndBound : Nat -> CRangeNB @@ -128,6 +134,7 @@ namespace Strict public export cRangeToList : {n : Nat} -> CRange n -> Either Nat (List Nat) cRangeToList (One x) = Left (cast x) + cRangeToList (One' x) = Right [cast x] cRangeToList All = Right $ range 0 n cRangeToList (StartBound x) = Right $ range (cast x) n cRangeToList (EndBound x) = Right $ range 0 (cast x) @@ -148,8 +155,8 @@ namespace Strict newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat newShape [] = [] newShape (r :: rs) with (cRangeToList r) - newShape (r :: rs) | Left _ = newShape rs - newShape (r :: rs) | Right xs = length xs :: newShape rs + _ | Left _ = newShape rs + _ | Right xs = length xs :: newShape rs getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat @@ -175,6 +182,14 @@ namespace NB validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |] where validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n) + validate' n (One i) = + case isLT i n of + Yes _ => Just (One (natToFinLT i)) + _ => Nothing + validate' n (One' i) = + case isLT i n of + Yes _ => Just (One' (natToFinLT i)) + _ => Nothing validate' n All = Just All validate' n (StartBound x) = case isLTE x n of @@ -197,29 +212,44 @@ namespace NB export %unsafe assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s assertCRange [] [] = [] - assertCRange (d :: s) (r :: rs) = assert' d r :: assertCRange s rs + assertCRange (d :: s) (r :: rs) = assert' r :: assertCRange s rs where - assert' : (n : Nat) -> CRangeNB -> CRange n - assert' n All = All - assert' n (StartBound x) = StartBound (believe_me x) - assert' n (EndBound x) = EndBound (believe_me x) - assert' n (Bounds x y) = Bounds (believe_me x) (believe_me y) - assert' n (Indices xs) = Indices (believe_me <$> xs) - assert' n (Filter f) = Filter (f . finToNat) + assert' : forall n. CRangeNB -> CRange n + assert' (One i) = One (assertFin i) + assert' (One' i) = One' (assertFin i) + assert' All = All + assert' (StartBound x) = StartBound (assertFin x) + assert' (EndBound x) = EndBound (assertFin x) + assert' (Bounds x y) = Bounds (assertFin x) (assertFin y) + assert' (Indices xs) = Indices (assertFin <$> xs) + assert' (Filter f) = Filter (f . finToNat) public export - cRangeNBToList : Nat -> CRangeNB -> List Nat - cRangeNBToList s All = range 0 s - cRangeNBToList s (StartBound x) = range x s - cRangeNBToList s (EndBound x) = range 0 x - cRangeNBToList s (Bounds x y) = range x y - cRangeNBToList s (Indices xs) = nub xs - cRangeNBToList s (Filter p) = filter p $ range 0 s + cRangeNBToList : Nat -> CRangeNB -> Either Nat (List Nat) + cRangeNBToList s (One i) = Left i + cRangeNBToList s (One' i) = Right [i] + cRangeNBToList s All = Right $ range 0 s + cRangeNBToList s (StartBound x) = Right $ range x s + cRangeNBToList s (EndBound x) = Right $ range 0 x + cRangeNBToList s (Bounds x y) = Right $ range x y + cRangeNBToList s (Indices xs) = Right $ nub xs + cRangeNBToList s (Filter p) = Right $ filter p $ range 0 s + + public export + newRank : Vect rk Nat -> Vect rk CRangeNB -> Nat + newRank _ [] = 0 + newRank (d :: s) (r :: rs) = + case cRangeNBToList d r of + Left _ => newRank s rs + Right _ => S (newRank s rs) ||| Calculate the new shape given by a coordinate range. public export - newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat - newShape = zipWith (length .: cRangeNBToList) + newShape : (s : Vect rk Nat) -> (is : Vect rk CRangeNB) -> Vect (newRank s is) Nat + newShape [] [] = [] + newShape (d :: s) (r :: rs) with (cRangeNBToList d r) + _ | Left _ = newShape s rs + _ | Right xs = length xs :: newShape s rs export getAllCoords' : Vect rk Nat -> List (Vect rk Nat) diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index 9a43a44..13e0316 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -568,9 +568,8 @@ solveLowerTri' {n} mat b with (viewShape b) construct [] = [] construct {i=S i} (b :: bs) = let xs = construct bs - i' = assert_total $ case natToFin i n of Just i' => i' in (b - sum (zipWith (*) xs (reverse $ toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $ - mat !!.. [One i', EndBound (weaken i')]))) / mat!#[i,i] :: xs + mat !#.. [One i, EndBound i]))) / mat!#[i,i] :: xs solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a @@ -581,9 +580,8 @@ solveUpperTri' {n} mat b with (viewShape b) construct _ [] = [] construct i (b :: bs) = let xs = construct (S i) bs - i' = assert_total $ case natToFin i n of Just i' => i' in (b - sum (zipWith (*) xs (toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $ - mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs + mat !#.. [One i, StartBound (S i)]))) / mat!#[i,i] :: xs ||| Solve a linear equation, assuming the matrix is lower triangular. @@ -628,6 +626,7 @@ solve mat = solveWithLUP mat (decompLUP mat) -------------------------------------------------------------------------------- +||| Determine whether a matrix has an inverse. export invertible : FieldCmp a => Matrix' n a -> Bool invertible {n} mat with (viewShape mat) diff --git a/src/Data/NumIdr/PrimArray/Linked.idr b/src/Data/NumIdr/PrimArray/Linked.idr index 35a8712..ac5530e 100644 --- a/src/Data/NumIdr/PrimArray/Linked.idr +++ b/src/Data/NumIdr/PrimArray/Linked.idr @@ -23,10 +23,6 @@ update : Coords s -> (a -> a) -> Vects s a -> Vects s a update [] f v = f v update (i :: is) f v = updateAt i (update is f) v -export %unsafe -assertFin : Nat -> Fin n -assertFin n = natToFinLt n @{believe_me Oh} - export indexRange : {s : _} -> (rs : CoordsRange s) -> Vects s a -> Vects (newShape rs) a indexRange [] v = v