Create Data.NumIdr.Multiply
This commit is contained in:
parent
97d1bdb538
commit
87d8814c38
|
@ -4,12 +4,17 @@ import Data.List
|
||||||
import Data.List1
|
import Data.List1
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.Zippable
|
import Data.Zippable
|
||||||
|
import Data.NumIdr.Multiply
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import Data.NumIdr.Array.Order
|
import Data.NumIdr.Array.Order
|
||||||
import Data.NumIdr.Array.Coords
|
import Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
infix 2 !!
|
||||||
|
infix 2 !?
|
||||||
|
infixl 3 !!..
|
||||||
|
infix 3 !?..
|
||||||
|
|
||||||
||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional
|
||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional
|
||||||
||| grid of values, where `n` is a value known as the *rank* of the array. Arrays
|
||| grid of values, where `n` is a value known as the *rank* of the array. Arrays
|
||||||
|
@ -201,11 +206,6 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||||
-- Indexing
|
-- Indexing
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
infix 2 !!
|
|
||||||
infix 2 !?
|
|
||||||
infixl 3 !!..
|
|
||||||
infix 3 !?..
|
|
||||||
|
|
||||||
|
|
||||||
||| Index the array using the given `Coords` object.
|
||| Index the array using the given `Coords` object.
|
||||||
export
|
export
|
||||||
|
@ -290,7 +290,7 @@ enumerate arr = map (\is => (is, index is arr))
|
||||||
|
|
||||||
export
|
export
|
||||||
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||||
Array (replaceAt axis (index axis s + d) s) a
|
Array (updateAt axis (+d) s) a
|
||||||
stack axis a b = let sA = shape a
|
stack axis a b = let sA = shape a
|
||||||
sB = shape b
|
sB = shape b
|
||||||
dA = index axis sA
|
dA = index axis sA
|
||||||
|
@ -374,6 +374,10 @@ Traversable (Array s) where
|
||||||
map (MkArray ord sts s) (traverse f arr)
|
map (MkArray ord sts s) (traverse f arr)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Cast a b => Cast (Array s a) (Array s b) where
|
||||||
|
cast = map cast
|
||||||
|
|
||||||
export
|
export
|
||||||
Eq a => Eq (Array s a) where
|
Eq a => Eq (Array s a) where
|
||||||
a == b = if getOrder a == getOrder b
|
a == b = if getOrder a == getOrder b
|
||||||
|
@ -398,6 +402,27 @@ export
|
||||||
|
|
||||||
fromInteger = repeat s . fromInteger
|
fromInteger = repeat s . fromInteger
|
||||||
|
|
||||||
|
export
|
||||||
|
{s : _} -> Neg a => Neg (Array s a) where
|
||||||
|
negate = map negate
|
||||||
|
(-) = zipWith (-)
|
||||||
|
|
||||||
|
{s : _} -> Fractional a => Fractional (Array s a) where
|
||||||
|
recip = map recip
|
||||||
|
(/) = zipWith (/)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Num a => Mult a (Array {rk} s a) where
|
||||||
|
Result = Array {rk} s a
|
||||||
|
(*.) x = map (*x)
|
||||||
|
|
||||||
|
export
|
||||||
|
Num a => Mult (Array {rk} s a) a where
|
||||||
|
Result = Array {rk} s a
|
||||||
|
(*.) = flip (*.)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Show a => Show (Array s a) where
|
Show a => Show (Array s a) where
|
||||||
|
|
|
@ -116,4 +116,4 @@ namespace CoordsRange
|
||||||
go [] = pure []
|
go [] = pure []
|
||||||
go (r :: rs) = [| (case cRangeToBounds r of
|
go (r :: rs) = [| (case cRangeToBounds r of
|
||||||
Left x => pure x
|
Left x => pure x
|
||||||
Right (x,y) => [x..pred y]) :: go rs |]
|
Right (x,y) => [x,S x..pred y]) :: go rs |]
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
module Data.NumIdr.Matrix
|
module Data.NumIdr.Matrix
|
||||||
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.NumIdr.Multiply
|
||||||
import public Data.NumIdr.Array
|
import public Data.NumIdr.Array
|
||||||
|
import Data.NumIdr.Vector
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
@ -10,6 +12,9 @@ public export
|
||||||
Matrix : Nat -> Nat -> Type -> Type
|
Matrix : Nat -> Nat -> Type -> Type
|
||||||
Matrix m n = Array [m,n]
|
Matrix m n = Array [m,n]
|
||||||
|
|
||||||
|
public export
|
||||||
|
Matrix' : Nat -> Type -> Type
|
||||||
|
Matrix' n = Matrix n n
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -29,13 +34,37 @@ matrix = matrix' COrder
|
||||||
export
|
export
|
||||||
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
||||||
repeatDiag d o = fromFunction [m,n]
|
repeatDiag d o = fromFunction [m,n]
|
||||||
(\[i,j] => if finToNat i == finToNat j then d else o)
|
(\[i,j] => if i `eq` j then d else o)
|
||||||
|
where
|
||||||
|
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Bool
|
||||||
|
eq FZ FZ = True
|
||||||
|
eq (FS x) (FS y) = eq x y
|
||||||
|
eq FZ (FS _) = False
|
||||||
|
eq (FS _) FZ = False
|
||||||
|
|
||||||
|
export
|
||||||
|
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
||||||
|
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||||
|
where
|
||||||
|
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
|
||||||
|
eq FZ FZ = Just FZ
|
||||||
|
eq (FS x) (FS y) = map FS (eq x y)
|
||||||
|
eq FZ (FS _) = Nothing
|
||||||
|
eq (FS _) FZ = Nothing
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
identity : Num a => {n : _} -> Matrix n n a
|
identity : Num a => {n : _} -> Matrix n n a
|
||||||
identity = repeatDiag 1 0
|
identity = repeatDiag 1 0
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
scaling : Num a => {n : _} -> a -> Matrix n n a
|
||||||
|
scaling x = repeatDiag x 0
|
||||||
|
|
||||||
|
export
|
||||||
|
rotation2D : Double -> Matrix' 2 Double
|
||||||
|
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -47,7 +76,29 @@ export
|
||||||
index : Fin m -> Fin n -> Matrix m n a -> a
|
index : Fin m -> Fin n -> Matrix m n a -> a
|
||||||
index m n = index [m,n]
|
index m n = index [m,n]
|
||||||
|
|
||||||
|
export
|
||||||
|
indexMaybe : Nat -> Nat -> Matrix m n a -> Maybe a
|
||||||
|
indexMaybe m n = indexMaybe [m,n]
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
getRow : Fin m -> Matrix m n a -> Vector n a
|
||||||
|
getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat
|
||||||
|
|
||||||
|
export
|
||||||
|
getColumn : Fin n -> Matrix m n a -> Vector m a
|
||||||
|
getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Operations
|
-- Operations
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
export
|
||||||
|
vstack : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
||||||
|
vstack = stack 0
|
||||||
|
|
||||||
|
export
|
||||||
|
hstack : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||||
|
hstack = stack 1
|
||||||
|
|
||||||
|
|
13
src/Data/NumIdr/Multiply.idr
Normal file
13
src/Data/NumIdr/Multiply.idr
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
module Data.NumIdr.Multiply
|
||||||
|
|
||||||
|
%default total
|
||||||
|
|
||||||
|
|
||||||
|
infixr 7 *.
|
||||||
|
|
||||||
|
||| A generalized multiplication/transformation operator. This interface is
|
||||||
|
||| necessary since the standard multiplication operator is homogenous.
|
||||||
|
public export
|
||||||
|
interface Mult a b where
|
||||||
|
0 Result : Type
|
||||||
|
(*.) : a -> b -> Result
|
|
@ -1,6 +1,7 @@
|
||||||
module Data.NumIdr.Scalar
|
module Data.NumIdr.Scalar
|
||||||
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.NumIdr.Multiply
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import public Data.NumIdr.Array
|
import public Data.NumIdr.Array
|
||||||
|
|
||||||
|
@ -21,3 +22,9 @@ scalar x = fromVect _ [x]
|
||||||
export
|
export
|
||||||
unwrap : Scalar a -> a
|
unwrap : Scalar a -> a
|
||||||
unwrap = index 0 . getPrim
|
unwrap = index 0 . getPrim
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
Num a => Mult (Scalar a) (Scalar a) where
|
||||||
|
Result = Scalar a
|
||||||
|
(*.) = (*)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
module Data.NumIdr.Vector
|
module Data.NumIdr.Vector
|
||||||
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.NumIdr.Multiply
|
||||||
import public Data.NumIdr.Array
|
import public Data.NumIdr.Array
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
@ -46,6 +47,18 @@ export
|
||||||
index : Fin n -> Vector n a -> a
|
index : Fin n -> Vector n a -> a
|
||||||
index n = Array.index [n]
|
index n = Array.index [n]
|
||||||
|
|
||||||
|
export
|
||||||
|
(!!) : Vector n a -> Fin n -> a
|
||||||
|
(!!) = flip index
|
||||||
|
|
||||||
|
export
|
||||||
|
indexMaybe : Nat -> Vector n a -> Maybe a
|
||||||
|
indexMaybe n = Array.indexMaybe [n]
|
||||||
|
|
||||||
|
export
|
||||||
|
(!?) : Vector n a -> Nat -> Maybe a
|
||||||
|
(!?) = flip indexMaybe
|
||||||
|
|
||||||
|
|
||||||
-- Named projections
|
-- Named projections
|
||||||
export
|
export
|
||||||
|
@ -80,6 +93,11 @@ swizzle p v = rewrite sym (lengthCorrect p)
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
concat : Vector m a -> Vector n a -> Vector (m + n) a
|
||||||
|
concat = stack 0
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
dot : Num a => Vector n a -> Vector n a -> a
|
dot : Num a => Vector n a -> Vector n a -> a
|
||||||
dot = sum .: zipWith (*)
|
dot = sum .: zipWith (*)
|
||||||
|
|
Loading…
Reference in a new issue