Create Field and Scalar interfaces

This commit is contained in:
Kiana Sheibani 2022-09-06 13:37:50 -04:00
parent 47e889992d
commit 3e12505377
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
8 changed files with 56 additions and 19 deletions

View file

@ -17,7 +17,7 @@ modules = Data.NP,
Data.NumIdr.Array.Order,
Data.NumIdr.Homogeneous,
Data.NumIdr.Matrix,
Data.NumIdr.Multiply,
Data.NumIdr.Interfaces,
Data.NumIdr.PrimArray,
Data.NumIdr.Scalar,
Data.NumIdr.Vector

View file

@ -6,7 +6,7 @@ import Data.Vect
import Data.Zippable
import Data.NP
import Data.Permutation
import Data.NumIdr.Multiply
import Data.NumIdr.Interfaces
import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order
import Data.NumIdr.Array.Coords

View file

@ -1,7 +1,7 @@
module Data.NumIdr.Homogeneous
import Data.Vect
import Data.NumIdr.Multiply
import Data.NumIdr.Interfaces
import Data.NumIdr.Vector
import Data.NumIdr.Matrix

View file

@ -1,8 +1,32 @@
module Data.NumIdr.Multiply
module Data.NumIdr.Interfaces
%default total
--------------------------------------------------------------------------------
-- Field
--------------------------------------------------------------------------------
public export
Field : Type -> Type
Field a = (Eq a, Neg a, Fractional a)
public export
interface (Eq a, Neg a, Fractional a) => Scalar a where
abscmp : a -> a -> Ordering
export
(Ord a, Abs a, Neg a, Fractional a) => Scalar a where
abscmp x y = compare (abs x) (abs y)
--------------------------------------------------------------------------------
-- Multiplication
--------------------------------------------------------------------------------
infixr 9 *.
infixr 10 ^

View file

@ -3,7 +3,7 @@ module Data.NumIdr.Matrix
import Data.List
import Data.Vect
import Data.Permutation
import Data.NumIdr.Multiply
import Data.NumIdr.Interfaces
import public Data.NumIdr.Array
import Data.NumIdr.Vector
@ -278,7 +278,7 @@ iterateN 1 f x = f FZ x
iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
gaussStep : {m,n : _} -> (Eq a, Neg a, Fractional a) =>
gaussStep : {m,n : _} -> Field a =>
Fin (minimum m n) -> Matrix m n a -> Matrix m n a
gaussStep i lu =
if all (==0) $ getColumn (minWeakenRight i) lu then lu else
@ -294,7 +294,7 @@ gaussStep i lu =
(flip (-) offsets) lu'
export
decompLU : (Eq a, Neg a, Fractional a) => (mat : Matrix m n a) -> Maybe (DecompLU mat)
decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat)
decompLU mat with (viewShape mat)
_ | Shape [m,n] = map MkLU $ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat)
where
@ -353,8 +353,7 @@ fromLU (MkLU lu) = MkLUP lu identity 0
export
decompLUP : (Ord a, Abs a, Neg a, Fractional a) =>
(mat : Matrix m n a) -> DecompLUP mat
decompLUP : Scalar a => (mat : Matrix m n a) -> DecompLUP mat
decompLUP {m,n} mat with (viewShape mat)
decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0
decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0
@ -365,7 +364,7 @@ decompLUP {m,n} mat with (viewShape mat)
maxIndex x [] = x
maxIndex _ [x] = x
maxIndex x ((a,b)::(c,d)::xs) =
if abs b < abs d then assert_total $ maxIndex x ((c,d)::xs)
if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs)
else assert_total $ maxIndex x ((a,b)::xs)
gaussStepSwap : Fin (minimum (S m) (S n)) -> DecompLUP mat -> DecompLUP mat
@ -385,16 +384,20 @@ decompLUP {m,n} mat with (viewShape mat)
export
detWithLUP : (Ord a, Abs a, Neg a, Fractional a) =>
(mat : Matrix' n a) -> DecompLUP mat -> a
detWithLUP {n} mat lup =
detWithLUP : Scalar a => (mat : Matrix' n a) -> DecompLUP mat -> a
detWithLUP mat lup =
(if numSwaps lup `mod` 2 == 0 then 1 else -1)
* product (diagonal lup.lu)
export
det : (Ord a, Abs a, Neg a, Fractional a) => Matrix' n a -> a
det : Scalar a => Matrix' n a -> a
det {n} mat with (viewShape mat)
det {n=0} mat | Shape [0,0] = 1
det {n=1} mat | Shape [1,1] = mat!![0,0]
det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
_ | Shape [n,n] = detWithLUP mat (decompLUP mat)
solveWithLUP : Scalar a => (mat : Matrix m n a) -> DecompLUP mat ->
Vector m a -> Maybe (Vector n a)
solveWithLUP mat lup b = ?h

View file

@ -1,7 +1,7 @@
module Data.NumIdr.Scalar
import Data.Vect
import Data.NumIdr.Multiply
import Data.NumIdr.Interfaces
import Data.NumIdr.PrimArray
import public Data.NumIdr.Array

View file

@ -1,7 +1,8 @@
module Data.NumIdr.Vector
import Data.Vect
import Data.NumIdr.Multiply
import Data.Permutation
import Data.NumIdr.Interfaces
import public Data.NumIdr.Array
%default total
@ -117,7 +118,7 @@ export
--------------------------------------------------------------------------------
-- Swizzling
-- Swizzling & permuting elements
--------------------------------------------------------------------------------
@ -130,6 +131,14 @@ swizzle p v = rewrite sym (lengthCorrect p)
)
export
swapCoords : (i,j : Fin n) -> Vector n a -> Vector n a
swapCoords = swapInAxis 0
export
permuteCoords : Permutation n -> Vector n a -> Vector n a
permuteCoords = permuteInAxis 0
--------------------------------------------------------------------------------
-- Vector operations

View file

@ -2,8 +2,9 @@ module Data.Permutation
import Data.List
import Data.Vect
import Data.NumIdr.Multiply
import Data.NumIdr.Interfaces
%default total
export
record Permutation n where