Add vector & matrix modules
This commit is contained in:
parent
1e7660b1f2
commit
861f1e29f2
|
@ -1,6 +1,8 @@
|
||||||
# NumIdr
|
# NumIdr
|
||||||
|
|
||||||
NumIdr is a linear algebra and matrix math library for Idris 2. It features
|
NumIdr is a linear algebra and data manipulation library for Idris 2. It features
|
||||||
an efficient, type-safe array system, as well as utilities for working with
|
an efficient, type-safe array data structure, as well as utilities for working with
|
||||||
vector spaces. It is heavily inspired by Python's [NumPy](https://numpy.org/),
|
vector spaces and matrices. It borrows concepts heavily from Python's [NumPy](https://numpy.org/),
|
||||||
as well as Rust's [nalgebra](https://www.nalgebra.org/).
|
as well as Rust's [nalgebra](https://www.nalgebra.org/).
|
||||||
|
|
||||||
|
The name is pronounced like "num-idge".
|
||||||
|
|
5
src/Data/NumIdr/Array.idr
Normal file
5
src/Data/NumIdr/Array.idr
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
module Data.NumIdr.Array
|
||||||
|
|
||||||
|
import public Data.NumIdr.Array.Array
|
||||||
|
import public Data.NumIdr.Array.Coords
|
||||||
|
import public Data.NumIdr.Array.Order
|
|
@ -78,15 +78,10 @@ rank : Array s a -> Nat
|
||||||
rank = length . shape
|
rank = length . shape
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
-- Array constructors
|
|
||||||
--------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
shapeEq : (arr : Array s a) -> s = shape arr
|
shapeEq : (arr : Array s a) -> s = shape arr
|
||||||
shapeEq (MkArray _ _ _ _) = Refl
|
shapeEq (MkArray _ _ _ _) = Refl
|
||||||
|
|
||||||
|
|
||||||
-- Get a list of all coordinates
|
-- Get a list of all coordinates
|
||||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||||
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||||
|
@ -97,11 +92,35 @@ getAllCoords (Z :: s) = []
|
||||||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||||
|
|
||||||
|
|
||||||
constant' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
|
--------------------------------------------------------------------------------
|
||||||
constant' s ord x = MkArray ord (calcStrides ord s) s (create (product s) (const x))
|
-- Array constructors
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
constant : (s : Vect rk Nat) -> a -> Array s a
|
|
||||||
constant s = constant' s COrder
|
||| Create an array by repeating a single value.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ ord The order of the constructed array
|
||||||
|
export
|
||||||
|
repeat' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
|
||||||
|
repeat' s ord x = MkArray ord (calcStrides ord s) s (constant (product s) x)
|
||||||
|
|
||||||
|
||| Create an array by repeating a single value.
|
||||||
|
||| To specify the order of the array, use `repeat'`.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
export
|
||||||
|
repeat : (s : Vect rk Nat) -> a -> Array s a
|
||||||
|
repeat s = repeat' s COrder
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
zeros : Num a => (s : Vect rk Nat) -> Array s a
|
||||||
|
zeros s = repeat s 0
|
||||||
|
|
||||||
|
export
|
||||||
|
ones : Num a => (s : Vect rk Nat) -> Array s a
|
||||||
|
ones s = repeat s 1
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are arranged into the provided shape using the provided order.
|
||| are arranged into the provided shape using the provided order.
|
||||||
|
@ -113,15 +132,31 @@ fromVect' : (s : Vect rk Nat) -> (ord : Order) -> Vect (product s) a -> Array s
|
||||||
fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are assembled into the provided shape using row-major order (the last axis is the
|
||| are arranged into the provided shape using row-major order (the last axis is the
|
||||||
||| least significant).
|
||| least significant).
|
||||||
||| To specify the order of the array, see `fromVect'`.
|
||| To specify the order of the array, use `fromVect'`.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
export
|
export
|
||||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||||
fromVect s = fromVect' s COrder
|
fromVect s = fromVect' s COrder
|
||||||
|
|
||||||
|
||| Create an array by taking values from a stream.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ ord The order to interpret the elements
|
||||||
|
export
|
||||||
|
fromStream' : (s : Vect rk Nat) -> (ord : Order) -> Stream a -> Array s a
|
||||||
|
fromStream' s ord st = MkArray ord (calcStrides ord s) s (fromList $ take (product s) st)
|
||||||
|
|
||||||
|
||| Create an array by taking values from a stream.
|
||||||
|
||| To specify the order of the array, use `fromStream'`.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
export
|
||||||
|
fromStream : (s : Vect rk Nat) -> Stream a -> Array s a
|
||||||
|
fromStream s = fromStream' s COrder
|
||||||
|
|
||||||
||| Create an array given a function to generate its elements.
|
||| Create an array given a function to generate its elements.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
|
@ -292,7 +327,7 @@ Functor (Array s) where
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Applicative (Array s) where
|
{s : _} -> Applicative (Array s) where
|
||||||
pure = constant s
|
pure = repeat s
|
||||||
(<*>) = zipWith apply
|
(<*>) = zipWith apply
|
||||||
|
|
||||||
export
|
export
|
||||||
|
@ -328,7 +363,7 @@ Semigroup a => Semigroup (Array s a) where
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Monoid a => Monoid (Array s a) where
|
{s : _} -> Monoid a => Monoid (Array s a) where
|
||||||
neutral = constant s neutral
|
neutral = repeat s neutral
|
||||||
|
|
||||||
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
|
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
|
||||||
-- were moved into its own interface, this constraint could be removed.
|
-- were moved into its own interface, this constraint could be removed.
|
||||||
|
@ -338,7 +373,7 @@ export
|
||||||
(+) = zipWith (+)
|
(+) = zipWith (+)
|
||||||
(*) = zipWith (*)
|
(*) = zipWith (*)
|
||||||
|
|
||||||
fromInteger = constant s . fromInteger
|
fromInteger = repeat s . fromInteger
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
module Data.NumIdr.Array.Coords
|
module Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
|
import Data.Either
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
@ -42,11 +43,6 @@ namespace Coords
|
||||||
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
||||||
|
|
||||||
|
|
||||||
index : Coords s -> Vects s a -> a
|
|
||||||
index [] x = x
|
|
||||||
index (i::is) v = index is $ index i v
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
||||||
getLocation' = sum .: zipWith (*)
|
getLocation' = sum .: zipWith (*)
|
||||||
|
@ -88,6 +84,10 @@ namespace CoordsRange
|
||||||
Left _ => newRank rs
|
Left _ => newRank rs
|
||||||
Right _ => S (newRank rs)
|
Right _ => S (newRank rs)
|
||||||
|
|
||||||
|
export
|
||||||
|
newDim : {d : _} -> (r : CRange d) -> Maybe Nat
|
||||||
|
newDim = map (uncurry $ flip minus) . eitherToMaybe . cRangeToBounds
|
||||||
|
|
||||||
||| Calculate the new shape given by a coordinate range.
|
||| Calculate the new shape given by a coordinate range.
|
||||||
export
|
export
|
||||||
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
|
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
module Data.NumIdr.Array.Utils
|
|
||||||
|
|
||||||
import Data.NumIdr.Array.Array
|
|
||||||
|
|
||||||
|
|
||||||
export
|
|
||||||
zeros : Num a => (s : Vect rk Nat) -> Array s a
|
|
||||||
zeros = constant 0
|
|
||||||
|
|
||||||
export
|
|
||||||
ones : Num a => (s : Vect rk Nat) -> Array s a
|
|
||||||
ones = constant 1
|
|
53
src/Data/NumIdr/Matrix.idr
Normal file
53
src/Data/NumIdr/Matrix.idr
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
module Data.NumIdr.Matrix
|
||||||
|
|
||||||
|
import Data.Vect
|
||||||
|
import public Data.NumIdr.Array
|
||||||
|
|
||||||
|
%default total
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
Matrix : Nat -> Nat -> Type -> Type
|
||||||
|
Matrix m n = Array [m,n]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Matrix constructors
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a
|
||||||
|
matrix' ord x = array' [m,n] ord x
|
||||||
|
|
||||||
|
export
|
||||||
|
matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a
|
||||||
|
matrix = matrix' COrder
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
||||||
|
repeatDiag d o = fromFunction [m,n]
|
||||||
|
(\[i,j] => if finToNat i == finToNat j then d else o)
|
||||||
|
|
||||||
|
export
|
||||||
|
identity : Num a => {n : _} -> Matrix n n a
|
||||||
|
identity = repeatDiag 1 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Indexing
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
index : Fin m -> Fin n -> Matrix m n a -> a
|
||||||
|
index m n = index [m,n]
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Operations
|
||||||
|
--------------------------------------------------------------------------------
|
|
@ -30,6 +30,11 @@ arrayDataSet : Nat -> a -> ArrayData a -> IO ()
|
||||||
arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x
|
arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x
|
||||||
|
|
||||||
|
|
||||||
|
||| Construct an array with a constant value.
|
||||||
|
export
|
||||||
|
constant : Nat -> a -> PrimArray a
|
||||||
|
constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x
|
||||||
|
|
||||||
||| Construct an array from a list of "instructions" to write a
|
||| Construct an array from a list of "instructions" to write a
|
||||||
||| value to a particular index.
|
||| value to a particular index.
|
||||||
export
|
export
|
||||||
|
|
23
src/Data/NumIdr/Scalar.idr
Normal file
23
src/Data/NumIdr/Scalar.idr
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
module Data.NumIdr.Scalar
|
||||||
|
|
||||||
|
import Data.Vect
|
||||||
|
import Data.NumIdr.PrimArray
|
||||||
|
import public Data.NumIdr.Array
|
||||||
|
|
||||||
|
%default total
|
||||||
|
|
||||||
|
||| Scalars are `Array []`, the unique 0-rank array type. They hold a single value.
|
||||||
|
||| Scalars are not particularly useful as container types, but they are
|
||||||
|
||| included here anyways.
|
||||||
|
public export
|
||||||
|
Scalar : Type -> Type
|
||||||
|
Scalar = Array []
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
scalar : a -> Scalar a
|
||||||
|
scalar x = fromVect _ [x]
|
||||||
|
|
||||||
|
export
|
||||||
|
unwrap : Scalar a -> a
|
||||||
|
unwrap = index 0 . getPrim
|
91
src/Data/NumIdr/Vector.idr
Normal file
91
src/Data/NumIdr/Vector.idr
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
module Data.NumIdr.Vector
|
||||||
|
|
||||||
|
import Data.Vect
|
||||||
|
import public Data.NumIdr.Array
|
||||||
|
|
||||||
|
%default total
|
||||||
|
|
||||||
|
public export
|
||||||
|
Vector : Nat -> Type -> Type
|
||||||
|
Vector n = Array [n]
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Vector constructors
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
vector : Vect n a -> Vector n a
|
||||||
|
vector v = rewrite sym (lengthCorrect v)
|
||||||
|
in fromVect [length v] $ -- the order doesn't matter here, as
|
||||||
|
rewrite lengthCorrect v in -- there is only 1 axis
|
||||||
|
rewrite multOneLeftNeutral n in v
|
||||||
|
|
||||||
|
export
|
||||||
|
basis : Num a => {n : _} -> (i : Fin n) -> Vector n a
|
||||||
|
basis i = fromFunction _ (\[j] => if i == j then 1 else 0)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
unit2D : (ang : Double) -> Vector 2 Double
|
||||||
|
unit2D ang = vector [cos ang, sin ang]
|
||||||
|
|
||||||
|
export
|
||||||
|
unit3D : (pol, az : Double) -> Vector 3 Double
|
||||||
|
unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Indexing
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
index : Fin n -> Vector n a -> a
|
||||||
|
index n = Array.index [n]
|
||||||
|
|
||||||
|
|
||||||
|
-- Named projections
|
||||||
|
export
|
||||||
|
(.x) : Vector (1 + n) a -> a
|
||||||
|
(.x) = index FZ
|
||||||
|
|
||||||
|
export
|
||||||
|
(.y) : Vector (2 + n) a -> a
|
||||||
|
(.y) = index (FS FZ)
|
||||||
|
|
||||||
|
export
|
||||||
|
(.z) : Vector (3 + n) a -> a
|
||||||
|
(.z) = index (FS (FS FZ))
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Swizzling
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
swizzle : Vect n (Fin m) -> Vector m a -> Vector n a
|
||||||
|
swizzle p v = rewrite sym (lengthCorrect p)
|
||||||
|
in fromFunction [length p] (\[i] =>
|
||||||
|
index (index (rewrite sym (lengthCorrect p) in i) p) v
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Vector operations
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
dot : Num a => Vector n a -> Vector n a -> a
|
||||||
|
dot = sum .: zipWith (*)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
perp : Neg a => Vector 2 a -> Vector 2 a -> a
|
||||||
|
perp a b = a.x * b.y - a.y * b.x
|
||||||
|
|
Loading…
Reference in a new issue