From c71cd953a0d73f9b9d4fa5470e3727fa2bebe14a Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Thu, 23 Jun 2022 19:10:47 -0400 Subject: [PATCH] Implement matrix multiplication --- src/Data/NumIdr/Array/Array.idr | 20 +++++++++++++++----- src/Data/NumIdr/Matrix.idr | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 47b6bce..8fbca93 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -237,7 +237,8 @@ export export indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a -indexUpdate is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation sts is) f arr) +indexUpdate is f (MkArray ord sts s arr) = + MkArray ord sts s (updateAt (getLocation sts is) f arr) export indexSet : Coords s -> a -> Array s a -> Array s a @@ -247,7 +248,7 @@ indexSet is = indexUpdate is . const export indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a indexNB is arr with (viewShape arr) - _ | Shape s = if concat @{All} $ zipWith (<) s (shape arr) + _ | Shape s = if concat @{All} $ zipWith (<) is s then Just $ index (getLocation' (strides arr) is) (getPrim arr) else Nothing @@ -257,7 +258,8 @@ export export indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a -indexUpdateNB is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation' sts is) f arr) +indexUpdateNB is f (MkArray ord sts s arr) = + MkArray ord sts s (updateAt (getLocation' sts is) f arr) export indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a @@ -323,8 +325,16 @@ resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB) export -enumerate' : Array {rk} s a -> List (Vect rk Nat, a) -enumerate' (MkArray _ sts sh p) = +-- TODO: Come up with a solution that doesn't use `believe_me` or trip over some +-- weird bug in the type-checker +resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) => + Array {rk} s a -> Array s' a +resizeLTE s' arr = resize s' (believe_me ()) arr + + +export +enumerateNB : Array {rk} s a -> List (Vect rk Nat, a) +enumerateNB (MkArray _ sts sh p) = map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh) diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index b4673cd..9e33de6 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -107,3 +107,27 @@ export kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a kronecker a b with (viewShape a, viewShape b) _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b) + + +-------------------------------------------------------------------------------- +-- Matrix multiplication +-------------------------------------------------------------------------------- + + +export +Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where + mat *. v with (viewShape mat) + _ | Shape [m,n] = fromFunction [m] + (\[i] => foldMap @{%search} @{Additive} + (\j => mat !! [i,j] * v !! j) range) + +export +Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where + m1 *. m2 with (viewShape m1, viewShape m2) + _ | (Shape [m,n], Shape [n,p]) = fromFunction [m,p] + (\[i,j] => foldMap @{%search} @{Additive} + (\k => m1 !! [i,k] * m2 !! [k,j]) range) + +export +{n : _} -> Num a => MultNeutral (Matrix' n a) where + neutral = identity