From 98223b180f8a9319de2a42443d7f6bdc903557da Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Wed, 14 Sep 2022 13:39:12 -0400 Subject: [PATCH] Add documentation --- src/Data/NP.idr | 1 + src/Data/NumIdr/Array/Array.idr | 66 ++++++++++++++++---- src/Data/NumIdr/Array/Coords.idr | 1 + src/Data/NumIdr/Interfaces.idr | 23 ++++++- src/Data/NumIdr/Matrix.idr | 103 +++++++++++++++++++++++++++---- src/Data/NumIdr/Vector.idr | 4 +- 6 files changed, 170 insertions(+), 28 deletions(-) diff --git a/src/Data/NP.idr b/src/Data/NP.idr index 6493541..678a29a 100644 --- a/src/Data/NP.idr +++ b/src/Data/NP.idr @@ -5,6 +5,7 @@ import Data.Vect %default total +||| A type constructor for heterogenous lists. public export data NP : (f : k -> Type) -> (ts : Vect n k) -> Type where Nil : NP {k} f [] diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index dfd9c00..9294ce8 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -102,10 +102,13 @@ shapeEq : (arr : Array s a) -> s = shape arr shapeEq (MkArray _ _ _ _) = Refl +||| A view for extracting the shape of an array. public export data ShapeView : Array s a -> Type where Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr +||| The covering function for the view `ShapeView`. This function takes an array +||| of type `Array s a` and returns `Shape s`. export viewShape : (arr : Array s a) -> ShapeView arr viewShape arr = rewrite shapeEq arr in @@ -134,11 +137,16 @@ export repeat : (s : Vect rk Nat) -> a -> Array s a repeat s = repeat' s COrder - +||| Create an array filled with zeros. +||| +||| @ s The shape of the constructed array export zeros : Num a => (s : Vect rk Nat) -> Array s a zeros s = repeat s 0 +||| Create an array filled with ones. +||| +||| @ s The shape of the constructed array export ones : Num a => (s : Vect rk Nat) -> Array s a ones s = repeat s 1 @@ -265,13 +273,13 @@ arr !! is = index is arr -- TODO: Create set/update at index functions -||| Update the value at the given coordinates using the function. +||| Update the entry at the given coordinates using the function. export indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a indexUpdate is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation sts is) f arr) -||| Set the value at the given coordinates to the given value. +||| Set the entry at the given coordinates to the given value. export indexSet : Coords s -> a -> Array s a -> Array s a indexSet is = indexUpdate is . const @@ -299,6 +307,7 @@ export %inline (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a arr !!.. rs = indexRange rs arr +||| Set the sub-array at the given range of coordinates to the given array. export indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a -> Array s a -> Array s a @@ -311,6 +320,8 @@ indexSetRange rs rpl (MkArray ord sts s arr) = pure arr') +||| Update the sub-array at the given range of coordinates by applying +||| a function to it. export indexUpdateRange : (rs : CoordsRange s) -> (Array (newShape rs) a -> Array (newShape rs) a) -> @@ -334,20 +345,22 @@ export %inline (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a arr !? is = indexNB is arr -||| Update the value at the given coordinates using the function. The array is +||| Update the entry at the given coordinates using the function. The array is ||| returned unchanged if the coordinate is out of bounds. export indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a indexUpdateNB is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation' sts is) f arr) -||| Set the value at the given coordinates to the given value. The array is +||| Set the entry at the given coordinates to the given value. The array is ||| returned unchanged if the coordinate is out of bounds. export indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a indexSetNB is = indexUpdateNB is . const +||| Index the array using the given range of coordinates, returning a new array. +||| Entries outside of the array's bounds are not included. export indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a indexRangeNB {s} rs arr with (viewShape arr) @@ -362,15 +375,27 @@ indexRangeNB {s} rs arr with (viewShape arr) where s' : Vect ? Nat s' = newShape s rs +||| Index the array using the given range of coordinates, returning a new array. +||| Entries outside of the array's bounds are not included. +||| +||| This is the operator form of `indexRangeNB`. export %inline (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a arr !?.. rs = indexRangeNB rs arr +||| Index the array using the given coordinates. +||| WARNING: This function does not perform any bounds check on its inputs. +||| Misuse of this function can easily break memory safety. export indexUnsafe : Vect rk Nat -> Array {rk} s a -> a indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr) +||| Index the array using the given coordinates. +||| WARNING: This function does not perform any bounds check on its inputs. +||| Misuse of this function can easily break memory safety. +||| +||| This is the operator form of `indexUnsafe`. export %inline (!#) : Array {rk} s a -> Vect rk Nat -> a arr !# is = indexUnsafe is arr @@ -455,7 +480,8 @@ enumerate {s} arr with (viewShape arr) ||| Join two arrays along a particular axis, e.g. combining two matrices -||| vertically or horizontally. The arrays must have the same shape on all other axes. +||| vertically or horizontally. All other axes of the arrays must have the +||| same dimensions. ||| ||| @ axis The axis to join the arrays on export @@ -485,35 +511,51 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in getAxisInd FZ (i :: is) = (i, is) getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is) + +||| Construct the transpose of an array by reversing the order of its axes. export transpose : Array s a -> Array (reverse s) a transpose {s} arr with (viewShape arr) _ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is) +||| Construct the transpose of an array by reversing the order of its axes. +||| +||| This is the postfix form of `transpose`. export (.T) : Array s a -> Array (reverse s) a (.T) = transpose +||| Swap two axes in an array. export swapAxes : (i,j : Fin rk) -> Array s a -> Array (swapElems i j s) a swapAxes {s} i j arr with (viewShape arr) _ | Shape s = fromFunctionNB _ (\is => arr !# swapElems i j is) +||| Apply a permutation to the axes of an array. export permuteAxes : (p : Permutation rk) -> Array s a -> Array (permuteVect p s) a permuteAxes {s} p arr with (viewShape arr) _ | Shape s = fromFunctionNB _ (\is => arr !# permuteVect p s) +||| Swap two coordinates along a specific axis (e.g. swapping two rows in a matrix). +||| +||| @ axis The axis to swap the coordinates along. Slices of the array +||| perpendicular to this axis are taken when swapping. export -swapInAxis : (ax : Fin rk) -> (i,j : Fin (index ax s)) -> Array s a -> Array s a -swapInAxis {s} ax i j arr with (viewShape arr) - _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt ax (swapValues i j) is) +swapInAxis : (axis : Fin rk) -> (i,j : Fin (index axis s)) -> Array s a -> Array s a +swapInAxis {s} axis i j arr with (viewShape arr) + _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (swapValues i j) is) +||| Permute the coordinates along a specific axis (e.g. permuting the rows in +||| a matrix). +||| +||| @ axis The axis to permute the coordinates along. Slices of the array +||| perpendicular to this axis are taken when permuting. export -permuteInAxis : (ax : Fin rk) -> Permutation (index ax s) -> Array s a -> Array s a -permuteInAxis {s} ax p arr with (viewShape arr) - _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt ax (permuteValues p) is) +permuteInAxis : (axis : Fin rk) -> Permutation (index axis s) -> Array s a -> Array s a +permuteInAxis {s} axis p arr with (viewShape arr) + _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (permuteValues p) is) -------------------------------------------------------------------------------- diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 8877de3..afb8482 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -167,6 +167,7 @@ namespace NB cRangeNBToList s (Indices xs) = filter ( Vect rk CRangeNB -> Vect rk Nat newShape = zipWith (length .: cRangeNBToList) diff --git a/src/Data/NumIdr/Interfaces.idr b/src/Data/NumIdr/Interfaces.idr index f092408..1e9a9eb 100644 --- a/src/Data/NumIdr/Interfaces.idr +++ b/src/Data/NumIdr/Interfaces.idr @@ -8,19 +8,27 @@ module Data.NumIdr.Interfaces -------------------------------------------------------------------------------- +||| An interface synonym for types which form a field, such as `Double`. public export Field : Type -> Type Field a = (Eq a, Neg a, Fractional a) +||| An interface for number types which allow for LUP decomposition to be performed. +||| +||| The function `abslt` is required for the internal decomposition algorithm. +||| An instance of this interface must satisfy: +||| * The relation `abslt` is a preorder. +||| * `0` is a minimum of `abslt`. public export interface (Eq a, Neg a, Fractional a) => FieldCmp a where - abscmp : a -> a -> Ordering + ||| Compare the absolute values of two inputs. + abslt : a -> a -> Bool export (Ord a, Abs a, Neg a, Fractional a) => FieldCmp a where - abscmp x y = compare (abs x) (abs y) + abslt x y = abs x < abs y -------------------------------------------------------------------------------- @@ -38,8 +46,11 @@ infixr 10 ^ ||| * If `x *. (y *. z)` is defined, then `(x *. y) *. z` is defined and equal. public export interface Mult a b c | a,b where + ||| A generalized multiplication/application operator for matrices and + ||| vector transformations. (*.) : a -> b -> c +||| A synonym for `Mult a a a`, or homogenous multiplication. public export Mult' : Type -> Type Mult' a = Mult a a a @@ -52,6 +63,10 @@ Mult' a = Mult a a a ||| * `identity *. x == x` public export interface Mult' a => MultMonoid a where + ||| Construct an identity matrix or transformation. + ||| + ||| NOTE: Creating an identity element for an `n`-dimensional transformation + ||| usually requires `n` to be available at runtime. identity : a ||| An interface for groups using the `*.` operator. @@ -61,6 +76,10 @@ interface Mult' a => MultMonoid a where ||| * `inverse x *. x == identity` public export interface MultMonoid a => MultGroup a where + ||| Calculate the inverse of the matrix or transformation. + ||| WARNING: This function will not check if an inverse exists for the given + ||| input, and will happily return results containing NaN values. To avoid + ||| this, use `tryInverse` instead. inverse : a -> a diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index acc012d..a4ca435 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -61,9 +61,10 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j) eq (FS _) FZ = Nothing +||| Construct a permutation matrix based on the given permutation. export -permutationMatrix : {n : _} -> Num a => Permutation n -> Matrix' n a -permutationMatrix p = permuteInAxis 0 p (repeatDiag 1 0) +permuteM : {n : _} -> Num a => Permutation n -> Matrix' n a +permuteM p = permuteInAxis 0 p (repeatDiag 1 0) ||| Construct the matrix that scales a vector by the given value. @@ -105,19 +106,23 @@ getColumn : Fin n -> Matrix m n a -> Vector m a getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c] +||| Return the diagonal elements of the matrix as a vector. export diagonal' : Matrix m n a -> Vector (minimum m n) a diagonal' {m,n} mat with (viewShape mat) _ | Shape [m,n] = fromFunctionNB _ (\[i] => mat!#[i,i]) +||| Return the diagonal elements of the matrix as a vector. export diagonal : Matrix' n a -> Vector n a diagonal {n} mat with (viewShape mat) _ | Shape [n,n] = fromFunctionNB [n] (\[i] => mat!#[i,i]) --- TODO: throw an actual proof in here to avoid the unsafety +||| Return a minor of the matrix, i.e. the matrix formed by removing a +||| single row and column. export +-- TODO: throw an actual proof in here to avoid the unsafety minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)] @@ -157,28 +162,33 @@ export hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a hconcat = concat 1 - +||| Stack row vectors to form a matrix. export vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a vstack = stack 0 +||| Stack column vectors to form a matrix. export hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a hstack = stack 1 +||| Swap two rows of a matrix. export swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a swapRows = swapInAxis 0 +||| Swap two columns of a matrix. export swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a swapColumns = swapInAxis 1 +||| Permute the rows of a matrix. export permuteRows : Permutation m -> Matrix m n a -> Matrix m n a permuteRows = permuteInAxis 0 +||| Permute the columns of a matrix. export permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a permuteColumns = permuteInAxis 1 @@ -191,6 +201,7 @@ outer {m,n} a b with (viewShape a, viewShape b) _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j) +||| Calculate the trace of a matrix, i.e. the sum of its diagonal elements. export trace : Num a => Matrix m n a -> a trace = sum . diagonal' @@ -223,14 +234,20 @@ export -------------------------------------------------------------------------------- --- LU Decomposition +||| The LU decomposition of a matrix. +||| +||| LU decomposition factors a matrix A into two matrices: a lower triangular +||| matrix L, and an upper triangular matrix U, such that A = LU. export record DecompLU {0 m,n,a : _} (mat : Matrix m n a) where constructor MkLU + -- The lower and upper triangular matrix elements are stored + -- together for efficiency reasons lu : Matrix m n a namespace DecompLU + ||| The lower triangular matrix L of the LU decomposition. export lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a lower {m,n} (MkLU lu) with (viewShape lu) @@ -240,34 +257,49 @@ namespace DecompLU EQ => 1 GT => lu!#[i,j]) + ||| The lower triangular matrix L of the LU decomposition. export %inline (.lower) : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a (.lower) = lower + ||| The lower triangular matrix L of the LU decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export lower' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n) in lower lu + ||| The lower triangular matrix L of the LU decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export %inline (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a (.lower') = lower' + ||| The upper triangular matrix U of the LU decomposition. export upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a upper {m,n} (MkLU lu) with (viewShape lu) _ | Shape [m,n] = fromFunctionNB _ (\[i,j] => if i <= j then lu!#[i,j] else 0) + ||| The upper triangular matrix U of the LU decomposition. export %inline (.upper) : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a (.upper) = upper + ||| The upper triangular matrix U of the LU decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export upper' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n) in upper lu + ||| The upper triangular matrix U of the LU decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export %inline (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a (.upper') = upper' @@ -296,6 +328,7 @@ iterateN 1 f x = f FZ x iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x +||| Perform a single step of Gaussian elimination on the `i`-th row and column. gaussStep : {m,n : _} -> Field a => Fin (minimum m n) -> Matrix m n a -> Matrix m n a gaussStep i lu = @@ -304,13 +337,14 @@ gaussStep i lu = ic = minWeakenRight i diag = lu!![ir,ic] coeffs = map (/diag) $ lu!!..[StartBound (FS ir), One ic] - lu' = indexSetRange [StartBound (FS ir), One ic] - coeffs lu + lu' = indexSetRange [StartBound (FS ir), One ic] coeffs lu pivot = lu!!..[One ir, StartBound (FS ic)] offsets = outer coeffs pivot in indexUpdateRange [StartBound (FS ir), StartBound (FS ic)] (flip (-) offsets) lu' +||| Calculate the LU decomposition of a matrix, returning `Nothing` if one +||| does not exist. export decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat) decompLU {m,n} mat with (viewShape mat) @@ -321,7 +355,12 @@ decompLU {m,n} mat with (viewShape mat) else Just (gaussStep i mat) --- LUP Decomposition +||| The LUP decomposition of a matrix. +||| +||| LUP decomposition is similar to LU decomposition, but the matrix may have +||| its rows permuted before being factored. More formally, an LUP decomposition +||| of a matrix A consists of a lower triangular matrix L, an upper triangular +||| matrix U, and a permutation matrix P, such that PA = LU. export record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where constructor MkLUP @@ -330,6 +369,7 @@ record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where sw : Nat namespace DecompLUP + ||| The lower triangular matrix L of the LUP decomposition. export lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a lower {m,n} (MkLUP lu _ _) with (viewShape lu) @@ -339,55 +379,78 @@ namespace DecompLUP EQ => 1 GT => lu!#[i,j]) + ||| The lower triangular matrix L of the LUP decomposition. export %inline (.lower) : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a (.lower) = lower + ||| The lower triangular matrix L of the LUP decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export lower' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n) in lower lu + ||| The lower triangular matrix L of the LUP decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export %inline (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a (.lower') = lower' + ||| The upper triangular matrix U of the LUP decomposition. export upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a upper {m,n} (MkLUP lu _ _) with (viewShape lu) _ | Shape [m,n] = fromFunctionNB _ (\[i,j] => if i <= j then lu!#[i,j] else 0) + ||| The upper triangular matrix U of the LUP decomposition. export %inline (.upper) : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a (.upper) = upper + ||| The upper triangular matrix U of the LUP decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export upper' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n) in upper lu + ||| The upper triangular matrix U of the LUP decomposition. + ||| + ||| This accessor is intended to be used for square matrix decompositions. export %inline (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a (.upper') = upper' + ||| The row permutation of the LUP decomposition. export permute : DecompLUP {m} mat -> Permutation m permute (MkLUP lu p sw) = p + ||| The row permutation of the LUP decomposition. export %inline (.permute) : DecompLUP {m} mat -> Permutation m (.permute) = permute + ||| The number of swaps in the permutation of the LUP decomposition. + ||| + ||| This is stored along with the permutation in order to increase the + ||| efficiency of certain algorithms. export numSwaps : DecompLUP mat -> Nat numSwaps (MkLUP lu p sw) = sw + +||| Convert an LU decomposition into an LUP decomposition. export fromLU : DecompLU mat -> DecompLUP mat fromLU (MkLU lu) = MkLUP lu identity 0 - +||| Calculate the LUP decomposition of a matrix. export decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat decompLUP {m,n} mat with (viewShape mat) @@ -400,8 +463,8 @@ decompLUP {m,n} mat with (viewShape mat) maxIndex x [] = x maxIndex _ [x] = x maxIndex x ((a,b)::(c,d)::xs) = - if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs) - else assert_total $ maxIndex x ((a,b)::xs) + if abslt b d then assert_total $ maxIndex x ((c,d)::xs) + else assert_total $ maxIndex x ((a,b)::xs) gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat gaussStepSwap i (MkLUP lu p sw) = @@ -418,12 +481,14 @@ decompLUP {m,n} mat with (viewShape mat) -------------------------------------------------------------------------------- +||| Calculate the determinant of a matrix given its LUP decomposition. export detWithLUP : Num a => (mat : Matrix' n a) -> DecompLUP mat -> a detWithLUP mat lup = (if numSwaps lup `mod` 2 == 0 then 1 else -1) * product (diagonal lup.lu) +||| Calculate the determinant of a matrix. export det : FieldCmp a => Matrix' n a -> a det {n} mat with (viewShape mat) @@ -464,12 +529,16 @@ solveUpperTri' {n} mat b with (viewShape b) mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs +||| Solve a linear equation, assuming the matrix is lower triangular. +||| Any entries other than those below the diagonal are ignored. export solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveLowerTri mat b = if all (/=0) (diagonal mat) then Just $ solveLowerTri' mat b else Nothing +||| Solve a linear equation, assuming the matrix is upper triangular. +||| Any entries other than those above the diagonal are ignored. export solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveUpperTri mat b = if all (/=0) (diagonal mat) @@ -483,19 +552,27 @@ solveWithLUP' mat lup b = let b' = permuteCoords (inverse lup.permute) b in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b' +||| Solve a linear equation, given a matrix and its LUP decomposition. export solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat -> Vector n a -> Maybe (Vector n a) solveWithLUP mat lup b = let b' = permuteCoords (inverse lup.permute) b - in solveLowerTri lup.lower' b' >>= solveUpperTri lup.upper' - + in solveUpperTri lup.upper' $ solveLowerTri' lup.lower' b' +||| Solve a linear equation given a matrix. export solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solve mat = solveWithLUP mat (decompLUP mat) +-------------------------------------------------------------------------------- +-- Matrix inversion +-------------------------------------------------------------------------------- + + +||| Try to invert a square matrix, returning `Nothing` if an inverse +||| does not exist. export tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a) tryInverse {n} mat with (viewShape mat) diff --git a/src/Data/NumIdr/Vector.idr b/src/Data/NumIdr/Vector.idr index 620d5d6..37ddc69 100644 --- a/src/Data/NumIdr/Vector.idr +++ b/src/Data/NumIdr/Vector.idr @@ -14,7 +14,7 @@ Vector : Nat -> Type -> Type Vector n = Array [n] -||| The length (number of dimensions) of the vector +||| The length (number of dimensions) of the vector. public export dim : Vector n a -> Nat dim = head . shape @@ -131,10 +131,12 @@ swizzle p v = rewrite sym (lengthCorrect p) ) +||| Swap two entries in a vector. export swapCoords : (i,j : Fin n) -> Vector n a -> Vector n a swapCoords = swapInAxis 0 +||| Permute the entries in a vector. export permuteCoords : Permutation n -> Vector n a -> Vector n a permuteCoords = permuteInAxis 0