module Data.NumIdr.Matrix import Data.List import Data.Vect import Data.Permutation import Data.NumIdr.Interfaces import public Data.NumIdr.Array import Data.NumIdr.PrimArray import Data.NumIdr.Vector %default total ||| A matrix is a rank-2 array. public export Matrix : Nat -> Nat -> Type -> Type Matrix m n = Array [m,n] %name Matrix mat ||| A synonym for a square matrix with dimensions of length `n`. public export Matrix' : Nat -> Type -> Type Matrix' n = Array [n,n] -------------------------------------------------------------------------------- -- Matrix constructors -------------------------------------------------------------------------------- ||| Construct a matrix with the given order and elements. export matrix : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} -> Vect m (Vect n a) -> Matrix m n a matrix x = array {rep, s=[m,n]} x ||| Construct a matrix with a specific value along the diagonal. ||| ||| @ diag The value to repeat along the diagonal ||| @ other The value to repeat elsewhere export repeatDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} -> (diag, other : a) -> Matrix m n a repeatDiag d o = fromFunctionNB {rep} [m,n] (\[i,j] => if i == j then d else o) ||| Construct a matrix given its diagonal elements. ||| ||| @ diag The elements of the matrix's diagonal ||| @ other The value to repeat elsewhere export fromDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a fromDiag ds o = fromFunction {rep} [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j) where eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n)) eq FZ FZ = Just FZ eq (FS x) (FS y) = map FS (eq x y) eq FZ (FS _) = Nothing eq (FS _) FZ = Nothing ||| Construct a permutation matrix based on the given permutation. export permuteM : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a => Permutation n -> Matrix' n a permuteM p = permuteInAxis 0 p (repeatDiag {rep} 1 0) ||| Construct the matrix that scales a vector by the given value. export scale : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a => a -> Matrix' n a scale x = repeatDiag {rep} x 0 ||| Construct a 2D rotation matrix that rotates by the given angle (in radians). export rotate2D : {default B rep : Rep} -> RepConstraint rep Double => Double -> Matrix' 2 Double rotate2D a = matrix {rep} [[cos a, - sin a], [sin a, cos a]] ||| Construct a 3D rotation matrix around the x-axis. export rotate3DX : {default B rep : Rep} -> RepConstraint rep Double => Double -> Matrix' 3 Double rotate3DX a = matrix {rep} [[1,0,0], [0, cos a, - sin a], [0, sin a, cos a]] ||| Construct a 3D rotation matrix around the y-axis. export rotate3DY : {default B rep : Rep} -> RepConstraint rep Double => Double -> Matrix' 3 Double rotate3DY a = matrix {rep} [[cos a, 0, sin a], [0,1,0], [- sin a, 0, cos a]] ||| Construct a 3D rotation matrix around the z-axis. export rotate3DZ : {default B rep : Rep} -> RepConstraint rep Double => Double -> Matrix' 3 Double rotate3DZ a = matrix {rep} [[cos a, - sin a, 0], [sin a, cos a, 0], [0,0,1]] export reflect : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Neg a => Fin n -> Matrix' n a reflect i = indexSet [i, i] (-1) (repeatDiag {rep} 1 0) export reflectX : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Neg a => Matrix' (1 + n) a reflectX = reflect {rep} 0 export reflectY : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Neg a => Matrix' (2 + n) a reflectY = reflect {rep} 1 export reflectZ : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Neg a => Matrix' (3 + n) a reflectZ = reflect {rep} 2 -------------------------------------------------------------------------------- -- Indexing -------------------------------------------------------------------------------- ||| Index the matrix at the given coordinates. export index : Fin m -> Fin n -> Matrix m n a -> a index m n = index [m,n] ||| Index the matrix at the given coordinates, returning `Nothing` if the ||| coordinates are out of bounds. export indexNB : Nat -> Nat -> Matrix m n a -> Maybe a indexNB m n = indexNB [m,n] ||| Return a row of the matrix as a vector. export getRow : Fin m -> Matrix m n a -> Vector n a getRow r mat = rewrite sym (rangeLenZ n) in mat!!..[One r, All] ||| Return a column of the matrix as a vector. export getColumn : Fin n -> Matrix m n a -> Vector m a getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c] ||| Return the diagonal elements of the matrix as a vector. export diagonal' : Matrix m n a -> Vector (minimum m n) a diagonal' {m,n} mat with (viewShape mat) _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat} _ (\[i] => mat!#[i,i]) ||| Return the diagonal elements of the matrix as a vector. export diagonal : Matrix' n a -> Vector n a diagonal {n} mat with (viewShape mat) _ | Shape [n,n] = fromFunctionNB {rep=_} @{getRepC mat} [n] (\[i] => mat!#[i,i]) ||| Return a minor of the matrix, i.e. the matrix formed by removing a ||| single row and column. export -- TODO: throw an actual proof in here to avoid the unsafety minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a minor i j mat = replace {p = flip Array a} (believe_me $ Refl {x = ()}) $ mat!!..[Filter (/=i), Filter (/=j)] filterInd : Num a => (Nat -> Nat -> Bool) -> Matrix m n a -> Matrix m n a filterInd {m,n} p mat with (viewShape mat) _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat} [m,n] (\[i,j] => if p i j then mat!#[i,j] else 0) export upperTriangle : Num a => Matrix m n a -> Matrix m n a upperTriangle = filterInd (<=) export lowerTriangle : Num a => Matrix m n a -> Matrix m n a lowerTriangle = filterInd (>=) export upperTriangleStrict : Num a => Matrix m n a -> Matrix m n a upperTriangleStrict = filterInd (<) export lowerTriangleStrict : Num a => Matrix m n a -> Matrix m n a lowerTriangleStrict = filterInd (>) -------------------------------------------------------------------------------- -- Basic operations -------------------------------------------------------------------------------- ||| Concatenate two matrices vertically. export vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a vconcat = concat 0 ||| Concatenate two matrices horizontally. export hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a hconcat = concat 1 ||| Stack row vectors to form a matrix. export vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a vstack = stack 0 ||| Stack column vectors to form a matrix. export hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a hstack = stack 1 ||| Swap two rows of a matrix. export swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a swapRows = swapInAxis 0 ||| Swap two columns of a matrix. export swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a swapColumns = swapInAxis 1 ||| Permute the rows of a matrix. export permuteRows : Permutation m -> Matrix m n a -> Matrix m n a permuteRows = permuteInAxis 0 ||| Permute the columns of a matrix. export permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a permuteColumns = permuteInAxis 1 ||| Calculate the outer product of two vectors as a matrix. export outer : Num a => Vector m a -> Vector n a -> Matrix m n a outer {m,n} a b with (viewShape a, viewShape b) _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j) ||| Calculate the trace of a matrix, i.e. the sum of its diagonal elements. export trace : Num a => Matrix m n a -> a trace = sum . diagonal' ||| Construct a matrix that reflects a vector along a hyperplane of the ||| given normal vector. The input does not have to be a unit vector. export reflectNormal : (Neg a, Fractional a) => Vector n a -> Matrix' n a reflectNormal {n} v with (viewShape v) _ | Shape [n] = repeatDiag 1 0 - (2 / normSq v) *. outer v v -------------------------------------------------------------------------------- -- Matrix multiplication -------------------------------------------------------------------------------- export Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where (*.) {m,n} mat v with (viewShape mat) _ | Shape [m,n] = fromFunction {rep=_} @{mergeRepConstraint (getRepC mat) (getRepC v)} [m] (\[i] => sum $ map (\j => mat!![i,j] * v!!j) range) export Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where (*.) {m,n,p} m1 m2 with (viewShape m1, viewShape m2) _ | (Shape [m,n], Shape [n,p]) = fromFunction {rep=_} @{mergeRepConstraint (getRepC m1) (getRepC m2)} [m,p] (\[i,j] => sum $ map (\k => m1!![i,k] * m2!![k,j]) range) export {n : _} -> Num a => MultMonoid (Matrix' n a) where identity = repeatDiag 1 0 -------------------------------------------------------------------------------- -- Matrix decomposition -------------------------------------------------------------------------------- ||| The LU decomposition of a matrix. ||| ||| LU decomposition factors a matrix A into two matrices: a lower triangular ||| matrix L, and an upper triangular matrix U, such that A = LU. export record DecompLU {0 m,n,a : _} (mat : Matrix m n a) where constructor MkLU -- The lower and upper triangular matrix elements are stored -- together for efficiency reasons lu : Matrix m n a namespace DecompLU ||| The lower triangular matrix L of the LU decomposition. export lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a lower {m,n} (MkLU lu) with (viewShape lu) _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] => case compare i j of LT => 0 EQ => 1 GT => lu!#[i,j]) ||| The lower triangular matrix L of the LU decomposition. export %inline (.lower) : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a (.lower) = lower ||| The lower triangular matrix L of the LU decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export lower' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n) in lower lu ||| The lower triangular matrix L of the LU decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export %inline (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a (.lower') = lower' ||| The upper triangular matrix U of the LU decomposition. export upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a upper {m,n} (MkLU lu) with (viewShape lu) _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] => if i <= j then lu!#[i,j] else 0) ||| The upper triangular matrix U of the LU decomposition. export %inline (.upper) : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a (.upper) = upper ||| The upper triangular matrix U of the LU decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export upper' : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n) in upper lu ||| The upper triangular matrix U of the LU decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export %inline (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLU mat -> Matrix' n a (.upper') = upper' minWeakenLeft : {m,n : _} -> Fin (minimum m n) -> Fin m minWeakenLeft x = weakenLTE x $ minLTE m n where minLTE : (m,n : _) -> minimum m n `LTE` m minLTE Z n = LTEZero minLTE (S m) Z = LTEZero minLTE (S m) (S n) = LTESucc (minLTE m n) minWeakenRight : {m,n : _} -> Fin (minimum m n) -> Fin n minWeakenRight x = weakenLTE x $ minLTE m n where minLTE : (m,n : _) -> minimum m n `LTE` n minLTE Z n = LTEZero minLTE (S m) Z = LTEZero minLTE (S m) (S n) = LTESucc (minLTE m n) iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a iterateN 0 f x = x iterateN 1 f x = f FZ x iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x ||| Perform a single step of Gaussian elimination on the `i`-th row and column. gaussStep : {m,n : _} -> Field a => Fin (minimum m n) -> Matrix m n a -> Matrix m n a gaussStep i lu = if all (==0) $ getColumn (minWeakenRight i) lu then lu else let ir = minWeakenLeft i ic = minWeakenRight i diag = lu!![ir,ic] coeffs = map (/diag) $ lu!!..[StartBound (FS ir), One ic] lu' = indexSetRange [StartBound (FS ir), One ic] coeffs lu pivot = lu!!..[One ir, StartBound (FS ic)] offsets = outer coeffs pivot in indexUpdateRange [StartBound (FS ir), StartBound (FS ic)] (flip (-) offsets) lu' ||| Calculate the LU decomposition of a matrix, returning `Nothing` if one ||| does not exist. export decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat) decompLU {m,n} mat with (viewShape mat) _ | Shape [m,n] = map MkLU $ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat) where gaussStepMaybe : Fin (minimum m n) -> Matrix m n a -> Maybe (Matrix m n a) gaussStepMaybe i mat = if mat!#[cast i,cast i] == 0 then Nothing else Just $ gaussStep i mat ||| The LUP decomposition of a matrix. ||| ||| LUP decomposition is similar to LU decomposition, but the matrix may have ||| its rows permuted before being factored. More formally, an LUP decomposition ||| of a matrix A consists of a lower triangular matrix L, an upper triangular ||| matrix U, and a permutation matrix P, such that PA = LU. export record DecompLUP {0 m,n,a : _} (mat : Matrix m n a) where constructor MkLUP lu : Matrix m n a p : Permutation m sw : Nat namespace DecompLUP ||| The lower triangular matrix L of the LUP decomposition. export lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a lower {m,n} (MkLUP lu _ _) with (viewShape lu) _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] => case compare i j of LT => 0 EQ => 1 GT => lu!#[i,j]) ||| The lower triangular matrix L of the LUP decomposition. export %inline (.lower) : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a (.lower) = lower ||| The lower triangular matrix L of the LUP decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export lower' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a lower' lu = rewrite cong (\i => Matrix n i a) $ sym (minimumIdempotent n) in lower lu ||| The lower triangular matrix L of the LUP decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export %inline (.lower') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a (.lower') = lower' ||| The upper triangular matrix U of the LUP decomposition. export upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a upper {m,n} (MkLUP lu _ _) with (viewShape lu) _ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] => if i <= j then lu!#[i,j] else 0) ||| The upper triangular matrix U of the LUP decomposition. export %inline (.upper) : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a (.upper) = upper ||| The upper triangular matrix U of the LUP decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export upper' : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a upper' lu = rewrite cong (\i => Matrix i n a) $ sym (minimumIdempotent n) in upper lu ||| The upper triangular matrix U of the LUP decomposition. ||| ||| This accessor is intended to be used for square matrix decompositions. export %inline (.upper') : Num a => {0 mat : Matrix' n a} -> DecompLUP mat -> Matrix' n a (.upper') = upper' ||| The row permutation of the LUP decomposition. export permute : DecompLUP {m} mat -> Permutation m permute (MkLUP lu p sw) = p ||| The row permutation of the LUP decomposition. export %inline (.permute) : DecompLUP {m} mat -> Permutation m (.permute) = permute ||| The number of swaps in the permutation of the LUP decomposition. ||| ||| This is stored along with the permutation in order to increase the ||| efficiency of certain algorithms. export numSwaps : DecompLUP mat -> Nat numSwaps (MkLUP lu p sw) = sw ||| Convert an LU decomposition into an LUP decomposition. export fromLU : DecompLU mat -> DecompLUP mat fromLU (MkLU lu) = MkLUP lu identity 0 ||| Calculate the LUP decomposition of a matrix. export decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat decompLUP {m,n} mat with (viewShape mat) decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0 decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0 decompLUP {m=S m,n=S n} mat | Shape [S m,S n] = iterateN (S $ minimum m n) gaussStepSwap (MkLUP mat identity 0) where maxIndex : (s,a) -> List (s,a) -> (s,a) maxIndex x [] = x maxIndex _ [x] = x maxIndex x ((a,b)::(c,d)::xs) = if abslt b d then maxIndex x ((c,d)::xs) else assert_total $ maxIndex x ((a,b)::xs) gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat gaussStepSwap i (MkLUP lu p sw) = let ir = minWeakenLeft {n=S n} i ic = minWeakenRight {m=S m} i maxi = head $ fst (maxIndex ([0],0) $ drop (cast i) $ enumerate $ indexSetRange [EndBound (weaken ir)] 0 $ getColumn ic lu) in if maxi == ir then MkLUP (gaussStep i lu) p sw else MkLUP (gaussStep i $ swapRows ir maxi lu) (appendSwap maxi ir p) (S sw) -------------------------------------------------------------------------------- -- Determinant -------------------------------------------------------------------------------- ||| Calculate the determinant of a matrix given its LUP decomposition. export detWithLUP : Num a => (mat : Matrix' n a) -> DecompLUP mat -> a detWithLUP mat lup = (if numSwaps lup `mod` 2 == 0 then 1 else -1) * product (diagonal lup.lu) ||| Calculate the determinant of a matrix. export det : FieldCmp a => Matrix' n a -> a det {n} mat with (viewShape mat) det {n=0} mat | Shape [0,0] = 1 det {n=1} mat | Shape [1,1] = mat!![0,0] det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c _ | Shape [n,n] = detWithLUP mat (decompLUP mat) -------------------------------------------------------------------------------- -- Solving matrix equations -------------------------------------------------------------------------------- solveLowerTri' : Field a => Matrix' n a -> Vector n a -> Vector n a solveLowerTri' {n} mat b with (viewShape b) _ | Shape [n] = vector $ reverse $ construct $ reverse $ toVect b where construct : {i : _} -> Vect i a -> Vect i a construct [] = [] construct {i=S i} (b :: bs) = let xs = construct bs i' = assert_total $ case natToFin i n of Just i' => i' in (b - sum (zipWith (*) xs (reverse $ toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $ mat !!.. [One i', EndBound (weaken i')]))) / mat!#[i,i] :: xs solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a solveUpperTri' {n} mat b with (viewShape b) _ | Shape [n] = vector $ construct Z $ toVect b where construct : Nat -> Vect i a -> Vect i a construct _ [] = [] construct i (b :: bs) = let xs = construct (S i) bs i' = assert_total $ case natToFin i n of Just i' => i' in (b - sum (zipWith (*) xs (toVect $ replace {p = flip Array a} (believe_me $ Refl {x=()}) $ mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs ||| Solve a linear equation, assuming the matrix is lower triangular. ||| Any entries other than those below the diagonal are ignored. export solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveLowerTri mat b = if all (/=0) (diagonal mat) then Just $ solveLowerTri' mat b else Nothing ||| Solve a linear equation, assuming the matrix is upper triangular. ||| Any entries other than those above the diagonal are ignored. export solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveUpperTri mat b = if all (/=0) (diagonal mat) then Just $ solveUpperTri' mat b else Nothing solveWithLUP' : Field a => (mat : Matrix' n a) -> DecompLUP mat -> Vector n a -> Vector n a solveWithLUP' mat lup b = let b' = permuteCoords (inverse lup.permute) b in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b' ||| Solve a linear equation, given a matrix and its LUP decomposition. export solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat -> Vector n a -> Maybe (Vector n a) solveWithLUP mat lup b = let b' = permuteCoords (inverse lup.permute) b in solveUpperTri lup.upper' $ solveLowerTri' lup.lower' b' ||| Solve a linear equation given a matrix. export solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solve mat = solveWithLUP mat (decompLUP mat) -------------------------------------------------------------------------------- -- Matrix inversion -------------------------------------------------------------------------------- export invertible : FieldCmp a => Matrix' n a -> Bool invertible {n} mat with (viewShape mat) _ | Shape [n,n] = let lup = decompLUP mat in all (/=0) (diagonal lup.lu) ||| Try to invert a square matrix, returning `Nothing` if an inverse ||| does not exist. export tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a) tryInverse {n} mat with (viewShape mat) _ | Shape [n,n] = let lup = decompLUP mat in map hstack $ traverse (solveWithLUP mat lup) $ map basis range export {n : _} -> FieldCmp a => MultGroup (Matrix' n a) where inverse mat = let lup = decompLUP mat in hstack $ map (solveWithLUP' mat lup . basis) range