Create Field and Scalar interfaces

This commit is contained in:
Kiana Sheibani 2022-09-06 13:37:50 -04:00
parent 47e889992d
commit 3e12505377
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
8 changed files with 56 additions and 19 deletions

View file

@ -17,7 +17,7 @@ modules = Data.NP,
Data.NumIdr.Array.Order, Data.NumIdr.Array.Order,
Data.NumIdr.Homogeneous, Data.NumIdr.Homogeneous,
Data.NumIdr.Matrix, Data.NumIdr.Matrix,
Data.NumIdr.Multiply, Data.NumIdr.Interfaces,
Data.NumIdr.PrimArray, Data.NumIdr.PrimArray,
Data.NumIdr.Scalar, Data.NumIdr.Scalar,
Data.NumIdr.Vector Data.NumIdr.Vector

View file

@ -6,7 +6,7 @@ import Data.Vect
import Data.Zippable import Data.Zippable
import Data.NP import Data.NP
import Data.Permutation import Data.Permutation
import Data.NumIdr.Multiply import Data.NumIdr.Interfaces
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order import Data.NumIdr.Array.Order
import Data.NumIdr.Array.Coords import Data.NumIdr.Array.Coords

View file

@ -1,7 +1,7 @@
module Data.NumIdr.Homogeneous module Data.NumIdr.Homogeneous
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply import Data.NumIdr.Interfaces
import Data.NumIdr.Vector import Data.NumIdr.Vector
import Data.NumIdr.Matrix import Data.NumIdr.Matrix

View file

@ -1,8 +1,32 @@
module Data.NumIdr.Multiply module Data.NumIdr.Interfaces
%default total %default total
--------------------------------------------------------------------------------
-- Field
--------------------------------------------------------------------------------
public export
Field : Type -> Type
Field a = (Eq a, Neg a, Fractional a)
public export
interface (Eq a, Neg a, Fractional a) => Scalar a where
abscmp : a -> a -> Ordering
export
(Ord a, Abs a, Neg a, Fractional a) => Scalar a where
abscmp x y = compare (abs x) (abs y)
--------------------------------------------------------------------------------
-- Multiplication
--------------------------------------------------------------------------------
infixr 9 *. infixr 9 *.
infixr 10 ^ infixr 10 ^

View file

