These implementations are absolutely terrible, but they'll work for now. I'll refactor them when there are more library features available.
295 lines
9 KiB
Idris
295 lines
9 KiB
Idris
module Data.NumIdr.Matrix
|
|
|
|
import Data.List
|
|
import Data.Vect
|
|
import Data.Bool.Xor
|
|
import Data.NumIdr.Multiply
|
|
import public Data.NumIdr.Array
|
|
import Data.NumIdr.Vector
|
|
|
|
%default total
|
|
|
|
|
|
||| A matrix is a rank-2 array.
|
|
public export
|
|
Matrix : Nat -> Nat -> Type -> Type
|
|
Matrix m n = Array [m,n]
|
|
|
|
||| A synonym for a square matrix with dimensions of length `n`.
|
|
public export
|
|
Matrix' : Nat -> Type -> Type
|
|
Matrix' n = Matrix n n
|
|
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Matrix constructors
|
|
--------------------------------------------------------------------------------
|
|
|
|
|
|
||| Construct a matrix with the given order and elements.
|
|
export
|
|
matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a
|
|
matrix' ord x = array' [m,n] ord x
|
|
|
|
||| Construct a matrix with the given elements.
|
|
export
|
|
matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a
|
|
matrix = matrix' COrder
|
|
|
|
|
|
||| Construct a matrix with a specific value along the diagonal.
|
|
|||
|
|
||| @ diag The value to repeat along the diagonal
|
|
||| @ other The value to repeat elsewhere
|
|
export
|
|
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
|
repeatDiag d o = fromFunctionNB [m,n]
|
|
(\[i,j] => if i == j then d else o)
|
|
|
|
||| Construct a matrix given its diagonal elements.
|
|
|||
|
|
||| @ diag The elements of the matrix's diagonal
|
|
||| @ other The value to repeat elsewhere
|
|
export
|
|
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
|
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
|
where
|
|
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
|
|
eq FZ FZ = Just FZ
|
|
eq (FS x) (FS y) = map FS (eq x y)
|
|
eq FZ (FS _) = Nothing
|
|
eq (FS _) FZ = Nothing
|
|
|
|
|
|
||| Construct the matrix that scales a vector by the given value.
|
|
export
|
|
scaling : {n : _} -> Num a => a -> Matrix' n a
|
|
scaling x = repeatDiag x 0
|
|
|
|
||| Calculate the rotation matrix of an angle.
|
|
export
|
|
rotation2D : Double -> Matrix' 2 Double
|
|
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
|
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Indexing
|
|
--------------------------------------------------------------------------------
|
|
|
|
|
|
||| Index the matrix at the given coordinates.
|
|
export
|
|
index : Fin m -> Fin n -> Matrix m n a -> a
|
|
index m n = index [m,n]
|
|
|
|
||| Index the matrix at the given coordinates, returning `Nothing` if the
|
|
||| coordinates are out of bounds.
|
|
export
|
|
indexNB : Nat -> Nat -> Matrix m n a -> Maybe a
|
|
indexNB m n = indexNB [m,n]
|
|
|
|
|
|
||| Return a row of the matrix as a vector.
|
|
export
|
|
getRow : Fin m -> Matrix m n a -> Vector n a
|
|
getRow r mat = rewrite sym (rangeLenZ n) in mat !!.. [One r, All]
|
|
|
|
||| Return a column of the matrix as a vector.
|
|
export
|
|
getColumn : Fin n -> Matrix m n a -> Vector m a
|
|
getColumn c mat = rewrite sym (rangeLenZ m) in mat !!.. [All, One c]
|
|
|
|
|
|
export
|
|
diagonal' : Matrix m n a -> Vector (minimum m n) a
|
|
diagonal' mat with (viewShape mat)
|
|
_ | Shape [m,n] = fromFunctionNB _ (\[i] => mat!#[i,i])
|
|
|
|
export
|
|
diagonal : Matrix' n a -> Vector n a
|
|
diagonal mat with (viewShape mat)
|
|
_ | Shape [n,n] = fromFunctionNB [n] (\[i] => mat!#[i,i])
|
|
|
|
|
|
-- TODO: throw an actual proof in here to avoid the unsafety
|
|
export
|
|
minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a
|
|
minor i j mat = believe_me $ mat !!.. [Filter (/=i), Filter (/=j)]
|
|
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Basic operations
|
|
--------------------------------------------------------------------------------
|
|
|
|
||| Concatenate two matrices vertically.
|
|
export
|
|
vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
|
vconcat = concat 0
|
|
|
|
||| Concatenate two matrices horizontally.
|
|
export
|
|
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
|
hconcat = concat 1
|
|
|
|
|
|
export
|
|
vstack : {n : _} -> Vect m (Vector n a) -> Matrix m n a
|
|
vstack = stack 0
|
|
|
|
export
|
|
hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a
|
|
hstack = stack 1
|
|
|
|
|
|
export
|
|
transpose : Matrix m n a -> Matrix n m a
|
|
transpose mat with (viewShape mat)
|
|
_ | Shape [m,n] = fromFunctionNB [n,m] (\[i,j] => mat!#[j,i])
|
|
|
|
export
|
|
(.T) : Matrix m n a -> Matrix n m a
|
|
(.T) = transpose
|
|
|
|
|
|
||| Calculate the outer product of two vectors as a matrix.
|
|
export
|
|
outer : Num a => Vector m a -> Vector n a -> Matrix m n a
|
|
outer a b with (viewShape a, viewShape b)
|
|
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a!!i * b!!j)
|
|
|
|
|
|
export
|
|
trace : Num a => Matrix m n a -> a
|
|
trace = sum . diagonal'
|
|
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Matrix multiplication
|
|
--------------------------------------------------------------------------------
|
|
|
|
|
|
export
|
|
Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where
|
|
mat *. v with (viewShape mat)
|
|
_ | Shape [m,n] = fromFunction [m]
|
|
(\[i] => sum $ map (\j => mat!![i,j] * v!!j) range)
|
|
|
|
export
|
|
Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where
|
|
m1 *. m2 with (viewShape m1, viewShape m2)
|
|
_ | (Shape [m,n], Shape [n,p]) = fromFunction [m,p]
|
|
(\[i,j] => sum $ map (\k => m1!![i,k] * m2!![k,j]) range)
|
|
|
|
export
|
|
{n : _} -> Num a => MultMonoid (Matrix' n a) where
|
|
identity = repeatDiag 1 0
|
|
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Matrix decomposition
|
|
--------------------------------------------------------------------------------
|
|
|
|
|
|
-- LU Decomposition
|
|
public export
|
|
record DecompLU {0 n,a : _} (mat : Matrix' n a) where
|
|
constructor MkLU
|
|
lower, upper : Matrix' n a
|
|
|
|
export
|
|
Show a => Show (DecompLU {a} mat) where
|
|
showPrec p (MkLU l u) = showCon p "MkLU" $ showArg l ++ showArg u
|
|
|
|
|
|
iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a
|
|
iterateN 0 f x = x
|
|
iterateN 1 f x = f FZ x
|
|
iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
|
|
|
|
export
|
|
decompLU : Neg a => Fractional a => (mat : Matrix' n a) -> DecompLU mat
|
|
decompLU {n} mat with (viewShape mat)
|
|
_ | Shape [n,n] = iterateN n doolittle (MkLU identity mat)
|
|
where
|
|
doolittle : Fin n -> DecompLU mat -> DecompLU mat
|
|
doolittle i (MkLU l u) =
|
|
let v = rewrite rangeLen (S i') n in fromFunctionNB [minus n (S i')]
|
|
(\[x] => u!#[S i' + x,i'] / u!#[i',i'])
|
|
low = indexSetRange [StartBound (FS i), One i] (-v) identity
|
|
in MkLU (indexSetRange [StartBound (FS i), One i] v l) (low *. u)
|
|
where i' : Nat
|
|
i' = cast i
|
|
|
|
|
|
-- LUP Decomposition
|
|
public export
|
|
record DecompLUP {0 n,a : _} (mat : Matrix' n a) where
|
|
constructor MkLUP
|
|
lower, upper, permute : Matrix' n a
|
|
swaps : Nat
|
|
|
|
export
|
|
Show a => Show (DecompLUP {a} mat) where
|
|
showPrec p (MkLUP l u pr b) = showCon p "MkLUP" $
|
|
showArg l ++ showArg u ++ showArg pr ++ showArg b
|
|
|
|
export
|
|
fromLU : Num a => DecompLU {n,a} mat -> DecompLUP mat
|
|
fromLU {n} (MkLU l u) with (viewShape l)
|
|
_ | Shape [n,n] = MkLUP l u identity 0
|
|
|
|
export
|
|
decompLUP : Ord a => Abs a => Neg a => Fractional a =>
|
|
(mat : Matrix' n a) -> DecompLUP mat
|
|
decompLUP {n} mat with (viewShape mat)
|
|
decompLUP {n=0} mat | Shape [0,0] = MkLUP mat mat mat 0
|
|
decompLUP {n=S n} mat | Shape [S n,S n] =
|
|
iterateN (S n) doolittle (MkLUP identity mat identity 0)
|
|
where
|
|
maxIndex : (s,a) -> List (s,a) -> (s,a)
|
|
maxIndex x [] = x
|
|
maxIndex _ [x] = x
|
|
maxIndex x ((a,b)::(c,d)::xs) =
|
|
if abs b < abs d then assert_total $ maxIndex x ((c,d)::xs)
|
|
else assert_total $ maxIndex x ((a,b)::xs)
|
|
|
|
doolittle : Fin (S n) -> DecompLUP mat -> DecompLUP mat
|
|
doolittle i (MkLUP l u p sw) =
|
|
let (maxi, maxv) = mapFst ((+i') . head)
|
|
(maxIndex ([0],0) $ enumerateNB $
|
|
u !!.. [StartBound (weaken i), One i])
|
|
u' = if maxi == i' then u
|
|
else fromFunctionNB _ (\[x,y] =>
|
|
if x==i' then u!#[maxi,y]
|
|
else if x==maxi then u!#[i',y]
|
|
else u!#[x,y])
|
|
p' = if maxi == i' then p
|
|
else fromFunctionNB _ (\[x,y] =>
|
|
if x==i' then p!#[maxi,y]
|
|
else if x==maxi then p!#[i',y]
|
|
else p!#[x,y])
|
|
v = rewrite rangeLen (S i') (S n) in fromFunctionNB [minus n i']
|
|
(\[x] => u'!#[S i' + x,i'] / u'!#[i',i'])
|
|
low = indexSetRange [StartBound (FS i), One i] (-v) identity
|
|
in if maxv == 0 then MkLUP l u p sw else
|
|
MkLUP (indexSetRange [StartBound (FS i), One i] v l)
|
|
(low *. u') p' (if maxi==i' then S sw else sw)
|
|
where i' : Nat
|
|
i' = cast i
|
|
|
|
|
|
--------------------------------------------------------------------------------
|
|
-- Matrix properties
|
|
--------------------------------------------------------------------------------
|
|
|
|
|
|
export
|
|
det : Ord a => Abs a => Neg a => Fractional a =>
|
|
Matrix' n a -> a
|
|
det {n} mat with (viewShape mat)
|
|
det {n=0} mat | Shape [0,0] = 1
|
|
det {n=1} mat | Shape [1,1] = mat!![0,0]
|
|
det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
|
|
_ | Shape [n,n] = let MkLUP l u p sw = decompLUP mat
|
|
in (if sw `mod` 2 == 0 then 1 else -1)
|
|
* product (diagonal l) * product (diagonal u)
|