Tweak library functions to match reps
This commit is contained in:
parent
983733f241
commit
b924d960b5
8 changed files with 123 additions and 103 deletions
|
|
@ -5,8 +5,8 @@ import Data.Vect
|
|||
import Data.Permutation
|
||||
import Data.NumIdr.Interfaces
|
||||
import public Data.NumIdr.Array
|
||||
import Data.NumIdr.PrimArray
|
||||
import Data.NumIdr.Vector
|
||||
import Data.NumIdr.LArray
|
||||
|
||||
%default total
|
||||
|
||||
|
|
@ -31,13 +31,9 @@ Matrix' n = Array [n,n]
|
|||
|
||||
||| Construct a matrix with the given order and elements.
|
||||
export
|
||||
matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a
|
||||
matrix' ord x = array' [m,n] ord x
|
||||
|
||||
||| Construct a matrix with the given elements.
|
||||
export
|
||||
matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a
|
||||
matrix = matrix' COrder
|
||||
matrix : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
|
||||
Vect m (Vect n a) -> Matrix m n a
|
||||
matrix x = array' {rep} [m,n] x
|
||||
|
||||
|
||||
||| Construct a matrix with a specific value along the diagonal.
|
||||
|
|
@ -45,8 +41,9 @@ matrix = matrix' COrder
|
|||
||| @ diag The value to repeat along the diagonal
|
||||
||| @ other The value to repeat elsewhere
|
||||
export
|
||||
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
||||
repeatDiag d o = fromFunctionNB [m,n]
|
||||
repeatDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
|
||||
(diag, other : a) -> Matrix m n a
|
||||
repeatDiag d o = fromFunctionNB {rep} [m,n]
|
||||
(\[i,j] => if i == j then d else o)
|
||||
|
||||
||| Construct a matrix given its diagonal elements.
|
||||
|
|
@ -54,8 +51,9 @@ repeatDiag d o = fromFunctionNB [m,n]
|
|||
||| @ diag The elements of the matrix's diagonal
|
||||
||| @ other The value to repeat elsewhere
|
||||
export
|
||||
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
||||
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||
fromDiag : {default B rep : Rep} -> RepConstraint rep a => {m, n : _} ->
|
||||
(diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
||||
fromDiag ds o = fromFunction {rep} [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||
where
|
||||
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
|
||||
eq FZ FZ = Just FZ
|
||||
|
|
@ -66,52 +64,62 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
|||
|
||||
||| Construct a permutation matrix based on the given permutation.
|
||||
export
|
||||
permuteM : {n : _} -> Num a => Permutation n -> Matrix' n a
|
||||
permuteM p = permuteInAxis 0 p (repeatDiag 1 0)
|
||||
permuteM : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a =>
|
||||
Permutation n -> Matrix' n a
|
||||
permuteM p = permuteInAxis 0 p (repeatDiag {rep} 1 0)
|
||||
|
||||
|
||||
||| Construct the matrix that scales a vector by the given value.
|
||||
export
|
||||
scale : {n : _} -> Num a => a -> Matrix' n a
|
||||
scale x = repeatDiag x 0
|
||||
scale : {default B rep : Rep} -> RepConstraint rep a => {n : _} -> Num a =>
|
||||
a -> Matrix' n a
|
||||
scale x = repeatDiag {rep} x 0
|
||||
|
||||
||| Construct a 2D rotation matrix that rotates by the given angle (in radians).
|
||||
export
|
||||
rotate2D : Double -> Matrix' 2 Double
|
||||
rotate2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
||||
rotate2D : {default B rep : Rep} -> RepConstraint rep Double =>
|
||||
Double -> Matrix' 2 Double
|
||||
rotate2D a = matrix {rep} [[cos a, - sin a], [sin a, cos a]]
|
||||
|
||||
|
||||
||| Construct a 3D rotation matrix around the x-axis.
|
||||
export
|
||||
rotate3DX : Double -> Matrix' 3 Double
|
||||
rotate3DX a = matrix [[1,0,0], [0, cos a, - sin a], [0, sin a, cos a]]
|
||||
rotate3DX : {default B rep : Rep} -> RepConstraint rep Double =>
|
||||
Double -> Matrix' 3 Double
|
||||
rotate3DX a = matrix {rep} [[1,0,0], [0, cos a, - sin a], [0, sin a, cos a]]
|
||||
|
||||
||| Construct a 3D rotation matrix around the y-axis.
|
||||
export
|
||||
rotate3DY : Double -> Matrix' 3 Double
|
||||
rotate3DY a = matrix [[cos a, 0, sin a], [0,1,0], [- sin a, 0, cos a]]
|
||||
rotate3DY : {default B rep : Rep} -> RepConstraint rep Double =>
|
||||
Double -> Matrix' 3 Double
|
||||
rotate3DY a = matrix {rep} [[cos a, 0, sin a], [0,1,0], [- sin a, 0, cos a]]
|
||||
|
||||
||| Construct a 3D rotation matrix around the z-axis.
|
||||
export
|
||||
rotate3DZ : Double -> Matrix' 3 Double
|
||||
rotate3DZ a = matrix [[cos a, - sin a, 0], [sin a, cos a, 0], [0,0,1]]
|
||||
rotate3DZ : {default B rep : Rep} -> RepConstraint rep Double =>
|
||||
Double -> Matrix' 3 Double
|
||||
rotate3DZ a = matrix {rep} [[cos a, - sin a, 0], [sin a, cos a, 0], [0,0,1]]
|
||||
|
||||
|
||||
export
|
||||
reflect : {n : _} -> Neg a => Fin n -> Matrix' n a
|
||||
reflect i = indexSet [i, i] (-1) (repeatDiag 1 0)
|
||||
reflect : {default B rep : Rep} -> RepConstraint rep a =>
|
||||
{n : _} -> Neg a => Fin n -> Matrix' n a
|
||||
reflect i = indexSet [i, i] (-1) (repeatDiag {rep} 1 0)
|
||||
|
||||
export
|
||||
reflectX : {n : _} -> Neg a => Matrix' (1 + n) a
|
||||
reflectX = reflect 0
|
||||
reflectX : {default B rep : Rep} -> RepConstraint rep a =>
|
||||
{n : _} -> Neg a => Matrix' (1 + n) a
|
||||
reflectX = reflect {rep} 0
|
||||
|
||||
export
|
||||
reflectY : {n : _} -> Neg a => Matrix' (2 + n) a
|
||||
reflectY = reflect 1
|
||||
reflectY : {default B rep : Rep} -> RepConstraint rep a =>
|
||||
{n : _} -> Neg a => Matrix' (2 + n) a
|
||||
reflectY = reflect {rep} 1
|
||||
|
||||
export
|
||||
reflectZ : {n : _} -> Neg a => Matrix' (3 + n) a
|
||||
reflectZ = reflect 2
|
||||
reflectZ : {default B rep : Rep} -> RepConstraint rep a =>
|
||||
{n : _} -> Neg a => Matrix' (3 + n) a
|
||||
reflectZ = reflect {rep} 2
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
|
@ -146,13 +154,13 @@ getColumn c mat = rewrite sym (rangeLenZ m) in mat!!..[All, One c]
|
|||
export
|
||||
diagonal' : Matrix m n a -> Vector (minimum m n) a
|
||||
diagonal' {m,n} mat with (viewShape mat)
|
||||
_ | Shape [m,n] = fromFunctionNB _ (\[i] => mat!#[i,i])
|
||||
_ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat} _ (\[i] => mat!#[i,i])
|
||||
|
||||
||| Return the diagonal elements of the matrix as a vector.
|
||||
export
|
||||
diagonal : Matrix' n a -> Vector n a
|
||||
diagonal {n} mat with (viewShape mat)
|
||||
_ | Shape [n,n] = fromFunctionNB [n] (\[i] => mat!#[i,i])
|
||||
_ | Shape [n,n] = fromFunctionNB {rep=_} @{getRepC mat} [n] (\[i] => mat!#[i,i])
|
||||
|
||||
|
||||
||| Return a minor of the matrix, i.e. the matrix formed by removing a
|
||||
|
|
@ -165,7 +173,8 @@ minor i j mat = believe_me $ mat!!..[Filter (/=i), Filter (/=j)]
|
|||
|
||||
filterInd : Num a => (Nat -> Nat -> Bool) -> Matrix m n a -> Matrix m n a
|
||||
filterInd {m,n} p mat with (viewShape mat)
|
||||
_ | Shape [m,n] = fromFunctionNB [m,n] (\[i,j] => if p i j then mat!#[i,j] else 0)
|
||||
_ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC mat}
|
||||
[m,n] (\[i,j] => if p i j then mat!#[i,j] else 0)
|
||||
|
||||
export
|
||||
upperTriangle : Num a => Matrix m n a -> Matrix m n a
|
||||
|
|
@ -259,13 +268,15 @@ reflectNormal {n} v with (viewShape v)
|
|||
export
|
||||
Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where
|
||||
(*.) {m,n} mat v with (viewShape mat)
|
||||
_ | Shape [m,n] = fromFunction [m]
|
||||
_ | Shape [m,n] = fromFunction {rep=_}
|
||||
@{mergeRepConstraint (getRepC mat) (getRepC v)} [m]
|
||||
(\[i] => sum $ map (\j => mat!![i,j] * v!!j) range)
|
||||
|
||||
export
|
||||
Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where
|
||||
(*.) {m,n,p} m1 m2 with (viewShape m1, viewShape m2)
|
||||
_ | (Shape [m,n], Shape [n,p]) = fromFunction [m,p]
|
||||
_ | (Shape [m,n], Shape [n,p]) = fromFunction {rep=_}
|
||||
@{mergeRepConstraint (getRepC m1) (getRepC m2)} [m,p]
|
||||
(\[i,j] => sum $ map (\k => m1!![i,k] * m2!![k,j]) range)
|
||||
|
||||
export
|
||||
|
|
@ -295,7 +306,7 @@ namespace DecompLU
|
|||
export
|
||||
lower : Num a => DecompLU {m,n,a} mat -> Matrix m (minimum m n) a
|
||||
lower {m,n} (MkLU lu) with (viewShape lu)
|
||||
_ | Shape [m,n] = fromFunctionNB _ (\[i,j] =>
|
||||
_ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
|
||||
case compare i j of
|
||||
LT => 0
|
||||
EQ => 1
|
||||
|
|
@ -325,7 +336,7 @@ namespace DecompLU
|
|||
export
|
||||
upper : Num a => DecompLU {m,n,a} mat -> Matrix (minimum m n) n a
|
||||
upper {m,n} (MkLU lu) with (viewShape lu)
|
||||
_ | Shape [m,n] = fromFunctionNB _ (\[i,j] =>
|
||||
_ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
|
||||
if i <= j then lu!#[i,j] else 0)
|
||||
|
||||
||| The upper triangular matrix U of the LU decomposition.
|
||||
|
|
@ -392,7 +403,8 @@ gaussStep i lu =
|
|||
export
|
||||
decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat)
|
||||
decompLU {m,n} mat with (viewShape mat)
|
||||
_ | Shape [m,n] = map MkLU $ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat)
|
||||
_ | Shape [m,n] = map (MkLU . convertRep _ @{getRepC mat})
|
||||
$ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just (convertRep Delayed mat))
|
||||
where
|
||||
gaussStepMaybe : Fin (minimum m n) -> Matrix m n a -> Maybe (Matrix m n a)
|
||||
gaussStepMaybe i mat = if mat!#[cast i,cast i] == 0 then Nothing
|
||||
|
|
@ -417,7 +429,7 @@ namespace DecompLUP
|
|||
export
|
||||
lower : Num a => DecompLUP {m,n,a} mat -> Matrix m (minimum m n) a
|
||||
lower {m,n} (MkLUP lu _ _) with (viewShape lu)
|
||||
_ | Shape [m,n] = fromFunctionNB _ (\[i,j] =>
|
||||
_ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
|
||||
case compare i j of
|
||||
LT => 0
|
||||
EQ => 1
|
||||
|
|
@ -447,7 +459,7 @@ namespace DecompLUP
|
|||
export
|
||||
upper : Num a => DecompLUP {m,n,a} mat -> Matrix (minimum m n) n a
|
||||
upper {m,n} (MkLUP lu _ _) with (viewShape lu)
|
||||
_ | Shape [m,n] = fromFunctionNB _ (\[i,j] =>
|
||||
_ | Shape [m,n] = fromFunctionNB {rep=_} @{getRepC lu} _ (\[i,j] =>
|
||||
if i <= j then lu!#[i,j] else 0)
|
||||
|
||||
||| The upper triangular matrix U of the LUP decomposition.
|
||||
|
|
@ -634,4 +646,3 @@ export
|
|||
{n : _} -> FieldCmp a => MultGroup (Matrix' n a) where
|
||||
inverse mat = let lup = decompLUP mat in
|
||||
hstack $ map (solveWithLUP' mat lup . basis) range
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue