Implement matrix inverse

This commit is contained in:
Kiana Sheibani 2022-09-12 10:47:48 -04:00
parent f72826b329
commit d610874abc
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -394,7 +394,7 @@ decompLUP {m,n} mat with (viewShape mat)
decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0 decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0
decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0 decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0
decompLUP {m=S m,n=S n} mat | Shape [S m,S n] = decompLUP {m=S m,n=S n} mat | Shape [S m,S n] =
iterateN (minimum (S m) (S n)) gaussStepSwap (MkLUP mat identity 0) iterateN (S $ minimum m n) gaussStepSwap (MkLUP mat identity 0)
where where
maxIndex : (s,a) -> List (s,a) -> (s,a) maxIndex : (s,a) -> List (s,a) -> (s,a)
maxIndex x [] = x maxIndex x [] = x
@ -403,15 +403,14 @@ decompLUP {m,n} mat with (viewShape mat)
if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs) if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs)
else assert_total $ maxIndex x ((a,b)::xs) else assert_total $ maxIndex x ((a,b)::xs)
gaussStepSwap : Fin (minimum (S m) (S n)) -> DecompLUP mat -> DecompLUP mat gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat
gaussStepSwap i (MkLUP lu p sw) = gaussStepSwap i (MkLUP lu p sw) =
let ir = minWeakenLeft {n=S n} i let ir = minWeakenLeft {n=S n} i
ic = minWeakenRight {m=S m} i ic = minWeakenRight {m=S m} i
(maxi, maxv) = mapFst head maxi = head $ fst (maxIndex ([0],0) $ drop (cast i) $ enumerate $
(maxIndex ([0],0) $ enumerate $
indexSetRange [EndBound (weaken ir)] 0 $ getColumn ic lu) indexSetRange [EndBound (weaken ir)] 0 $ getColumn ic lu)
in if maxi == ir then MkLUP (gaussStep i lu) p sw in if maxi == ir then MkLUP (gaussStep i lu) p sw
else MkLUP (gaussStep i $ swapRows maxi ir lu) (appendSwap maxi ir p) (S sw) else MkLUP (gaussStep i $ swapRows ir maxi lu) (appendSwap maxi ir p) (S sw)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -439,13 +438,9 @@ det {n} mat with (viewShape mat)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
export solveLowerTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveLowerTri' {n} mat b with (viewShape b)
solveLowerTri {n} mat b with (viewShape b) _ | Shape [n] = vector $ reverse $ construct $ reverse $ toVect b
_ | Shape [n] =
if all (/=0) (diagonal mat)
then Just $ vector $ reverse $ construct $ reverse $ toVect b
else Nothing
where where
construct : {i : _} -> Vect i a -> Vect i a construct : {i : _} -> Vect i a -> Vect i a
construct [] = [] construct [] = []
@ -455,13 +450,10 @@ solveLowerTri {n} mat b with (viewShape b)
in (b - sum (zipWith (*) xs (reverse $ toVect $ believe_me $ in (b - sum (zipWith (*) xs (reverse $ toVect $ believe_me $
mat !!.. [One i', EndBound (weaken i')]))) / mat!#[i,i] :: xs mat !!.. [One i', EndBound (weaken i')]))) / mat!#[i,i] :: xs
export
solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a
solveUpperTri {n} mat b with (viewShape b) solveUpperTri' {n} mat b with (viewShape b)
_ | Shape [n] = _ | Shape [n] = vector $ construct Z $ toVect b
if all (/=0) (diagonal mat)
then Just $ vector $ construct Z $ toVect b
else Nothing
where where
construct : Nat -> Vect i a -> Vect i a construct : Nat -> Vect i a -> Vect i a
construct _ [] = [] construct _ [] = []
@ -471,6 +463,26 @@ solveUpperTri {n} mat b with (viewShape b)
in (b - sum (zipWith (*) xs (toVect $ believe_me $ in (b - sum (zipWith (*) xs (toVect $ believe_me $
mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs
export
solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
solveLowerTri mat b = if all (/=0) (diagonal mat)
then Just $ solveLowerTri' mat b
else Nothing
export
solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
solveUpperTri mat b = if all (/=0) (diagonal mat)
then Just $ solveUpperTri' mat b
else Nothing
solveWithLUP' : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
Vector n a -> Vector n a
solveWithLUP' mat lup b =
let b' = permuteCoords (inverse lup.permute) b
in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b'
export export
solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat -> solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat ->
Vector n a -> Maybe (Vector n a) Vector n a -> Maybe (Vector n a)
@ -478,6 +490,22 @@ solveWithLUP mat lup b =
let b' = permuteCoords (inverse lup.permute) b let b' = permuteCoords (inverse lup.permute) b
in solveLowerTri lup.lower' b' >>= solveUpperTri lup.upper' in solveLowerTri lup.lower' b' >>= solveUpperTri lup.upper'
export export
solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a)
solve mat = solveWithLUP mat (decompLUP mat) solve mat = solveWithLUP mat (decompLUP mat)
export
tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a)
tryInverse {n} mat with (viewShape mat)
_ | Shape [n,n] =
let lup = decompLUP mat
in map hstack $ traverse (solveWithLUP mat lup) $ map basis range
export
{n : _} -> FieldCmp a => MultGroup (Matrix' n a) where
inverse mat = let lup = decompLUP mat in
hstack $ map (solveWithLUP' mat lup . basis) range