Add documentation

This commit is contained in:
Kiana Sheibani 2022-09-14 13:39:12 -04:00
parent d610874abc
commit 98223b180f
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
6 changed files with 170 additions and 28 deletions

View file

@ -5,6 +5,7 @@ import Data.Vect
%default total %default total
||| A type constructor for heterogenous lists.
public export public export
data NP : (f : k -> Type) -> (ts : Vect n k) -> Type where data NP : (f : k -> Type) -> (ts : Vect n k) -> Type where
Nil : NP {k} f [] Nil : NP {k} f []

View file

@ -102,10 +102,13 @@ shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl shapeEq (MkArray _ _ _ _) = Refl
||| A view for extracting the shape of an array.
public export public export
data ShapeView : Array s a -> Type where data ShapeView : Array s a -> Type where
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
||| The covering function for the view `ShapeView`. This function takes an array
||| of type `Array s a` and returns `Shape s`.
export export
viewShape : (arr : Array s a) -> ShapeView arr viewShape : (arr : Array s a) -> ShapeView arr
viewShape arr = rewrite shapeEq arr in viewShape arr = rewrite shapeEq arr in
@ -134,11 +137,16 @@ export
repeat : (s : Vect rk Nat) -> a -> Array s a repeat : (s : Vect rk Nat) -> a -> Array s a
repeat s = repeat' s COrder repeat s = repeat' s COrder
||| Create an array filled with zeros.
|||
||| @ s The shape of the constructed array
export export
zeros : Num a => (s : Vect rk Nat) -> Array s a zeros : Num a => (s : Vect rk Nat) -> Array s a
zeros s = repeat s 0 zeros s = repeat s 0
||| Create an array filled with ones.
|||
||| @ s The shape of the constructed array
export export
ones : Num a => (s : Vect rk Nat) -> Array s a ones : Num a => (s : Vect rk Nat) -> Array s a
ones s = repeat s 1 ones s = repeat s 1
@ -265,13 +273,13 @@ arr !! is = index is arr
-- TODO: Create set/update at index functions -- TODO: Create set/update at index functions
||| Update the value at the given coordinates using the function. ||| Update the entry at the given coordinates using the function.
export export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
indexUpdate is f (MkArray ord sts s arr) = indexUpdate is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation sts is) f arr) MkArray ord sts s (updateAt (getLocation sts is) f arr)
||| Set the value at the given coordinates to the given value. ||| Set the entry at the given coordinates to the given value.
export export
indexSet : Coords s -> a -> Array s a -> Array s a indexSet : Coords s -> a -> Array s a -> Array s a
indexSet is = indexUpdate is . const indexSet is = indexUpdate is . const
@ -299,6 +307,7 @@ export %inline
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
arr !!.. rs = indexRange rs arr arr !!.. rs = indexRange rs arr
||| Set the sub-array at the given range of coordinates to the given array.
export export
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a -> indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
Array s a -> Array s a Array s a -> Array s a
@ -311,6 +320,8 @@ indexSetRange rs rpl (MkArray ord sts s arr) =
pure arr') pure arr')
||| Update the sub-array at the given range of coordinates by applying
||| a function to it.
export export
indexUpdateRange : (rs : CoordsRange s) -> indexUpdateRange : (rs : CoordsRange s) ->
(Array (newShape rs) a -> Array (newShape rs) a) -> (Array (newShape rs) a -> Array (newShape rs) a) ->
@ -334,20 +345,22 @@ export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr arr !? is = indexNB is arr
||| Update the value at the given coordinates using the function. The array is ||| Update the entry at the given coordinates using the function. The array is
||| returned unchanged if the coordinate is out of bounds. ||| returned unchanged if the coordinate is out of bounds.
export export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
indexUpdateNB is f (MkArray ord sts s arr) = indexUpdateNB is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation' sts is) f arr) MkArray ord sts s (updateAt (getLocation' sts is) f arr)
||| Set the value at the given coordinates to the given value. The array is ||| Set the entry at the given coordinates to the given value. The array is
||| returned unchanged if the coordinate is out of bounds. ||| returned unchanged if the coordinate is out of bounds.
export export
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
indexSetNB is = indexUpdateNB is . const indexSetNB is = indexUpdateNB is . const
||| Index the array using the given range of coordinates, returning a new array.
||| Entries outside of the array's bounds are not included.
export export
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
indexRangeNB {s} rs arr with (viewShape arr) indexRangeNB {s} rs arr with (viewShape arr)
@ -362,15 +375,27 @@ indexRangeNB {s} rs arr with (viewShape arr)
where s' : Vect ? Nat where s' : Vect ? Nat
s' = newShape s rs s' = newShape s rs
||| Index the array using the given range of coordinates, returning a new array.
||| Entries outside of the array's bounds are not included.
|||
||| This is the operator form of `indexRangeNB`.
export %inline export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
arr !?.. rs = indexRangeNB rs arr arr !?.. rs = indexRangeNB rs arr
||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
export export
indexUnsafe : Vect rk Nat -> Array {rk} s a -> a indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr) indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr)
||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
|||
||| This is the operator form of `indexUnsafe`.
export %inline export %inline
(!#) : Array {rk} s a -> Vect rk Nat -> a (!#) : Array {rk} s a -> Vect rk Nat -> a
arr !# is = indexUnsafe is arr arr !# is = indexUnsafe is arr
@ -455,7 +480,8 @@ enumerate {s} arr with (viewShape arr)
||| Join two arrays along a particular axis, e.g. combining two matrices ||| Join two arrays along a particular axis, e.g. combining two matrices
||| vertically or horizontally. The arrays must have the same shape on all other axes. ||| vertically or horizontally. All other axes of the arrays must have the
||| same dimensions.
||| |||
||| @ axis The axis to join the arrays on ||| @ axis The axis to join the arrays on
export export
@ -485,35 +511,51 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in
getAxisInd FZ (i :: is) = (i, is) getAxisInd FZ (i :: is) = (i, is)
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is) getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
||| Construct the transpose of an array by reversing the order of its axes.
export export
transpose : Array s a -> Array (reverse s) a transpose : Array s a -> Array (reverse s) a
transpose {s} arr with (viewShape arr) transpose {s} arr with (viewShape arr)
_ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is) _ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is)
||| Construct the transpose of an array by reversing the order of its axes.
|||
||| This is the postfix form of `transpose`.
export export
(.T) : Array s a -> Array (reverse s) a (.T) : Array s a -> Array (reverse s) a
(.T) = transpose (.T) = transpose
||| Swap two axes in an array.
export export
swapAxes : (i,j : Fin rk) -> Array s a -> Array (swapElems i j s) a swapAxes : (i,j : Fin rk) -> Array s a -> Array (swapElems i j s) a
swapAxes {s} i j arr with (viewShape arr) swapAxes {s} i j arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# swapElems i j is) _ | Shape s = fromFunctionNB _ (\is => arr !# swapElems i j is)
||| Apply a permutation to the axes of an array.
export export
permuteAxes : (p : Permutation rk) -> Array s a -> Array (permuteVect p s) a permuteAxes : (p : Permutation rk) -> Array s a -> Array (permuteVect p s) a
permuteAxes {s} p arr with (viewShape arr) permuteAxes {s} p arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# permuteVect p s) _ | Shape s = fromFunctionNB _ (\is => arr !# permuteVect p s)
||| Swap two coordinates along a specific axis (e.g. swapping two rows in a matrix).
|||
||| @ axis The axis to swap the coordinates along. Slices of the array
||| perpendicular to this axis are taken when swapping.
export export
swapInAxis : (ax : Fin rk) -> (i,j : Fin (index ax s)) -> Array s a -> Array s a swapInAxis : (axis : Fin rk) -> (i,j : Fin (index axis s)) -> Array s a -> Array s a
swapInAxis {s} ax i j arr with (viewShape arr) swapInAxis {s} axis i j arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# updateAt ax (swapValues i j) is) _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (swapValues i j) is)
||| Permute the coordinates along a specific axis (e.g. permuting the rows in
||| a matrix).
|||
||| @ axis The axis to permute the coordinates along. Slices of the array
||| perpendicular to this axis are taken when permuting.
export export
permuteInAxis : (ax : Fin rk) -> Permutation (index ax s) -> Array s a -> Array s a permuteInAxis : (axis : Fin rk) -> Permutation (index axis s) -> Array s a -> Array s a
permuteInAxis {s} ax p arr with (viewShape arr) permuteInAxis {s} axis p arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# updateAt ax (permuteValues p) is) _ | Shape s = fromFunctionNB _ (\is => arr !# updateAt axis (permuteValues p) is)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View file

