Implement matrix multiplication
This commit is contained in:
parent
a0d9c766c0
commit
c71cd953a0
|
@ -237,7 +237,8 @@ export
|
||||||
|
|
||||||
export
|
export
|
||||||
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
||||||
indexUpdate is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation sts is) f arr)
|
indexUpdate is f (MkArray ord sts s arr) =
|
||||||
|
MkArray ord sts s (updateAt (getLocation sts is) f arr)
|
||||||
|
|
||||||
export
|
export
|
||||||
indexSet : Coords s -> a -> Array s a -> Array s a
|
indexSet : Coords s -> a -> Array s a -> Array s a
|
||||||
|
@ -247,7 +248,7 @@ indexSet is = indexUpdate is . const
|
||||||
export
|
export
|
||||||
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
||||||
indexNB is arr with (viewShape arr)
|
indexNB is arr with (viewShape arr)
|
||||||
_ | Shape s = if concat @{All} $ zipWith (<) s (shape arr)
|
_ | Shape s = if concat @{All} $ zipWith (<) is s
|
||||||
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
|
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
|
||||||
else Nothing
|
else Nothing
|
||||||
|
|
||||||
|
@ -257,7 +258,8 @@ export
|
||||||
|
|
||||||
export
|
export
|
||||||
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
|
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
|
||||||
indexUpdateNB is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation' sts is) f arr)
|
indexUpdateNB is f (MkArray ord sts s arr) =
|
||||||
|
MkArray ord sts s (updateAt (getLocation' sts is) f arr)
|
||||||
|
|
||||||
export
|
export
|
||||||
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
|
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
|
||||||
|
@ -323,8 +325,16 @@ resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a
|
||||||
resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB)
|
resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB)
|
||||||
|
|
||||||
export
|
export
|
||||||
enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
|
-- TODO: Come up with a solution that doesn't use `believe_me` or trip over some
|
||||||
enumerate' (MkArray _ sts sh p) =
|
-- weird bug in the type-checker
|
||||||
|
resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
|
||||||
|
Array {rk} s a -> Array s' a
|
||||||
|
resizeLTE s' arr = resize s' (believe_me ()) arr
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
|
||||||
|
enumerateNB (MkArray _ sts sh p) =
|
||||||
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -107,3 +107,27 @@ export
|
||||||
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
||||||
kronecker a b with (viewShape a, viewShape b)
|
kronecker a b with (viewShape a, viewShape b)
|
||||||
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b)
|
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b)
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Matrix multiplication
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Num a => Mult (Matrix m n a) (Vector n a) (Vector m a) where
|
||||||
|
mat *. v with (viewShape mat)
|
||||||
|
_ | Shape [m,n] = fromFunction [m]
|
||||||
|
(\[i] => foldMap @{%search} @{Additive}
|
||||||
|
(\j => mat !! [i,j] * v !! j) range)
|
||||||
|
|
||||||
|
export
|
||||||
|
Num a => Mult (Matrix m n a) (Matrix n p a) (Matrix m p a) where
|
||||||
|
m1 *. m2 with (viewShape m1, viewShape m2)
|
||||||
|
_ | (Shape [m,n], Shape [n,p]) = fromFunction [m,p]
|
||||||
|
(\[i,j] => foldMap @{%search} @{Additive}
|
||||||
|
(\k => m1 !! [i,k] * m2 !! [k,j]) range)
|
||||||
|
|
||||||
|
export
|
||||||
|
{n : _} -> Num a => MultNeutral (Matrix' n a) where
|
||||||
|
neutral = identity
|
||||||
|
|
Loading…
Reference in a new issue