Create Data.NumIdr.Multiply

This commit is contained in:
Kiana Sheibani 2022-05-26 18:50:07 -04:00
parent 97d1bdb538
commit 87d8814c38
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
6 changed files with 122 additions and 8 deletions

View file

@ -4,12 +4,17 @@ import Data.List
import Data.List1 import Data.List1
import Data.Vect import Data.Vect
import Data.Zippable import Data.Zippable
import Data.NumIdr.Multiply
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order import Data.NumIdr.Array.Order
import Data.NumIdr.Array.Coords import Data.NumIdr.Array.Coords
%default total %default total
infix 2 !!
infix 2 !?
infixl 3 !!..
infix 3 !?..
||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional ||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional
||| grid of values, where `n` is a value known as the *rank* of the array. Arrays ||| grid of values, where `n` is a value known as the *rank* of the array. Arrays
@ -201,11 +206,6 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
-- Indexing -- Indexing
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
infix 2 !!
infix 2 !?
infixl 3 !!..
infix 3 !?..
||| Index the array using the given `Coords` object. ||| Index the array using the given `Coords` object.
export export
@ -290,7 +290,7 @@ enumerate arr = map (\is => (is, index is arr))
export export
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
Array (replaceAt axis (index axis s + d) s) a Array (updateAt axis (+d) s) a
stack axis a b = let sA = shape a stack axis a b = let sA = shape a
sB = shape b sB = shape b
dA = index axis sA dA = index axis sA
@ -374,6 +374,10 @@ Traversable (Array s) where
map (MkArray ord sts s) (traverse f arr) map (MkArray ord sts s) (traverse f arr)
export
Cast a b => Cast (Array s a) (Array s b) where
cast = map cast
export export
Eq a => Eq (Array s a) where Eq a => Eq (Array s a) where
a == b = if getOrder a == getOrder b a == b = if getOrder a == getOrder b
@ -398,6 +402,27 @@ export
fromInteger = repeat s . fromInteger fromInteger = repeat s . fromInteger
export
{s : _} -> Neg a => Neg (Array s a) where
negate = map negate
(-) = zipWith (-)
{s : _} -> Fractional a => Fractional (Array s a) where
recip = map recip
(/) = zipWith (/)
export
Num a => Mult a (Array {rk} s a) where
Result = Array {rk} s a
(*.) x = map (*x)
export
Num a => Mult (Array {rk} s a) a where
Result = Array {rk} s a
(*.) = flip (*.)
export export
Show a => Show (Array s a) where Show a => Show (Array s a) where

View file

@ -116,4 +116,4 @@ namespace CoordsRange
go [] = pure [] go [] = pure []
go (r :: rs) = [| (case cRangeToBounds r of go (r :: rs) = [| (case cRangeToBounds r of
Left x => pure x Left x => pure x
Right (x,y) => [x..pred y]) :: go rs |] Right (x,y) => [x,S x..pred y]) :: go rs |]

View file

@ -1,7 +1,9 @@
module Data.NumIdr.Matrix module Data.NumIdr.Matrix
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply
import public Data.NumIdr.Array import public Data.NumIdr.Array
import Data.NumIdr.Vector
%default total %default total
@ -10,6 +12,9 @@ public export
Matrix : Nat -> Nat -> Type -> Type Matrix : Nat -> Nat -> Type -> Type
Matrix m n = Array [m,n] Matrix m n = Array [m,n]
public export
Matrix' : Nat -> Type -> Type
Matrix' n = Matrix n n
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -29,13 +34,37 @@ matrix = matrix' COrder
export export
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
repeatDiag d o = fromFunction [m,n] repeatDiag d o = fromFunction [m,n]
(\[i,j] => if finToNat i == finToNat j then d else o) (\[i,j] => if i `eq` j then d else o)
where
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Bool
eq FZ FZ = True
eq (FS x) (FS y) = eq x y
eq FZ (FS _) = False
eq (FS _) FZ = False
export
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
where
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
eq FZ FZ = Just FZ
eq (FS x) (FS y) = map FS (eq x y)
eq FZ (FS _) = Nothing
eq (FS _) FZ = Nothing
export export
identity : Num a => {n : _} -> Matrix n n a identity : Num a => {n : _} -> Matrix n n a
identity = repeatDiag 1 0 identity = repeatDiag 1 0
export
scaling : Num a => {n : _} -> a -> Matrix n n a
scaling x = repeatDiag x 0
export
rotation2D : Double -> Matrix' 2 Double
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -47,7 +76,29 @@ export
index : Fin m -> Fin n -> Matrix m n a -> a index : Fin m -> Fin n -> Matrix m n a -> a
index m n = index [m,n] index m n = index [m,n]
export
indexMaybe : Nat -> Nat -> Matrix m n a -> Maybe a
indexMaybe m n = indexMaybe [m,n]
export
getRow : Fin m -> Matrix m n a -> Vector n a
getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat
export
getColumn : Fin n -> Matrix m n a -> Vector m a
getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Operations -- Operations
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
export
vstack : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
vstack = stack 0
export
hstack : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
hstack = stack 1

View file

@ -0,0 +1,13 @@
module Data.NumIdr.Multiply
%default total
infixr 7 *.
||| A generalized multiplication/transformation operator. This interface is
||| necessary since the standard multiplication operator is homogenous.
public export
interface Mult a b where
0 Result : Type
(*.) : a -> b -> Result

View file

@ -1,6 +1,7 @@
module Data.NumIdr.Scalar module Data.NumIdr.Scalar
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import public Data.NumIdr.Array import public Data.NumIdr.Array
@ -21,3 +22,9 @@ scalar x = fromVect _ [x]
export export
unwrap : Scalar a -> a unwrap : Scalar a -> a
unwrap = index 0 . getPrim unwrap = index 0 . getPrim
export
Num a => Mult (Scalar a) (Scalar a) where
Result = Scalar a
(*.) = (*)

View file

@ -1,6 +1,7 @@
module Data.NumIdr.Vector module Data.NumIdr.Vector
import Data.Vect import Data.Vect
import Data.NumIdr.Multiply
import public Data.NumIdr.Array import public Data.NumIdr.Array
%default total %default total
@ -46,6 +47,18 @@ export
index : Fin n -> Vector n a -> a index : Fin n -> Vector n a -> a
index n = Array.index [n] index n = Array.index [n]
export
(!!) : Vector n a -> Fin n -> a
(!!) = flip index
export
indexMaybe : Nat -> Vector n a -> Maybe a
indexMaybe n = Array.indexMaybe [n]
export
(!?) : Vector n a -> Nat -> Maybe a
(!?) = flip indexMaybe
-- Named projections -- Named projections
export export
@ -80,6 +93,11 @@ swizzle p v = rewrite sym (lengthCorrect p)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
export
concat : Vector m a -> Vector n a -> Vector (m + n) a
concat = stack 0
export export
dot : Num a => Vector n a -> Vector n a -> a dot : Num a => Vector n a -> Vector n a -> a
dot = sum .: zipWith (*) dot = sum .: zipWith (*)