@ -167,6 +167,7 @@ namespace NB
cRangeNBToList s (Indices xs) = filter (<s) xs cRangeNBToList s (Indices xs) = filter (<s) xs
cRangeNBToList s (Filter p) = filter p $ range 0 s cRangeNBToList s (Filter p) = filter p $ range 0 s
||| Calculate the new shape given by a coordinate range.
export export
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
newShape = zipWith (length .: cRangeNBToList) newShape = zipWith (length .: cRangeNBToList)

View file

@ -8,19 +8,27 @@ module Data.NumIdr.Interfaces
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| An interface synonym for types which form a field, such as `Double`.
public export public export
Field : Type -> Type Field : Type -> Type
Field a = (Eq a, Neg a, Fractional a) Field a = (Eq a, Neg a, Fractional a)
||| An interface for number types which allow for LUP decomposition to be performed.
|||
||| The function `abslt` is required for the internal decomposition algorithm.
||| An instance of this interface must satisfy:
||| * The relation `abslt` is a preorder.
||| * `0` is a minimum of `abslt`.
public export public export
interface (Eq a, Neg a, Fractional a) => FieldCmp a where interface (Eq a, Neg a, Fractional a) => FieldCmp a where
abscmp : a -> a -> Ordering ||| Compare the absolute values of two inputs.
abslt : a -> a -> Bool
export export
(Ord a, Abs a, Neg a, Fractional a) => FieldCmp a where (Ord a, Abs a, Neg a, Fractional a) => FieldCmp a where
abscmp x y = compare (abs x) (abs y) abslt x y = abs x < abs y
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -38,8 +46,11 @@ infixr 10 ^
||| * If `x *. (y *. z)` is defined, then `(x *. y) *. z` is defined and equal. ||| * If `x *. (y *. z)` is defined, then `(x *. y) *. z` is defined and equal.
public export public export
interface Mult a b c | a,b where interface Mult a b c | a,b where
||| A generalized multiplication/application operator for matrices and
||| vector transformations.
(*.) : a -> b -> c (*.) : a -> b -> c
||| A synonym for `Mult a a a`, or homogenous multiplication.
public export public export
Mult' : Type -> Type Mult' : Type -> Type
Mult' a = Mult a a a Mult' a = Mult a a a
@ -52,6 +63,10 @@ Mult' a = Mult a a a
||| * `identity *. x == x` ||| * `identity *. x == x`
public export public export
interface Mult' a => MultMonoid a where interface Mult' a => MultMonoid a where
||| Construct an identity matrix or transformation.
|||
||| NOTE: Creating an identity element for an `n`-dimensional transformation
||| usually requires `n` to be available at runtime.
identity : a identity : a
||| An interface for groups using the `*.` operator. ||| An interface for groups using the `*.` operator.
@ -61,6 +76,10 @@ interface Mult' a => MultMonoid a where
||| * `inverse x *. x == identity` ||| * `inverse x *. x == identity`
public export public export
interface MultMonoid a => MultGroup a where interface MultMonoid a => MultGroup a where
||| Calculate the inverse of the matrix or transformation.
||| WARNING: This function will not check if an inverse exists for the given
||| input, and will happily return results containing NaN values. To avoid
||| this, use `tryInverse` instead.
inverse : a -> a inverse : a -> a

