From a0d9c766c098a2e9f9975bbcd19f318c0d7cc17b Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Thu, 23 Jun 2022 19:09:10 -0400 Subject: [PATCH] Refactor Mult and add MultNeutral --- src/Data/NumIdr/Array/Array.idr | 14 ++++++-------- src/Data/NumIdr/Matrix.idr | 4 ++-- src/Data/NumIdr/Multiply.idr | 30 +++++++++++++++++++++++++++--- src/Data/NumIdr/Vector.idr | 2 ++ 4 files changed, 37 insertions(+), 13 deletions(-) diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index cec9f81..47b6bce 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -11,10 +11,6 @@ import Data.NumIdr.Array.Coords %default total -infix 2 !! -infix 2 !? -infixl 3 !!.. -infix 3 !?.. ||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional ||| grid of values, where `n` is a value known as the *rank* of the array. Arrays @@ -223,6 +219,10 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v) -- Indexing -------------------------------------------------------------------------------- +infix 10 !! +infix 10 !? +infixl 11 !!.. +infix 11 !?.. ||| Index the array using the given `Coords` object. export @@ -462,13 +462,11 @@ export export -Num a => Mult a (Array {rk} s a) where - Result = Array {rk} s a +Num a => Mult a (Array {rk} s a) (Array s a) where (*.) x = map (*x) export -Num a => Mult (Array {rk} s a) a where - Result = Array {rk} s a +Num a => Mult (Array {rk} s a) a (Array s a) where (*.) = flip (*.) diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index c9f0d67..b4673cd 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -54,12 +54,12 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j) export -identity : Num a => {n : _} -> Matrix n n a +identity : Num a => {n : _} -> Matrix' n a identity = repeatDiag 1 0 export -scaling : Num a => {n : _} -> a -> Matrix n n a +scaling : Num a => {n : _} -> a -> Matrix' n a scaling x = repeatDiag x 0 export diff --git a/src/Data/NumIdr/Multiply.idr b/src/Data/NumIdr/Multiply.idr index 63d3308..b61521f 100644 --- a/src/Data/NumIdr/Multiply.idr +++ b/src/Data/NumIdr/Multiply.idr @@ -4,10 +4,34 @@ module Data.NumIdr.Multiply infixr 9 *. +infixr 10 ^ ||| A generalized multiplication/transformation operator. This interface is ||| necessary since the standard multiplication operator is homogenous. public export -interface Mult a b where - 0 Result : Type - (*.) : a -> b -> Result +interface Mult a b c | a,b where + (*.) : a -> b -> c + +public export +interface (Mult a a a) => MultNeutral a where + neutral : a + + +public export +[MultSemigroup] Mult a a a => Semigroup a where + (<+>) = (*.) + +public export +[MultMonoid] MultNeutral a => Monoid a using MultSemigroup where + neutral = Multiply.neutral + + +public export +power : MultNeutral a => Nat -> a -> a +power 0 _ = neutral +power 1 x = x +power (S n@(S _)) x = x *. power n x + +public export +(^) : MultNeutral a => a -> Nat -> a +(^) = flip power diff --git a/src/Data/NumIdr/Vector.idr b/src/Data/NumIdr/Vector.idr index 7e35969..301bb9d 100644 --- a/src/Data/NumIdr/Vector.idr +++ b/src/Data/NumIdr/Vector.idr @@ -56,6 +56,8 @@ unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol] -- Indexing -------------------------------------------------------------------------------- +infix 10 !! +infix 10 !? export index : Fin n -> Vector n a -> a