diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index 4acf004..8e89120 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -2,7 +2,7 @@ module Data.NumIdr.Matrix import Data.List import Data.Vect -import Data.Bool.Xor +import Data.Permutation import Data.NumIdr.Multiply import public Data.NumIdr.Array import Data.NumIdr.Vector @@ -61,6 +61,11 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j) eq (FS _) FZ = Nothing +export +permutationMatrix : {n : _} -> Num a => Permutation n -> Matrix' n a +permutationMatrix p = permuteInAxis 0 p (repeatDiag 1 0) + + ||| Construct the matrix that scales a vector by the given value. export scaling : {n : _} -> Num a => a -> Matrix' n a @@ -117,6 +122,27 @@ minor : Fin (S m) -> Fin (S n) -> Matrix (S m) (S n) a -> Matrix m n a minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)] +filterInd : Num a => (Nat -> Nat -> Bool) -> Matrix m n a -> Matrix m n a +filterInd p mat with (viewShape mat) + _ | Shape [m,n] = fromFunctionNB [m,n] (\[i,j] => if p i j then mat!#[i,j] else 0) + +export +upperTriangle : Num a => Matrix m n a -> Matrix m n a +upperTriangle = filterInd (<=) + +export +lowerTriangle : Num a => Matrix m n a -> Matrix m n a +lowerTriangle = filterInd (>=) + +export +upperTriangleStrict : Num a => Matrix m n a -> Matrix m n a +upperTriangleStrict = filterInd (<) + +export +lowerTriangleStrict : Num a => Matrix m n a -> Matrix m n a +lowerTriangleStrict = filterInd (>) + + -------------------------------------------------------------------------------- -- Basic operations -------------------------------------------------------------------------------- @@ -141,6 +167,23 @@ hstack : {m : _} -> Vect n (Vector m a) -> Matrix m n a hstack = stack 1 +export +swapRows : (i,j : Fin m) -> Matrix m n a -> Matrix m n a +swapRows = swapInAxis 0 + +export +swapColumns : (i,j : Fin n) -> Matrix m n a -> Matrix m n a +swapColumns = swapInAxis 1 + +export +permuteRows : Permutation m -> Matrix m n a -> Matrix m n a +permuteRows = permuteInAxis 0 + +export +permuteColumns : Permutation n -> Matrix m n a -> Matrix m n a +permuteColumns = permuteInAxis 1 + + ||| Calculate the outer product of two vectors as a matrix. export outer : Num a => Vector m a -> Vector n a -> Matrix m n a @@ -181,14 +224,29 @@ export -- LU Decomposition -public export +export record DecompLU {0 n,a : _} (mat : Matrix' n a) where constructor MkLU - lower, upper : Matrix' n a + lu : Matrix' n a -export -Show a => Show (DecompLU {a} mat) where - showPrec p (MkLU l u) = showCon p "MkLU" $ showArg l ++ showArg u + +namespace DecompLU + export + lower : Num a => DecompLU {n,a} mat -> Matrix' n a + lower (MkLU lu) with (viewShape lu) + _ | Shape [n,n] = lowerTriangleStrict lu + identity + + export %inline + (.lower) : Num a => DecompLU {n,a} mat -> Matrix' n a + (.lower) = lower + + export + upper : Num a => DecompLU {n,a} mat -> Matrix' n a + upper (MkLU lu) = upperTriangle lu + + export %inline + (.upper) : Num a => DecompLU {n,a} mat -> Matrix' n a + (.upper) = upper iterateN : (n : Nat) -> (Fin n -> a -> a) -> a -> a @@ -196,45 +254,75 @@ iterateN 0 f x = x iterateN 1 f x = f FZ x iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x + +gaussStep : (Eq a, Neg a, Fractional a) => Fin n -> Matrix' n a -> Matrix' n a +gaussStep {n} i lu with (viewShape lu) + _ | Shape [n,n] = + if all (==0) $ getColumn i lu then lu else + let diag = lu!![i,i] + coeffs = map (/diag) $ lu!!..[StartBound (FS i), One i] + lu' = indexSetRange [StartBound (FS i), One i] + coeffs lu + pivot = lu!!..[One i, StartBound (FS i)] + offsets = negate $ outer coeffs pivot + in indexUpdateRange [StartBound (FS i), StartBound (FS i)] (+offsets) lu' + export -decompLU : Neg a => Fractional a => (mat : Matrix' n a) -> DecompLU mat +decompLU : (Eq a, Neg a, Fractional a) => (mat : Matrix' n a) -> DecompLU mat decompLU {n} mat with (viewShape mat) - _ | Shape [n,n] = iterateN n doolittle (MkLU identity mat) - where - doolittle : Fin n -> DecompLU mat -> DecompLU mat - doolittle i (MkLU l u) = - let v = rewrite rangeLen (S i') n in fromFunctionNB [minus n (S i')] - (\[x] => u!#[S i' + x,i'] / u!#[i',i']) - low = indexSetRange [StartBound (FS i), One i] (-v) identity - in MkLU (indexSetRange [StartBound (FS i), One i] v l) (low *. u) - where i' : Nat - i' = cast i + _ | Shape [n,n] = MkLU $ iterateN n gaussStep mat -- LUP Decomposition public export record DecompLUP {0 n,a : _} (mat : Matrix' n a) where constructor MkLUP - lower, upper, permute : Matrix' n a - swaps : Nat + lu : Matrix' n a + p : Permutation n + sw : Nat + +namespace DecompLUP + export + lower : Num a => DecompLUP {n,a} mat -> Matrix' n a + lower (MkLUP lu p sw) with (viewShape lu) + _ | Shape [n,n] = lowerTriangleStrict lu + identity + + export %inline + (.lower) : Num a => DecompLUP {n,a} mat -> Matrix' n a + (.lower) = lower + + export + upper : Num a => DecompLUP {n,a} mat -> Matrix' n a + upper (MkLUP lu p sw) = upperTriangle lu + + export %inline + (.upper) : Num a => DecompLUP {n,a} mat -> Matrix' n a + (.upper) = upper + + export + permute : DecompLUP {n} mat -> Permutation n + permute (MkLUP lu p sw) = p + + export %inline + (.permute) : DecompLUP {n} mat -> Permutation n + (.permute) = permute + + export + numSwaps : DecompLUP {n} mat -> Nat + numSwaps (MkLUP lu p sw) = sw export -Show a => Show (DecompLUP {a} mat) where - showPrec p (MkLUP l u pr b) = showCon p "MkLUP" $ - showArg l ++ showArg u ++ showArg pr ++ showArg b +fromLU : DecompLU mat -> DecompLUP mat +fromLU (MkLU lu) = MkLUP lu identity 0 + export -fromLU : Num a => DecompLU {n,a} mat -> DecompLUP mat -fromLU {n} (MkLU l u) with (viewShape l) - _ | Shape [n,n] = MkLUP l u identity 0 - -export -decompLUP : Ord a => Abs a => Neg a => Fractional a => +decompLUP : (Ord a, Abs a, Neg a, Fractional a) => (mat : Matrix' n a) -> DecompLUP mat decompLUP {n} mat with (viewShape mat) - decompLUP {n=0} mat | Shape [0,0] = MkLUP mat mat mat 0 + decompLUP {n=0} mat | Shape [0,0] = MkLUP mat identity 0 decompLUP {n=S n} mat | Shape [S n,S n] = - iterateN (S n) doolittle (MkLUP identity mat identity 0) + iterateN (S n) gaussStepSwap (MkLUP mat identity 0) where maxIndex : (s,a) -> List (s,a) -> (s,a) maxIndex x [] = x @@ -243,29 +331,13 @@ decompLUP {n} mat with (viewShape mat) if abs b < abs d then assert_total $ maxIndex x ((c,d)::xs) else assert_total $ maxIndex x ((a,b)::xs) - doolittle : Fin (S n) -> DecompLUP mat -> DecompLUP mat - doolittle i (MkLUP l u p sw) = - let (maxi, maxv) = mapFst ((+i') . head) - (maxIndex ([0],0) $ enumerateNB $ - u !!.. [StartBound (weaken i), One i]) - u' = if maxi == i' then u - else fromFunctionNB _ (\[x,y] => - if x==i' then u!#[maxi,y] - else if x==maxi then u!#[i',y] - else u!#[x,y]) - p' = if maxi == i' then p - else fromFunctionNB _ (\[x,y] => - if x==i' then p!#[maxi,y] - else if x==maxi then p!#[i',y] - else p!#[x,y]) - v = rewrite rangeLen (S i') (S n) in fromFunctionNB [minus n i'] - (\[x] => u'!#[S i' + x,i'] / u'!#[i',i']) - low = indexSetRange [StartBound (FS i), One i] (-v) identity - in if maxv == 0 then MkLUP l u p sw else - MkLUP (indexSetRange [StartBound (FS i), One i] v l) - (low *. u') p' (if maxi==i' then S sw else sw) - where i' : Nat - i' = cast i + gaussStepSwap : Fin (S n) -> DecompLUP mat -> DecompLUP mat + gaussStepSwap i (MkLUP lu p sw) = + let (maxi, maxv) = mapFst head + (maxIndex ([0],0) $ enumerate $ + indexSetRange [EndBound (weaken i)] 0 $ getColumn i lu) + in if maxi == i then MkLUP (gaussStep i lu) p sw + else MkLUP (gaussStep i $ swapRows maxi i lu) (appendSwap maxi i p) (S sw) -------------------------------------------------------------------------------- @@ -274,12 +346,16 @@ decompLUP {n} mat with (viewShape mat) export -det : Ord a => Abs a => Neg a => Fractional a => - Matrix' n a -> a +detWithLUP : (Ord a, Abs a, Neg a, Fractional a) => + (mat : Matrix' n a) -> DecompLUP mat -> a +detWithLUP {n} mat lup = + (if numSwaps lup `mod` 2 == 0 then 1 else -1) + * product (diagonal lup.lower) * product (diagonal lup.upper) + +export +det : (Ord a, Abs a, Neg a, Fractional a) => Matrix' n a -> a det {n} mat with (viewShape mat) det {n=0} mat | Shape [0,0] = 1 det {n=1} mat | Shape [1,1] = mat!![0,0] det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c - _ | Shape [n,n] = let MkLUP l u p sw = decompLUP mat - in (if sw `mod` 2 == 0 then 1 else -1) - * product (diagonal l) * product (diagonal u) + _ | Shape [n,n] = detWithLUP mat (decompLUP mat) diff --git a/src/Data/Permutation.idr b/src/Data/Permutation.idr index cb15fcd..4f3a48b 100644 --- a/src/Data/Permutation.idr +++ b/src/Data/Permutation.idr @@ -15,6 +15,10 @@ export swap : (i,j : Fin n) -> Permutation n swap x y = MkPerm [(x,y)] +export +swaps : List (Fin n, Fin n) -> Permutation n +swaps = MkPerm + export appendSwap : (i,j : Fin n) -> Permutation n -> Permutation n appendSwap i j (MkPerm a) = MkPerm ((i,j)::a) @@ -51,6 +55,12 @@ export permuteValues : Permutation n -> Nat -> Nat permuteValues p = foldMap @{%search} @{mon} (\(i,j) => swapValues i j) p.swaps + + +export +Show (Permutation n) where + showPrec p (MkPerm a) = showCon p "swaps" $ showArg a + export Mult (Permutation n) (Permutation n) (Permutation n) where MkPerm a *. MkPerm b = MkPerm (a ++ b)