@ -3,7 +3,7 @@ module Data.NumIdr.Matrix
import Data.List import Data.List
import Data.Vect import Data.Vect
import Data.Permutation import Data.Permutation
import Data.NumIdr.Multiply import Data.NumIdr.Interfaces
import public Data.NumIdr.Array import public Data.NumIdr.Array
import Data.NumIdr.Vector import Data.NumIdr.Vector
@ -278,7 +278,7 @@ iterateN 1 f x = f FZ x
iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x iterateN (S n@(S _)) f x = iterateN n (f . FS) $ f FZ x
gaussStep : {m,n : _} -> (Eq a, Neg a, Fractional a) => gaussStep : {m,n : _} -> Field a =>
Fin (minimum m n) -> Matrix m n a -> Matrix m n a Fin (minimum m n) -> Matrix m n a -> Matrix m n a
gaussStep i lu = gaussStep i lu =
if all (==0) $ getColumn (minWeakenRight i) lu then lu else if all (==0) $ getColumn (minWeakenRight i) lu then lu else
@ -294,7 +294,7 @@ gaussStep i lu =
(flip (-) offsets) lu' (flip (-) offsets) lu'
export export
decompLU : (Eq a, Neg a, Fractional a) => (mat : Matrix m n a) -> Maybe (DecompLU mat) decompLU : Field a => (mat : Matrix m n a) -> Maybe (DecompLU mat)
decompLU mat with (viewShape mat) decompLU mat with (viewShape mat)
_ | Shape [m,n] = map MkLU $ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat) _ | Shape [m,n] = map MkLU $ iterateN (minimum m n) (\i => (>>= gaussStepMaybe i)) (Just mat)
where where
@ -353,8 +353,7 @@ fromLU (MkLU lu) = MkLUP lu identity 0
export export
decompLUP : (Ord a, Abs a, Neg a, Fractional a) => decompLUP : Scalar a => (mat : Matrix m n a) -> DecompLUP mat
(mat : Matrix m n a) -> DecompLUP mat
decompLUP {m,n} mat with (viewShape mat) decompLUP {m,n} mat with (viewShape mat)
decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0 decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0
decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0 decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0
@ -365,8 +364,8 @@ decompLUP {m,n} mat with (viewShape mat)
maxIndex x [] = x maxIndex x [] = x
maxIndex _ [x] = x maxIndex _ [x] = x
maxIndex x ((a,b)::(c,d)::xs) = maxIndex x ((a,b)::(c,d)::xs) =
if abs b < abs d then assert_total $ maxIndex x ((c,d)::xs) if abscmp b d == LT then assert_total $ maxIndex x ((c,d)::xs)
else assert_total $ maxIndex x ((a,b)::xs) else assert_total $ maxIndex x ((a,b)::xs)
gaussStepSwap : Fin (minimum (S m) (S n)) -> DecompLUP mat -> DecompLUP mat gaussStepSwap : Fin (minimum (S m) (S n)) -> DecompLUP mat -> DecompLUP mat
gaussStepSwap i (MkLUP lu p sw) = gaussStepSwap i (MkLUP lu p sw) =
@ -385,16 +384,20 @@ decompLUP {m,n} mat with (viewShape mat)
export export
detWithLUP : (Ord a, Abs a, Neg a, Fractional a) => detWithLUP : Scalar a => (mat : Matrix' n a) -> DecompLUP mat -> a
(mat : Matrix' n a) -> DecompLUP mat -> a detWithLUP mat lup =
detWithLUP {n} mat lup =
(if numSwaps lup `mod` 2 == 0 then 1 else -1) (if numSwaps lup `mod` 2 == 0 then 1 else -1)
* product (diagonal lup.lu) * product (diagonal lup.lu)
export export
det : (Ord a, Abs a, Neg a, Fractional a) => Matrix' n a -> a det : Scalar a => Matrix' n a -> a
det {n} mat with (viewShape mat) det {n} mat with (viewShape mat)
det {n=0} mat | Shape [0,0] = 1 det {n=0} mat | Shape [0,0] = 1
det {n=1} mat | Shape [1,1] = mat!![0,0] det {n=1} mat | Shape [1,1] = mat!![0,0]
det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c det {n=2} mat | Shape [2,2] = let [a,b,c,d] = elements mat in a*d - b*c
_ | Shape [n,n] = detWithLUP mat (decompLUP mat) _ | Shape [n,n] = detWithLUP mat (decompLUP mat)
solveWithLUP : Scalar a => (mat : Matrix m n a) -> DecompLUP mat ->
Vector m a -> Maybe (Vector n a)
solveWithLUP mat lup b = ?h

View file

@ -1,7 +1,7 @@
module Data.NumIdr.Scalar module Data.NumIdr.Scalar
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply import Data.NumIdr.Interfaces
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import public Data.NumIdr.Array import public Data.NumIdr.Array

View file

@ -1,7 +1,8 @@
module Data.NumIdr.Vector module Data.NumIdr.Vector
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply import Data.Permutation
import Data.NumIdr.Interfaces
import public Data.NumIdr.Array import public Data.NumIdr.Array
%default total %default total
@ -117,7 +118,7 @@ export
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Swizzling -- Swizzling & permuting elements
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -130,6 +131,14 @@ swizzle p v = rewrite sym (lengthCorrect p)
) )
export
swapCoords : (i,j : Fin n) -> Vector n a -> Vector n a
swapCoords = swapInAxis 0
export
permuteCoords : Permutation n -> Vector n a -> Vector n a
permuteCoords = permuteInAxis 0
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Vector operations -- Vector operations

View file

@ -2,8 +2,9 @@ module Data.Permutation
import Data.List import Data.List
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply import Data.NumIdr.Interfaces
%default total
export export
record Permutation n where record Permutation n where