Create Data.NumIdr.Multiply
This commit is contained in:
parent
97d1bdb538
commit
87d8814c38
|
@ -4,12 +4,17 @@ import Data.List
|
|||
import Data.List1
|
||||
import Data.Vect
|
||||
import Data.Zippable
|
||||
import Data.NumIdr.Multiply
|
||||
import Data.NumIdr.PrimArray
|
||||
import Data.NumIdr.Array.Order
|
||||
import Data.NumIdr.Array.Coords
|
||||
|
||||
%default total
|
||||
|
||||
infix 2 !!
|
||||
infix 2 !?
|
||||
infixl 3 !!..
|
||||
infix 3 !?..
|
||||
|
||||
||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional
|
||||
||| grid of values, where `n` is a value known as the *rank* of the array. Arrays
|
||||
|
@ -201,11 +206,6 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
|||
-- Indexing
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
infix 2 !!
|
||||
infix 2 !?
|
||||
infixl 3 !!..
|
||||
infix 3 !?..
|
||||
|
||||
|
||||
||| Index the array using the given `Coords` object.
|
||||
export
|
||||
|
@ -290,7 +290,7 @@ enumerate arr = map (\is => (is, index is arr))
|
|||
|
||||
export
|
||||
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||
Array (replaceAt axis (index axis s + d) s) a
|
||||
Array (updateAt axis (+d) s) a
|
||||
stack axis a b = let sA = shape a
|
||||
sB = shape b
|
||||
dA = index axis sA
|
||||
|
@ -374,6 +374,10 @@ Traversable (Array s) where
|
|||
map (MkArray ord sts s) (traverse f arr)
|
||||
|
||||
|
||||
export
|
||||
Cast a b => Cast (Array s a) (Array s b) where
|
||||
cast = map cast
|
||||
|
||||
export
|
||||
Eq a => Eq (Array s a) where
|
||||
a == b = if getOrder a == getOrder b
|
||||
|
@ -398,6 +402,27 @@ export
|
|||
|
||||
fromInteger = repeat s . fromInteger
|
||||
|
||||
export
|
||||
{s : _} -> Neg a => Neg (Array s a) where
|
||||
negate = map negate
|
||||
(-) = zipWith (-)
|
||||
|
||||
{s : _} -> Fractional a => Fractional (Array s a) where
|
||||
recip = map recip
|
||||
(/) = zipWith (/)
|
||||
|
||||
|
||||
export
|
||||
Num a => Mult a (Array {rk} s a) where
|
||||
Result = Array {rk} s a
|
||||
(*.) x = map (*x)
|
||||
|
||||
export
|
||||
Num a => Mult (Array {rk} s a) a where
|
||||
Result = Array {rk} s a
|
||||
(*.) = flip (*.)
|
||||
|
||||
|
||||
|
||||
export
|
||||
Show a => Show (Array s a) where
|
||||
|
|
|
@ -116,4 +116,4 @@ namespace CoordsRange
|
|||
go [] = pure []
|
||||
go (r :: rs) = [| (case cRangeToBounds r of
|
||||
Left x => pure x
|
||||
Right (x,y) => [x..pred y]) :: go rs |]
|
||||
Right (x,y) => [x,S x..pred y]) :: go rs |]
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
module Data.NumIdr.Matrix
|
||||
|
||||
import Data.Vect
|
||||
import Data.NumIdr.Multiply
|
||||
import public Data.NumIdr.Array
|
||||
import Data.NumIdr.Vector
|
||||
|
||||
%default total
|
||||
|
||||
|
@ -10,6 +12,9 @@ public export
|
|||
Matrix : Nat -> Nat -> Type -> Type
|
||||
Matrix m n = Array [m,n]
|
||||
|
||||
public export
|
||||
Matrix' : Nat -> Type -> Type
|
||||
Matrix' n = Matrix n n
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -29,13 +34,37 @@ matrix = matrix' COrder
|
|||
export
|
||||
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
||||
repeatDiag d o = fromFunction [m,n]
|
||||
(\[i,j] => if finToNat i == finToNat j then d else o)
|
||||
(\[i,j] => if i `eq` j then d else o)
|
||||
where
|
||||
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Bool
|
||||
eq FZ FZ = True
|
||||
eq (FS x) (FS y) = eq x y
|
||||
eq FZ (FS _) = False
|
||||
eq (FS _) FZ = False
|
||||
|
||||
export
|
||||
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
||||
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||
where
|
||||
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Maybe (Fin (minimum m n))
|
||||
eq FZ FZ = Just FZ
|
||||
eq (FS x) (FS y) = map FS (eq x y)
|
||||
eq FZ (FS _) = Nothing
|
||||
eq (FS _) FZ = Nothing
|
||||
|
||||
|
||||
export
|
||||
identity : Num a => {n : _} -> Matrix n n a
|
||||
identity = repeatDiag 1 0
|
||||
|
||||
|
||||
export
|
||||
scaling : Num a => {n : _} -> a -> Matrix n n a
|
||||
scaling x = repeatDiag x 0
|
||||
|
||||
export
|
||||
rotation2D : Double -> Matrix' 2 Double
|
||||
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -47,7 +76,29 @@ export
|
|||
index : Fin m -> Fin n -> Matrix m n a -> a
|
||||
index m n = index [m,n]
|
||||
|
||||
export
|
||||
indexMaybe : Nat -> Nat -> Matrix m n a -> Maybe a
|
||||
indexMaybe m n = indexMaybe [m,n]
|
||||
|
||||
|
||||
export
|
||||
getRow : Fin m -> Matrix m n a -> Vector n a
|
||||
getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat
|
||||
|
||||
export
|
||||
getColumn : Fin n -> Matrix m n a -> Vector m a
|
||||
getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Operations
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
export
|
||||
vstack : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
||||
vstack = stack 0
|
||||
|
||||
export
|
||||
hstack : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||
hstack = stack 1
|
||||
|
||||
|
|
13
src/Data/NumIdr/Multiply.idr
Normal file
13
src/Data/NumIdr/Multiply.idr
Normal file
|
@ -0,0 +1,13 @@
|
|||
module Data.NumIdr.Multiply
|
||||
|
||||
%default total
|
||||
|
||||
|
||||
infixr 7 *.
|
||||
|
||||
||| A generalized multiplication/transformation operator. This interface is
|
||||
||| necessary since the standard multiplication operator is homogenous.
|
||||
public export
|
||||
interface Mult a b where
|
||||
0 Result : Type
|
||||
(*.) : a -> b -> Result
|
|
@ -1,6 +1,7 @@
|
|||
module Data.NumIdr.Scalar
|
||||
|
||||
import Data.Vect
|
||||
import Data.NumIdr.Multiply
|
||||
import Data.NumIdr.PrimArray
|
||||
import public Data.NumIdr.Array
|
||||
|
||||
|
@ -21,3 +22,9 @@ scalar x = fromVect _ [x]
|
|||
export
|
||||
unwrap : Scalar a -> a
|
||||
unwrap = index 0 . getPrim
|
||||
|
||||
|
||||
export
|
||||
Num a => Mult (Scalar a) (Scalar a) where
|
||||
Result = Scalar a
|
||||
(*.) = (*)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
module Data.NumIdr.Vector
|
||||
|
||||
import Data.Vect
|
||||
import Data.NumIdr.Multiply
|
||||
import public Data.NumIdr.Array
|
||||
|
||||
%default total
|
||||
|
@ -46,6 +47,18 @@ export
|
|||
index : Fin n -> Vector n a -> a
|
||||
index n = Array.index [n]
|
||||
|
||||
export
|
||||
(!!) : Vector n a -> Fin n -> a
|
||||
(!!) = flip index
|
||||
|
||||
export
|
||||
indexMaybe : Nat -> Vector n a -> Maybe a
|
||||
indexMaybe n = Array.indexMaybe [n]
|
||||
|
||||
export
|
||||
(!?) : Vector n a -> Nat -> Maybe a
|
||||
(!?) = flip indexMaybe
|
||||
|
||||
|
||||
-- Named projections
|
||||
export
|
||||
|
@ -80,6 +93,11 @@ swizzle p v = rewrite sym (lengthCorrect p)
|
|||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
export
|
||||
concat : Vector m a -> Vector n a -> Vector (m + n) a
|
||||
concat = stack 0
|
||||
|
||||
|
||||
export
|
||||
dot : Num a => Vector n a -> Vector n a -> a
|
||||
dot = sum .: zipWith (*)
|
||||
|
|
Loading…
Reference in a new issue