diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index 37352f7..acc012d 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -394,7 +394,7 @@ decompLUP {m,n} mat with (viewShape mat) decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0 decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0 decompLUP {m=S m,n=S n} mat | Shape [S m,S n] = - iterateN (minimum (S m) (S n)) gaussStepSwap (MkLUP mat identity 0) + iterateN (S $ minimum m n) gaussStepSwap (MkLUP mat identity 0) where maxIndex : (s,a) -> List (s,a) -> (s,a) maxIndex x [] = x @@ -403,15 +403,14 @@ decompLUP {m,n} mat with (viewShape mat) if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs) else assert_total $ maxIndex x ((a,b)::xs) - gaussStepSwap : Fin (minimum (S m) (S n)) -> DecompLUP mat -> DecompLUP mat + gaussStepSwap : Fin (S $ minimum m n) -> DecompLUP mat -> DecompLUP mat gaussStepSwap i (MkLUP lu p sw) = let ir = minWeakenLeft {n=S n} i ic = minWeakenRight {m=S m} i - (maxi, maxv) = mapFst head - (maxIndex ([0],0) $ enumerate $ + maxi = head $ fst (maxIndex ([0],0) $ drop (cast i) $ enumerate $ indexSetRange [EndBound (weaken ir)] 0 $ getColumn ic lu) in if maxi == ir then MkLUP (gaussStep i lu) p sw - else MkLUP (gaussStep i $ swapRows maxi ir lu) (appendSwap maxi ir p) (S sw) + else MkLUP (gaussStep i $ swapRows ir maxi lu) (appendSwap maxi ir p) (S sw) -------------------------------------------------------------------------------- @@ -439,13 +438,9 @@ det {n} mat with (viewShape mat) -------------------------------------------------------------------------------- -export -solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) -solveLowerTri {n} mat b with (viewShape b) - _ | Shape [n] = - if all (/=0) (diagonal mat) - then Just $ vector $ reverse $ construct $ reverse $ toVect b - else Nothing +solveLowerTri' : Field a => Matrix' n a -> Vector n a -> Vector n a +solveLowerTri' {n} mat b with (viewShape b) + _ | Shape [n] = vector $ reverse $ construct $ reverse $ toVect b where construct : {i : _} -> Vect i a -> Vect i a construct [] = [] @@ -455,13 +450,10 @@ solveLowerTri {n} mat b with (viewShape b) in (b - sum (zipWith (*) xs (reverse $ toVect $ believe_me $ mat !!.. [One i', EndBound (weaken i')]))) / mat!#[i,i] :: xs -export -solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) -solveUpperTri {n} mat b with (viewShape b) - _ | Shape [n] = - if all (/=0) (diagonal mat) - then Just $ vector $ construct Z $ toVect b - else Nothing + +solveUpperTri' : Field a => Matrix' n a -> Vector n a -> Vector n a +solveUpperTri' {n} mat b with (viewShape b) + _ | Shape [n] = vector $ construct Z $ toVect b where construct : Nat -> Vect i a -> Vect i a construct _ [] = [] @@ -471,6 +463,26 @@ solveUpperTri {n} mat b with (viewShape b) in (b - sum (zipWith (*) xs (toVect $ believe_me $ mat !!.. [One i', StartBound (FS i')]))) / mat!#[i,i] :: xs + +export +solveLowerTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) +solveLowerTri mat b = if all (/=0) (diagonal mat) + then Just $ solveLowerTri' mat b + else Nothing + +export +solveUpperTri : Field a => Matrix' n a -> Vector n a -> Maybe (Vector n a) +solveUpperTri mat b = if all (/=0) (diagonal mat) + then Just $ solveUpperTri' mat b + else Nothing + + +solveWithLUP' : Field a => (mat : Matrix' n a) -> DecompLUP mat -> + Vector n a -> Vector n a +solveWithLUP' mat lup b = + let b' = permuteCoords (inverse lup.permute) b + in solveUpperTri' lup.upper' $ solveLowerTri' lup.lower' b' + export solveWithLUP : Field a => (mat : Matrix' n a) -> DecompLUP mat -> Vector n a -> Maybe (Vector n a) @@ -478,6 +490,22 @@ solveWithLUP mat lup b = let b' = permuteCoords (inverse lup.permute) b in solveLowerTri lup.lower' b' >>= solveUpperTri lup.upper' + export solve : FieldCmp a => Matrix' n a -> Vector n a -> Maybe (Vector n a) solve mat = solveWithLUP mat (decompLUP mat) + + +export +tryInverse : FieldCmp a => Matrix' n a -> Maybe (Matrix' n a) +tryInverse {n} mat with (viewShape mat) + _ | Shape [n,n] = + let lup = decompLUP mat + in map hstack $ traverse (solveWithLUP mat lup) $ map basis range + + +export +{n : _} -> FieldCmp a => MultGroup (Matrix' n a) where + inverse mat = let lup = decompLUP mat in + hstack $ map (solveWithLUP' mat lup . basis) range +