View file

@ -61,9 +61,10 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
eq (FS _) FZ = Nothing eq (FS _) FZ = Nothing
||| Construct a permutation matrix based on the given permutation.
export export
permutationMatrix : {n : _} -> Num a => Permutation n -> Matrix' n a permuteM : {n : _} -> Num a => Permutation n -> Matrix' n a
permutationMatrix p = permuteInAxis 0 p (repeatDiag 1 0) permuteM p = permuteInAxis 0 p (repeatDiag 1 0)
||| Construct the matrix that scales a vector by the given value. ||| Construct the matrix that scales a vector by the given value.
@ -105,19 +106,23 @@ getColumn : Fin n -> Matrix m n a -> Vector m a
getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c] getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c]
||| Return the diagonal elements of the matrix as a vector.
export export
diagonal' : Matrix m n a -> Vector (minimum m n) a diagonal' : Matrix m n a -> Vector (minimum m n) a
diagonal' {m,n} mat with (viewShape mat) diagonal' {m,n} mat with (viewShape mat)
_ | Shape [m,n] = fromFunctionNB _ (\[i] => mat!#[i,i]) _ | Shape [m,n] = fromFunctionNB _ (\[i] => mat!#[i,i])
||| Return the diagonal elements of the matrix as a vector.
export export
diagonal : Matrix' n a -> Vector n a diagonal : Matrix' n a -> Vector n a
diagonal {n} mat with (viewShape mat) diagonal {n} mat with (viewShape mat)
_ | Shape [n,n] = fromFunctionNB [n] (\[i] => mat!#[i,i]) _ | Shape [n,n] = fromFunctionNB [n] (\[i] => mat!#[i,i])
-- TODO: throw an actual proof in here to avoid the unsafety ||| Return a minor of the matrix, i.e. the matrix formed by removing a
||| single row and column.
export export
-- TODO: throw an actual proof in here to avoid the unsafety
minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a
minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)] minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)]
@ -157,28 +162,33 @@ export
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
hconcat = concat 1 hconcat = concat 1
||| Stack row vectors to form a matrix.
export export
vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a
vstack = stack 0 vstack = stack 0
||| Stack column vectors to form a matrix.
export export
hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a
hstack = stack 1 hstack = stack 1
||| Swap two rows of a matrix.
export export
swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a
swapRows = swapInAxis 0 swapRows = swapInAxis 0
||| Swap two columns of a matrix.
export export
swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a
swapColumns = swapInAxis 1 swapColumns = swapInAxis 1
||| Permute the rows of a matrix.
export export
permuteRows : Permutation m -> Matrix m n a -> Matrix m n a permuteRows : Permutation m -> Matrix m n a -> Matrix m n a
permuteRows = permuteInAxis 0 permuteRows = permuteInAxis 0
||| Permute the columns of a matrix.
export export
permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a
permuteColumns = permuteInAxis 1 permuteColumns = permuteInAxis 1
@ -191,6 +201,7 @@ outer {m,n} a b with (viewShape a, viewShape b)
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j) _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j)
||| Calculate the trace of a matrix, i.e. the sum of its diagonal elements.
export export
trace : Num a => Matrix m n a -> a trace : Num a => Matrix m n a -> a
trace = sum . diagonal' trace = sum . diagonal'
@ -223,14 +234,20 @@ export
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- LU Decomposition ||| The LU decomposition of a matrix.
|||
||| LU decomposition factors a matrix A into two matrices: a lower triangular
||| matrix L, and an upper triangular matrix U, such that A = LU.
export export
record DecompLU {0 m,n,a : _} (mat : Matrix m n a) where record DecompLU {0 m,n,a : _} (mat : Matrix m n a) where
constructor MkLU constructor MkLU
-- The lower and upper triangular matrix elements are stored
-- together for efficiency reasons
lu : Matrix m n a lu : Matrix m n a
namespace DecompLU namespace DecompLU
||| The lower triangular matrix L of the LU decomposition.
export export
lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
lower {m,n} (MkLU lu) with (viewShape lu) lower {m,n} (MkLU lu) with (viewShape lu)
@ -240,34 +257,49 @@ namespace DecompLU
EQ => 1 EQ => 1
GT => lu!#[i,j]) GT => lu!#[i,j])
||| The lower triangular matrix L of the LU decomposition.
export %inline export %inline
(.lower) : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a (.lower) : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
(.lower) = lower (.lower) = lower
||| The lower triangular matrix L of the LU decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export export
lower' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a lower' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n) lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n)
in lower lu in lower lu
||| The lower triangular matrix L of the LU decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export %inline export %inline
(.lower') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
(.lower') = lower' (.lower') = lower'
||| The upper triangular matrix U of the LU decomposition.
export export
upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
upper {m,n} (MkLU lu) with (viewShape lu) upper {m,n} (MkLU lu) with (viewShape lu)
_ | Shape [m,n] = fromFunctionNB _ (\[i,j] => _ | Shape [m,n] = fromFunctionNB _ (\[i,j] =>
if i <= j then lu!#[i,j] else 0) if i <= j then lu!#[i,j] else 0)
||| The upper triangular matrix U of the LU decomposition.
export %inline export %inline
(.upper) : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a (.upper) : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
(.upper) = upper (.upper) = upper
||| The upper triangular matrix U of the LU decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export export
upper' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a upper' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n) upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n)
in upper lu in upper lu
||| The upper triangular matrix U of the LU decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export %inline export %inline
(.upper') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a
(.upper') = upper' (.upper') = upper'
@ -296,6 +328,7 @@ iterateN 1 f x = f FZ x
iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
||| Perform a single step of Gaussian elimination on the `i`-th row and column.
gaussStep : {m,n : _} -> Field a => gaussStep : {m,n : _} -> Field a =>
Fin (minimum m n) -> Matrix m n a -> Matrix m n a Fin (minimum m n) -> Matrix m n a -> Matrix m n a
gaussStep i lu = gaussStep i lu =
@ -304,13 +337,14 @@ gaussStep i lu =
ic = minWeakenRight i ic = minWeakenRight i
diag = lu!![ir,ic] diag = lu!![ir,ic]
coeffs = map (/diag) $ lu!!..[StartBound (FS ir), One ic] coeffs = map (/diag) $ lu!!..[StartBound (FS ir), One ic]
lu' = indexSetRange [StartBound (FS ir), One ic] lu' = indexSetRange [StartBound (FS ir), One ic] coeffs lu
coeffs lu
pivot = lu!!..[One ir, StartBound (FS ic)] pivot = lu!!..[One ir, StartBound (FS ic)]
offsets = outer coeffs pivot offsets = outer coeffs pivot
in indexUpdateRange [StartBound (FS ir), StartBound (FS ic)] in indexUpdateRange [StartBound (FS ir), StartBound (FS ic)]
(flip (-) offsets) lu' (flip (-) offsets) lu'
||| Calculate the LU decomposition of a matrix, returning `Nothing` if one
||| does not exist.
export export
decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat) decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat)
decompLU {m,n} mat with (viewShape mat) decompLU {m,n} mat with (viewShape mat)
@ -321,7 +355,12 @@ decompLU {m,n} mat with (viewShape mat)
else Just (gaussStep i mat) else Just (gaussStep i mat)
-- LUP Decomposition ||| The LUP decomposition of a matrix.
|||
||| LUP decomposition is similar to LU decomposition, but the matrix may have
||| its rows permuted before being factored. More formally, an LUP decomposition
||| of a matrix A consists of a lower triangular matrix L, an upper triangular
||| matrix U, and a permutation matrix P, such that PA = LU.
export export
record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where
constructor MkLUP constructor MkLUP
@ -330,6 +369,7 @@ record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where
sw : Nat sw : Nat
namespace DecompLUP namespace DecompLUP
||| The lower triangular matrix L of the LUP decomposition.
export export
lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
lower {m,n} (MkLUP lu _ _) with (viewShape lu) lower {m,n} (MkLUP lu _ _) with (viewShape lu)
@ -339,55 +379,78 @@ namespace DecompLUP
EQ => 1 EQ => 1
GT => lu!#[i,j]) GT => lu!#[i,j])
||| The lower triangular matrix L of the LUP decomposition.
export %inline export %inline
(.lower) : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a (.lower) : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
(.lower) = lower (.lower) = lower
||| The lower triangular matrix L of the LUP decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export export
lower' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a lower' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n) lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n)
in lower lu in lower lu
||| The lower triangular matrix L of the LUP decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export %inline export %inline
(.lower') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
(.lower') = lower' (.lower') = lower'
||| The upper triangular matrix U of the LUP decomposition.
export export
upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
upper {m,n} (MkLUP lu _ _) with (viewShape lu) upper {m,n} (MkLUP lu _ _) with (viewShape lu)
_ | Shape [m,n] = fromFunctionNB _ (\[i,j] => _ | Shape [m,n] = fromFunctionNB _ (\[i,j] =>
if i <= j then lu!#[i,j] else 0) if i <= j then lu!#[i,j] else 0)
||| The upper triangular matrix U of the LUP decomposition.
export %inline export %inline
(.upper) : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a (.upper) : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
(.upper) = upper (.upper) = upper
||| The upper triangular matrix U of the LUP decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export export
upper' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a upper' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n) upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n)
in upper lu in upper lu
||| The upper triangular matrix U of the LUP decomposition.
|||
||| This accessor is intended to be used for square matrix decompositions.
export %inline export %inline
(.upper') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a
(.upper') = upper' (.upper') = upper'
||| The row permutation of the LUP decomposition.
export export
permute : DecompLUP {m} mat -> Permutation m permute : DecompLUP {m} mat -> Permutation m
permute (MkLUP lu p sw) = p permute (MkLUP lu p sw) = p
||| The row permutation of the LUP decomposition.
export %inline export %inline
(.permute) : DecompLUP {m} mat -> Permutation m (.permute) : DecompLUP {m} mat -> Permutation m
(.permute) = permute (.permute) = permute
||| The number of swaps in the permutation of the LUP decomposition.
|||
||| This is stored along with the permutation in order to increase the
||| efficiency of certain algorithms.
export export
numSwaps : DecompLUP mat -> Nat numSwaps : DecompLUP mat -> Nat
numSwaps (MkLUP lu p sw) = sw numSwaps (MkLUP lu p sw) = sw
||| Convert an LU decomposition into an LUP decomposition.
export export
fromLU : DecompLU mat -> DecompLUP mat fromLU : DecompLU mat -> DecompLUP mat
fromLU (MkLU lu) = MkLUP lu identity 0 fromLU (MkLU lu) = MkLUP lu identity 0
||| Calculate the LUP decomposition of a matrix.
export export
decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat
decompLUP {m,n} mat with (viewShape mat) decompLUP {m,n} mat with (viewShape mat)
@ -400,8 +463,8 @@ decompLUP {m,n} mat with (viewShape mat)
maxIndex x [] = x maxIndex x [] = x
maxIndex _ [x] = x maxIndex _ [x] = x
maxIndex x ((a,b)::(c,d)::xs) = maxIndex x ((a,b)::(c,d)::xs) =
if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs) if abslt b d then assert_total $ maxIndex x ((c,d)::xs)
else assert_total $ maxIndex x ((a,b)::xs) else assert_total $ maxIndex x ((a,b)::xs)
gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat
gaussStepSwap i (MkLUP lu p sw) = gaussStepSwap i (MkLUP lu p sw) =
@ -418,12 +481,14 @@ decompLUP {m,n} mat with (viewShape mat)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Calculate the determinant of a matrix given its LUP decomposition.
export export
detWithLUP : Num a => (mat : Matrix' n a) -> DecompLUP mat -> a detWithLUP : Num a => (mat : Matrix' n a) -> DecompLUP mat -> a
detWithLUP mat lup = detWithLUP mat lup =
(if numSwaps lup `mod` 2 == 0 then 1 else -1) (if numSwaps lup `mod` 2 == 0 then 1 else -1)
* product (diagonal lup.lu) * product (diagonal lup.lu)
||| Calculate the determinant of a matrix.
export export
det : FieldCmp a => Matrix' n a -> a det : FieldCmp a => Matrix' n a -> a
det {n} mat with (viewShape mat) det {n} mat with (viewShape mat)
@ -464,12 +529,16 @@ solveUpperTri' {n} mat b with (viewShape b)
mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs
||| Solve a linear equation, assuming the matrix is lower triangular.
||| Any entries other than those below the diagonal are ignored.
export export
solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
solveLowerTri mat b = if all (/=0) (diagonal mat) solveLowerTri mat b = if all (/=0) (diagonal mat)
then Just $ solveLowerTri' mat b then Just $ solveLowerTri' mat b
else Nothing else Nothing
||| Solve a linear equation, assuming the matrix is upper triangular.
||| Any entries other than those above the diagonal are ignored.
export export
solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
solveUpperTri mat b = if all (/=0) (diagonal mat) solveUpperTri mat b = if all (/=0) (diagonal mat)
@ -483,19 +552,27 @@ solveWithLUP' mat lup b =
let b' = permuteCoords (inverse lup.permute) b let b' = permuteCoords (inverse lup.permute) b
in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b' in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b'
||| Solve a linear equation, given a matrix and its LUP decomposition.
export export
solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat -> solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
Vector n a -> Maybe (Vector n a) Vector n a -> Maybe (Vector n a)
solveWithLUP mat lup b = solveWithLUP mat lup b =
let b' = permuteCoords (inverse lup.permute) b let b' = permuteCoords (inverse lup.permute) b
in solveLowerTri lup.lower' b' >>= solveUpperTri lup.upper' in solveUpperTri lup.upper' $ solveLowerTri' lup.lower' b'
||| Solve a linear equation given a matrix.
export export
solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
solve mat = solveWithLUP mat (decompLUP mat) solve mat = solveWithLUP mat (decompLUP mat)
--------------------------------------------------------------------------------
-- Matrix inversion
--------------------------------------------------------------------------------
||| Try to invert a square matrix, returning `Nothing` if an inverse
||| does not exist.
export export
tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a) tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a)
tryInverse {n} mat with (viewShape mat) tryInverse {n} mat with (viewShape mat)

View file

@ -14,7 +14,7 @@ Vector : Nat -> Type -> Type
Vector n = Array [n] Vector n = Array [n]
||| The length (number of dimensions) of the vector ||| The length (number of dimensions) of the vector.
public export public export
dim : Vector n a -> Nat dim : Vector n a -> Nat
dim = head . shape dim = head . shape
@ -131,10 +131,12 @@ swizzle p v = rewrite sym (lengthCorrect p)
) )
||| Swap two entries in a vector.
export export
swapCoords : (i,j : Fin n) -> Vector n a -> Vector n a swapCoords : (i,j : Fin n) -> Vector n a -> Vector n a
swapCoords = swapInAxis 0 swapCoords = swapInAxis 0
||| Permute the entries in a vector.
export export
permuteCoords : Permutation n -> Vector n a -> Vector n a permuteCoords : Permutation n -> Vector n a -> Vector n a
permuteCoords = permuteInAxis 0 permuteCoords = permuteInAxis 0