Implement matrix multiplication

This commit is contained in:
Kiana Sheibani 2022-06-23 19:10:47 -04:00
parent a0d9c766c0
commit c71cd953a0
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 39 additions and 5 deletions

View file

@ -237,7 +237,8 @@ export
export export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
indexUpdate is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation sts is) f arr) indexUpdate is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation sts is) f arr)
export export
indexSet : Coords s -> a -> Array s a -> Array s a indexSet : Coords s -> a -> Array s a -> Array s a
@ -247,7 +248,7 @@ indexSet is = indexUpdate is . const
export export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is arr with (viewShape arr) indexNB is arr with (viewShape arr)
_ | Shape s = if concat @{All} $ zipWith (<) s (shape arr) _ | Shape s = if concat @{All} $ zipWith (<) is s
then Just $ index (getLocation' (strides arr) is) (getPrim arr) then Just $ index (getLocation' (strides arr) is) (getPrim arr)
else Nothing else Nothing
@ -257,7 +258,8 @@ export
export export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
indexUpdateNB is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation' sts is) f arr) indexUpdateNB is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation' sts is) f arr)
export export
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
@ -323,8 +325,16 @@ resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a
resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB) resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB)
export export
enumerate' : Array {rk} s a -> List (Vect rk Nat, a) -- TODO: Come up with a solution that doesn't use `believe_me` or trip over some
enumerate' (MkArray _ sts sh p) = -- weird bug in the type-checker
resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
Array {rk} s a -> Array s' a
resizeLTE s' arr = resize s' (believe_me ()) arr
export
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
enumerateNB (MkArray _ sts sh p) =
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh) map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)

View file

@ -107,3 +107,27 @@ export
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
kronecker a b with (viewShape a, viewShape b) kronecker a b with (viewShape a, viewShape b)
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b) _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b)
--------------------------------------------------------------------------------
-- Matrix multiplication
--------------------------------------------------------------------------------
export
Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where
mat *. v with (viewShape mat)
_ | Shape [m,n] = fromFunction [m]
(\[i] => foldMap @{%search} @{Additive}
(\j => mat !! [i,j] * v !! j) range)
export
Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where
m1 *. m2 with (viewShape m1, viewShape m2)
_ | (Shape [m,n], Shape [n,p]) = fromFunction [m,p]
(\[i,j] => foldMap @{%search} @{Additive}
(\k => m1 !! [i,k] * m2 !! [k,j]) range)
export
{n : _} -> Num a => MultNeutral (Matrix' n a) where
neutral = identity