From 87d8814c3833e44fefbdc377075cd883e9c4ff3c Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Thu, 26 May 2022 18:50:07 -0400 Subject: [PATCH] Create Data.NumIdr.Multiply --- src/Data/NumIdr/Array/Array.idr | 37 ++++++++++++++++++---- src/Data/NumIdr/Array/Coords.idr | 2 +- src/Data/NumIdr/Matrix.idr | 53 +++++++++++++++++++++++++++++++- src/Data/NumIdr/Multiply.idr | 13 ++++++++ src/Data/NumIdr/Scalar.idr | 7 +++++ src/Data/NumIdr/Vector.idr | 18 +++++++++++ 6 files changed, 122 insertions(+), 8 deletions(-) create mode 100644 src/Data/NumIdr/Multiply.idr diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 3e16731..59677f9 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -4,12 +4,17 @@ import Data.List import Data.List1 import Data.Vect import Data.Zippable +import Data.NumIdr.Multiply import Data.NumIdr.PrimArray import Data.NumIdr.Array.Order import Data.NumIdr.Array.Coords %default total +infix 2 !! +infix 2 !? +infixl 3 !!.. +infix 3 !?.. ||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional ||| grid of values, where `n` is a value known as the *rank* of the array. Arrays @@ -201,11 +206,6 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v) -- Indexing -------------------------------------------------------------------------------- -infix 2 !! -infix 2 !? -infixl 3 !!.. -infix 3 !?.. - ||| Index the array using the given `Coords` object. export @@ -290,7 +290,7 @@ enumerate arr = map (\is => (is, index is arr)) export stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> - Array (replaceAt axis (index axis s + d) s) a + Array (updateAt axis (+d) s) a stack axis a b = let sA = shape a sB = shape b dA = index axis sA @@ -374,6 +374,10 @@ Traversable (Array s) where map (MkArray ord sts s) (traverse f arr) +export +Cast a b => Cast (Array s a) (Array s b) where + cast = map cast + export Eq a => Eq (Array s a) where a == b = if getOrder a == getOrder b @@ -398,6 +402,27 @@ export fromInteger = repeat s . fromInteger +export +{s : _} -> Neg a => Neg (Array s a) where + negate = map negate + (-) = zipWith (-) + +{s : _} -> Fractional a => Fractional (Array s a) where + recip = map recip + (/) = zipWith (/) + + +export +Num a => Mult a (Array {rk} s a) where + Result = Array {rk} s a + (*.) x = map (*x) + +export +Num a => Mult (Array {rk} s a) a where + Result = Array {rk} s a + (*.) = flip (*.) + + export Show a => Show (Array s a) where diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 954ede0..db4b8e7 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -116,4 +116,4 @@ namespace CoordsRange go [] = pure [] go (r :: rs) = [| (case cRangeToBounds r of Left x => pure x - Right (x,y) => [x..pred y]) :: go rs |] + Right (x,y) => [x,S x..pred y]) :: go rs |] diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index 269631a..41bc346 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -1,7 +1,9 @@ module Data.NumIdr.Matrix import Data.Vect +import Data.NumIdr.Multiply import public Data.NumIdr.Array +import Data.NumIdr.Vector %default total @@ -10,6 +12,9 @@ public export Matrix : Nat -> Nat -> Type -> Type Matrix m n = Array [m,n] +public export +Matrix' : Nat -> Type -> Type +Matrix' n = Matrix n n -------------------------------------------------------------------------------- @@ -29,13 +34,37 @@ matrix = matrix' COrder export repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a repeatDiag d o = fromFunction [m,n] - (\[i,j] => if finToNat i == finToNat j then d else o) + (\[i,j] => if i `eq` j then d else o) + where + eq : {0 m,n : Nat} -> Fin m -> Fin n -> Bool + eq FZ FZ = True + eq (FS x) (FS y) = eq x y + eq FZ (FS _) = False + eq (FS _) FZ = False + +export +fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a +fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j) + where + eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n)) + eq FZ FZ = Just FZ + eq (FS x) (FS y) = map FS (eq x y) + eq FZ (FS _) = Nothing + eq (FS _) FZ = Nothing + export identity : Num a => {n : _} -> Matrix n n a identity = repeatDiag 1 0 +export +scaling : Num a => {n : _} -> a -> Matrix n n a +scaling x = repeatDiag x 0 + +export +rotation2D : Double -> Matrix' 2 Double +rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]] -------------------------------------------------------------------------------- @@ -47,7 +76,29 @@ export index : Fin m -> Fin n -> Matrix m n a -> a index m n = index [m,n] +export +indexMaybe : Nat -> Nat -> Matrix m n a -> Maybe a +indexMaybe m n = indexMaybe [m,n] + + +export +getRow : Fin m -> Matrix m n a -> Vector n a +getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat + +export +getColumn : Fin n -> Matrix m n a -> Vector m a +getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat + -------------------------------------------------------------------------------- -- Operations -------------------------------------------------------------------------------- + +export +vstack : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a +vstack = stack 0 + +export +hstack : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a +hstack = stack 1 + diff --git a/src/Data/NumIdr/Multiply.idr b/src/Data/NumIdr/Multiply.idr new file mode 100644 index 0000000..fb828b9 --- /dev/null +++ b/src/Data/NumIdr/Multiply.idr @@ -0,0 +1,13 @@ +module Data.NumIdr.Multiply + +%default total + + +infixr 7 *. + +||| A generalized multiplication/transformation operator. This interface is +||| necessary since the standard multiplication operator is homogenous. +public export +interface Mult a b where + 0 Result : Type + (*.) : a -> b -> Result diff --git a/src/Data/NumIdr/Scalar.idr b/src/Data/NumIdr/Scalar.idr index 8d0e1ac..64db639 100644 --- a/src/Data/NumIdr/Scalar.idr +++ b/src/Data/NumIdr/Scalar.idr @@ -1,6 +1,7 @@ module Data.NumIdr.Scalar import Data.Vect +import Data.NumIdr.Multiply import Data.NumIdr.PrimArray import public Data.NumIdr.Array @@ -21,3 +22,9 @@ scalar x = fromVect _ [x] export unwrap : Scalar a -> a unwrap = index 0 . getPrim + + +export +Num a => Mult (Scalar a) (Scalar a) where + Result = Scalar a + (*.) = (*) diff --git a/src/Data/NumIdr/Vector.idr b/src/Data/NumIdr/Vector.idr index dc012ab..64ec077 100644 --- a/src/Data/NumIdr/Vector.idr +++ b/src/Data/NumIdr/Vector.idr @@ -1,6 +1,7 @@ module Data.NumIdr.Vector import Data.Vect +import Data.NumIdr.Multiply import public Data.NumIdr.Array %default total @@ -46,6 +47,18 @@ export index : Fin n -> Vector n a -> a index n = Array.index [n] +export +(!!) : Vector n a -> Fin n -> a +(!!) = flip index + +export +indexMaybe : Nat -> Vector n a -> Maybe a +indexMaybe n = Array.indexMaybe [n] + +export +(!?) : Vector n a -> Nat -> Maybe a +(!?) = flip indexMaybe + -- Named projections export @@ -80,6 +93,11 @@ swizzle p v = rewrite sym (lengthCorrect p) -------------------------------------------------------------------------------- +export +concat : Vector m a -> Vector n a -> Vector (m + n) a +concat = stack 0 + + export dot : Num a => Vector n a -> Vector n a -> a dot = sum .: zipWith (*)