Refactor LU and LUP decomposition
This commit is contained in:
parent
b74734fbc1
commit
c9dada5206
|
@ -2,7 +2,7 @@ module Data.NumIdr.Matrix
|
||||||
|
|
||||||
import Data.List
|
import Data.List
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.Bool.Xor
|
import Data.Permutation
|
||||||
import Data.NumIdr.Multiply
|
import Data.NumIdr.Multiply
|
||||||
import public Data.NumIdr.Array
|
import public Data.NumIdr.Array
|
||||||
import Data.NumIdr.Vector
|
import Data.NumIdr.Vector
|
||||||
|
@ -61,6 +61,11 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||||
eq (FS _) FZ = Nothing
|
eq (FS _) FZ = Nothing
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
permutationMatrix : {n : _} -> Num a => Permutation n -> Matrix' n a
|
||||||
|
permutationMatrix p = permuteInAxis 0 p (repeatDiag 1 0)
|
||||||
|
|
||||||
|
|
||||||
||| Construct the matrix that scales a vector by the given value.
|
||| Construct the matrix that scales a vector by the given value.
|
||||||
export
|
export
|
||||||
scaling : {n : _} -> Num a => a -> Matrix' n a
|
scaling : {n : _} -> Num a => a -> Matrix' n a
|
||||||
|
@ -117,6 +122,27 @@ minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a
|
||||||
minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)]
|
minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)]
|
||||||
|
|
||||||
|
|
||||||
|
filterInd : Num a => (Nat -> Nat -> Bool) -> Matrix m n a -> Matrix m n a
|
||||||
|
filterInd p mat with (viewShape mat)
|
||||||
|
_ | Shape [m,n] = fromFunctionNB [m,n] (\[i,j] => if p i j then mat!#[i,j] else 0)
|
||||||
|
|
||||||
|
export
|
||||||
|
upperTriangle : Num a => Matrix m n a -> Matrix m n a
|
||||||
|
upperTriangle = filterInd (<=)
|
||||||
|
|
||||||
|
export
|
||||||
|
lowerTriangle : Num a => Matrix m n a -> Matrix m n a
|
||||||
|
lowerTriangle = filterInd (>=)
|
||||||
|
|
||||||
|
export
|
||||||
|
upperTriangleStrict : Num a => Matrix m n a -> Matrix m n a
|
||||||
|
upperTriangleStrict = filterInd (<)
|
||||||
|
|
||||||
|
export
|
||||||
|
lowerTriangleStrict : Num a => Matrix m n a -> Matrix m n a
|
||||||
|
lowerTriangleStrict = filterInd (>)
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Basic operations
|
-- Basic operations
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -141,6 +167,23 @@ hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a
|
||||||
hstack = stack 1
|
hstack = stack 1
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a
|
||||||
|
swapRows = swapInAxis 0
|
||||||
|
|
||||||
|
export
|
||||||
|
swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a
|
||||||
|
swapColumns = swapInAxis 1
|
||||||
|
|
||||||
|
export
|
||||||
|
permuteRows : Permutation m -> Matrix m n a -> Matrix m n a
|
||||||
|
permuteRows = permuteInAxis 0
|
||||||
|
|
||||||
|
export
|
||||||
|
permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a
|
||||||
|
permuteColumns = permuteInAxis 1
|
||||||
|
|
||||||
|
|
||||||
||| Calculate the outer product of two vectors as a matrix.
|
||| Calculate the outer product of two vectors as a matrix.
|
||||||
export
|
export
|
||||||
outer : Num a => Vector m a -> Vector n a -> Matrix m n a
|
outer : Num a => Vector m a -> Vector n a -> Matrix m n a
|
||||||
|
@ -181,14 +224,29 @@ export
|
||||||
|
|
||||||
|
|
||||||
-- LU Decomposition
|
-- LU Decomposition
|
||||||
public export
|
export
|
||||||
record DecompLU {0 n,a : _} (mat : Matrix' n a) where
|
record DecompLU {0 n,a : _} (mat : Matrix' n a) where
|
||||||
constructor MkLU
|
constructor MkLU
|
||||||
lower, upper : Matrix' n a
|
lu : Matrix' n a
|
||||||
|
|
||||||
export
|
|
||||||
Show a => Show (DecompLU {a} mat) where
|
namespace DecompLU
|
||||||
showPrec p (MkLU l u) = showCon p "MkLU" $ showArg l ++ showArg u
|
export
|
||||||
|
lower : Num a => DecompLU {n,a} mat -> Matrix' n a
|
||||||
|
lower (MkLU lu) with (viewShape lu)
|
||||||
|
_ | Shape [n,n] = lowerTriangleStrict lu + identity
|
||||||
|
|
||||||
|
export %inline
|
||||||
|
(.lower) : Num a => DecompLU {n,a} mat -> Matrix' n a
|
||||||
|
(.lower) = lower
|
||||||
|
|
||||||
|
export
|
||||||
|
upper : Num a => DecompLU {n,a} mat -> Matrix' n a
|
||||||
|
upper (MkLU lu) = upperTriangle lu
|
||||||
|
|
||||||
|
export %inline
|
||||||
|
(.upper) : Num a => DecompLU {n,a} mat -> Matrix' n a
|
||||||
|
(.upper) = upper
|
||||||
|
|
||||||
|
|
||||||
iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a
|
iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a
|
||||||
|
@ -196,45 +254,75 @@ iterateN 0 f x = x
|
||||||
iterateN 1 f x = f FZ x
|
iterateN 1 f x = f FZ x
|
||||||
iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
|
iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
|
||||||
|
|
||||||
|
|
||||||
|
gaussStep : (Eq a, Neg a, Fractional a) => Fin n -> Matrix' n a -> Matrix' n a
|
||||||
|
gaussStep {n} i lu with (viewShape lu)
|
||||||
|
_ | Shape [n,n] =
|
||||||
|
if all (==0) $ getColumn i lu then lu else
|
||||||
|
let diag = lu!![i,i]
|
||||||
|
coeffs = map (/diag) $ lu!!..[StartBound (FS i), One i]
|
||||||
|
lu' = indexSetRange [StartBound (FS i), One i]
|
||||||
|
coeffs lu
|
||||||
|
pivot = lu!!..[One i, StartBound (FS i)]
|
||||||
|
offsets = negate $ outer coeffs pivot
|
||||||
|
in indexUpdateRange [StartBound (FS i), StartBound (FS i)] (+offsets) lu'
|
||||||
|
|
||||||
export
|
export
|
||||||
decompLU : Neg a => Fractional a => (mat : Matrix' n a) -> DecompLU mat
|
decompLU : (Eq a, Neg a, Fractional a) => (mat : Matrix' n a) -> DecompLU mat
|
||||||
decompLU {n} mat with (viewShape mat)
|
decompLU {n} mat with (viewShape mat)
|
||||||
_ | Shape [n,n] = iterateN n doolittle (MkLU identity mat)
|
_ | Shape [n,n] = MkLU $ iterateN n gaussStep mat
|
||||||
where
|
|
||||||
doolittle : Fin n -> DecompLU mat -> DecompLU mat
|
|
||||||
doolittle i (MkLU l u) =
|
|
||||||
let v = rewrite rangeLen (S i') n in fromFunctionNB [minus n (S i')]
|
|
||||||
(\[x] => u!#[S i' + x,i'] / u!#[i',i'])
|
|
||||||
low = indexSetRange [StartBound (FS i), One i] (-v) identity
|
|
||||||
in MkLU (indexSetRange [StartBound (FS i), One i] v l) (low *. u)
|
|
||||||
where i' : Nat
|
|
||||||
i' = cast i
|
|
||||||
|
|
||||||
|
|
||||||
-- LUP Decomposition
|
-- LUP Decomposition
|
||||||
public export
|
public export
|
||||||
record DecompLUP {0 n,a : _} (mat : Matrix' n a) where
|
record DecompLUP {0 n,a : _} (mat : Matrix' n a) where
|
||||||
constructor MkLUP
|
constructor MkLUP
|
||||||
lower, upper, permute : Matrix' n a
|
lu : Matrix' n a
|
||||||
swaps : Nat
|
p : Permutation n
|
||||||
|
sw : Nat
|
||||||
|
|
||||||
|
namespace DecompLUP
|
||||||
|
export
|
||||||
|
lower : Num a => DecompLUP {n,a} mat -> Matrix' n a
|
||||||
|
lower (MkLUP lu p sw) with (viewShape lu)
|
||||||
|
_ | Shape [n,n] = lowerTriangleStrict lu + identity
|
||||||
|
|
||||||
|
export %inline
|
||||||
|
(.lower) : Num a => DecompLUP {n,a} mat -> Matrix' n a
|
||||||
|
(.lower) = lower
|
||||||
|
|
||||||
|
export
|
||||||
|
upper : Num a => DecompLUP {n,a} mat -> Matrix' n a
|
||||||
|
upper (MkLUP lu p sw) = upperTriangle lu
|
||||||
|
|
||||||
|
export %inline
|
||||||
|
(.upper) : Num a => DecompLUP {n,a} mat -> Matrix' n a
|
||||||
|
(.upper) = upper
|
||||||
|
|
||||||
|
export
|
||||||
|
permute : DecompLUP {n} mat -> Permutation n
|
||||||
|
permute (MkLUP lu p sw) = p
|
||||||
|
|
||||||
|
export %inline
|
||||||
|
(.permute) : DecompLUP {n} mat -> Permutation n
|
||||||
|
(.permute) = permute
|
||||||
|
|
||||||
|
export
|
||||||
|
numSwaps : DecompLUP {n} mat -> Nat
|
||||||
|
numSwaps (MkLUP lu p sw) = sw
|
||||||
|
|
||||||
export
|
export
|
||||||
Show a => Show (DecompLUP {a} mat) where
|
fromLU : DecompLU mat -> DecompLUP mat
|
||||||
showPrec p (MkLUP l u pr b) = showCon p "MkLUP" $
|
fromLU (MkLU lu) = MkLUP lu identity 0
|
||||||
showArg l ++ showArg u ++ showArg pr ++ showArg b
|
|
||||||
|
|
||||||
export
|
export
|
||||||
fromLU : Num a => DecompLU {n,a} mat -> DecompLUP mat
|
decompLUP : (Ord a, Abs a, Neg a, Fractional a) =>
|
||||||
fromLU {n} (MkLU l u) with (viewShape l)
|
|
||||||
_ | Shape [n,n] = MkLUP l u identity 0
|
|
||||||
|
|
||||||
export
|
|
||||||
decompLUP : Ord a => Abs a => Neg a => Fractional a =>
|
|
||||||
(mat : Matrix' n a) -> DecompLUP mat
|
(mat : Matrix' n a) -> DecompLUP mat
|
||||||
decompLUP {n} mat with (viewShape mat)
|
decompLUP {n} mat with (viewShape mat)
|
||||||
decompLUP {n=0} mat | Shape [0,0] = MkLUP mat mat mat 0
|
decompLUP {n=0} mat | Shape [0,0] = MkLUP mat identity 0
|
||||||
decompLUP {n=S n} mat | Shape [S n,S n] =
|
decompLUP {n=S n} mat | Shape [S n,S n] =
|
||||||
iterateN (S n) doolittle (MkLUP identity mat identity 0)
|
iterateN (S n) gaussStepSwap (MkLUP mat identity 0)
|
||||||
where
|
where
|
||||||
maxIndex : (s,a) -> List (s,a) -> (s,a)
|
maxIndex : (s,a) -> List (s,a) -> (s,a)
|
||||||
maxIndex x [] = x
|
maxIndex x [] = x
|
||||||
|
@ -243,29 +331,13 @@ decompLUP {n} mat with (viewShape mat)
|
||||||
if abs b < abs d then assert_total $ maxIndex x ((c,d)::xs)
|
if abs b < abs d then assert_total $ maxIndex x ((c,d)::xs)
|
||||||
else assert_total $ maxIndex x ((a,b)::xs)
|
else assert_total $ maxIndex x ((a,b)::xs)
|
||||||
|
|
||||||
doolittle : Fin (S n) -> DecompLUP mat -> DecompLUP mat
|
gaussStepSwap : Fin (S n) -> DecompLUP mat -> DecompLUP mat
|
||||||
doolittle i (MkLUP l u p sw) =
|
gaussStepSwap i (MkLUP lu p sw) =
|
||||||
let (maxi, maxv) = mapFst ((+i') . head)
|
let (maxi, maxv) = mapFst head
|
||||||
(maxIndex ([0],0) $ enumerateNB $
|
(maxIndex ([0],0) $ enumerate $
|
||||||
u !!.. [StartBound (weaken i), One i])
|
indexSetRange [EndBound (weaken i)] 0 $ getColumn i lu)
|
||||||
u' = if maxi == i' then u
|
in if maxi == i then MkLUP (gaussStep i lu) p sw
|
||||||
else fromFunctionNB _ (\[x,y] =>
|
else MkLUP (gaussStep i $ swapRows maxi i lu) (appendSwap maxi i p) (S sw)
|
||||||
if x==i' then u!#[maxi,y]
|
|
||||||
else if x==maxi then u!#[i',y]
|
|
||||||
else u!#[x,y])
|
|
||||||
p' = if maxi == i' then p
|
|
||||||
else fromFunctionNB _ (\[x,y] =>
|
|
||||||
if x==i' then p!#[maxi,y]
|
|
||||||
else if x==maxi then p!#[i',y]
|
|
||||||
else p!#[x,y])
|
|
||||||
v = rewrite rangeLen (S i') (S n) in fromFunctionNB [minus n i']
|
|
||||||
(\[x] => u'!#[S i' + x,i'] / u'!#[i',i'])
|
|
||||||
low = indexSetRange [StartBound (FS i), One i] (-v) identity
|
|
||||||
in if maxv == 0 then MkLUP l u p sw else
|
|
||||||
MkLUP (indexSetRange [StartBound (FS i), One i] v l)
|
|
||||||
(low *. u') p' (if maxi==i' then S sw else sw)
|
|
||||||
where i' : Nat
|
|
||||||
i' = cast i
|
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -274,12 +346,16 @@ decompLUP {n} mat with (viewShape mat)
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
det : Ord a => Abs a => Neg a => Fractional a =>
|
detWithLUP : (Ord a, Abs a, Neg a, Fractional a) =>
|
||||||
Matrix' n a -> a
|
(mat : Matrix' n a) -> DecompLUP mat -> a
|
||||||
|
detWithLUP {n} mat lup =
|
||||||
|
(if numSwaps lup `mod` 2 == 0 then 1 else -1)
|
||||||
|
* product (diagonal lup.lower) * product (diagonal lup.upper)
|
||||||
|
|
||||||
|
export
|
||||||
|
det : (Ord a, Abs a, Neg a, Fractional a) => Matrix' n a -> a
|
||||||
det {n} mat with (viewShape mat)
|
det {n} mat with (viewShape mat)
|
||||||
det {n=0} mat | Shape [0,0] = 1
|
det {n=0} mat | Shape [0,0] = 1
|
||||||
det {n=1} mat | Shape [1,1] = mat!![0,0]
|
det {n=1} mat | Shape [1,1] = mat!![0,0]
|
||||||
det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
|
det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
|
||||||
_ | Shape [n,n] = let MkLUP l u p sw = decompLUP mat
|
_ | Shape [n,n] = detWithLUP mat (decompLUP mat)
|
||||||
in (if sw `mod` 2 == 0 then 1 else -1)
|
|
||||||
* product (diagonal l) * product (diagonal u)
|
|
||||||
|
|
|
@ -15,6 +15,10 @@ export
|
||||||
swap : (i,j : Fin n) -> Permutation n
|
swap : (i,j : Fin n) -> Permutation n
|
||||||
swap x y = MkPerm [(x,y)]
|
swap x y = MkPerm [(x,y)]
|
||||||
|
|
||||||
|
export
|
||||||
|
swaps : List (Fin n, Fin n) -> Permutation n
|
||||||
|
swaps = MkPerm
|
||||||
|
|
||||||
export
|
export
|
||||||
appendSwap : (i,j : Fin n) -> Permutation n -> Permutation n
|
appendSwap : (i,j : Fin n) -> Permutation n -> Permutation n
|
||||||
appendSwap i j (MkPerm a) = MkPerm ((i,j)::a)
|
appendSwap i j (MkPerm a) = MkPerm ((i,j)::a)
|
||||||
|
@ -51,6 +55,12 @@ export
|
||||||
permuteValues : Permutation n -> Nat -> Nat
|
permuteValues : Permutation n -> Nat -> Nat
|
||||||
permuteValues p = foldMap @{%search} @{mon} (\(i,j) => swapValues i j) p.swaps
|
permuteValues p = foldMap @{%search} @{mon} (\(i,j) => swapValues i j) p.swaps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Show (Permutation n) where
|
||||||
|
showPrec p (MkPerm a) = showCon p "swaps" $ showArg a
|
||||||
|
|
||||||
export
|
export
|
||||||
Mult (Permutation n) (Permutation n) (Permutation n) where
|
Mult (Permutation n) (Permutation n) (Permutation n) where
|
||||||
MkPerm a *. MkPerm b = MkPerm (a ++ b)
|
MkPerm a *. MkPerm b = MkPerm (a ++ b)
|
||||||
|
|
Loading…
Reference in a new issue