From 861f1e29f2ade8a762430277b5bcd2625d9202d3 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Sat, 21 May 2022 16:38:21 -0400 Subject: [PATCH] Add vector & matrix modules --- README.md | 8 +-- src/Data/NumIdr/Array.idr | 5 ++ src/Data/NumIdr/Array/Array.idr | 63 +++++++++++++++++----- src/Data/NumIdr/Array/Coords.idr | 10 ++-- src/Data/NumIdr/Array/Utils.idr | 12 ----- src/Data/NumIdr/Matrix.idr | 53 +++++++++++++++++++ src/Data/NumIdr/PrimArray.idr | 5 ++ src/Data/NumIdr/Scalar.idr | 23 ++++++++ src/Data/NumIdr/Vector.idr | 91 ++++++++++++++++++++++++++++++++ 9 files changed, 236 insertions(+), 34 deletions(-) create mode 100644 src/Data/NumIdr/Array.idr delete mode 100644 src/Data/NumIdr/Array/Utils.idr create mode 100644 src/Data/NumIdr/Matrix.idr create mode 100644 src/Data/NumIdr/Scalar.idr create mode 100644 src/Data/NumIdr/Vector.idr diff --git a/README.md b/README.md index 8e5b28e..9b01bd1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # NumIdr -NumIdr is a linear algebra and matrix math library for Idris 2. It features -an efficient, type-safe array system, as well as utilities for working with -vector spaces. It is heavily inspired by Python's [NumPy](https://numpy.org/), +NumIdr is a linear algebra and data manipulation library for Idris 2. It features +an efficient, type-safe array data structure, as well as utilities for working with +vector spaces and matrices. It borrows concepts heavily from Python's [NumPy](https://numpy.org/), as well as Rust's [nalgebra](https://www.nalgebra.org/). + +The name is pronounced like "num-idge". diff --git a/src/Data/NumIdr/Array.idr b/src/Data/NumIdr/Array.idr new file mode 100644 index 0000000..fecd96b --- /dev/null +++ b/src/Data/NumIdr/Array.idr @@ -0,0 +1,5 @@ +module Data.NumIdr.Array + +import public Data.NumIdr.Array.Array +import public Data.NumIdr.Array.Coords +import public Data.NumIdr.Array.Order diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 361c856..2ea76c7 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -78,15 +78,10 @@ rank : Array s a -> Nat rank = length . shape --------------------------------------------------------------------------------- --- Array constructors --------------------------------------------------------------------------------- - shapeEq : (arr : Array s a) -> s = shape arr shapeEq (MkArray _ _ _ _) = Refl - -- Get a list of all coordinates getAllCoords' : Vect rk Nat -> List (Vect rk Nat) getAllCoords' = traverse (\case Z => []; S n => [0..n]) @@ -97,11 +92,35 @@ getAllCoords (Z :: s) = [] getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] -constant' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a -constant' s ord x = MkArray ord (calcStrides ord s) s (create (product s) (const x)) +-------------------------------------------------------------------------------- +-- Array constructors +-------------------------------------------------------------------------------- -constant : (s : Vect rk Nat) -> a -> Array s a -constant s = constant' s COrder + +||| Create an array by repeating a single value. +||| +||| @ s The shape of the constructed array +||| @ ord The order of the constructed array +export +repeat' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a +repeat' s ord x = MkArray ord (calcStrides ord s) s (constant (product s) x) + +||| Create an array by repeating a single value. +||| To specify the order of the array, use `repeat'`. +||| +||| @ s The shape of the constructed array +export +repeat : (s : Vect rk Nat) -> a -> Array s a +repeat s = repeat' s COrder + + +export +zeros : Num a => (s : Vect rk Nat) -> Array s a +zeros s = repeat s 0 + +export +ones : Num a => (s : Vect rk Nat) -> Array s a +ones s = repeat s 1 ||| Create an array given a vector of its elements. The elements of the vector ||| are arranged into the provided shape using the provided order. @@ -113,15 +132,31 @@ fromVect' : (s : Vect rk Nat) -> (ord : Order) -> Vect (product s) a -> Array s fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v) ||| Create an array given a vector of its elements. The elements of the vector -||| are assembled into the provided shape using row-major order (the last axis is the +||| are arranged into the provided shape using row-major order (the last axis is the ||| least significant). -||| To specify the order of the array, see `fromVect'`. +||| To specify the order of the array, use `fromVect'`. ||| ||| @ s The shape of the constructed array export fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a fromVect s = fromVect' s COrder +||| Create an array by taking values from a stream. +||| +||| @ s The shape of the constructed array +||| @ ord The order to interpret the elements +export +fromStream' : (s : Vect rk Nat) -> (ord : Order) -> Stream a -> Array s a +fromStream' s ord st = MkArray ord (calcStrides ord s) s (fromList $ take (product s) st) + +||| Create an array by taking values from a stream. +||| To specify the order of the array, use `fromStream'`. +||| +||| @ s The shape of the constructed array +export +fromStream : (s : Vect rk Nat) -> Stream a -> Array s a +fromStream s = fromStream' s COrder + ||| Create an array given a function to generate its elements. ||| ||| @ s The shape of the constructed array @@ -292,7 +327,7 @@ Functor (Array s) where export {s : _} -> Applicative (Array s) where - pure = constant s + pure = repeat s (<*>) = zipWith apply export @@ -328,7 +363,7 @@ Semigroup a => Semigroup (Array s a) where export {s : _} -> Monoid a => Monoid (Array s a) where - neutral = constant s neutral + neutral = repeat s neutral -- the shape must be known at runtime due to `fromInteger`. If `fromInteger` -- were moved into its own interface, this constraint could be removed. @@ -338,7 +373,7 @@ export (+) = zipWith (+) (*) = zipWith (*) - fromInteger = constant s . fromInteger + fromInteger = repeat s . fromInteger export diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 2882269..27258a4 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -1,5 +1,6 @@ module Data.NumIdr.Array.Coords +import Data.Either import Data.Vect %default total @@ -42,11 +43,6 @@ namespace Coords mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs - index : Coords s -> Vects s a -> a - index [] x = x - index (i::is) v = index is $ index i v - - export getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat getLocation' = sum .: zipWith (*) @@ -88,6 +84,10 @@ namespace CoordsRange Left _ => newRank rs Right _ => S (newRank rs) + export + newDim : {d : _} -> (r : CRange d) -> Maybe Nat + newDim = map (uncurry $ flip minus) . eitherToMaybe . cRangeToBounds + ||| Calculate the new shape given by a coordinate range. export newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat diff --git a/src/Data/NumIdr/Array/Utils.idr b/src/Data/NumIdr/Array/Utils.idr deleted file mode 100644 index 5364a5a..0000000 --- a/src/Data/NumIdr/Array/Utils.idr +++ /dev/null @@ -1,12 +0,0 @@ -module Data.NumIdr.Array.Utils - -import Data.NumIdr.Array.Array - - -export -zeros : Num a => (s : Vect rk Nat) -> Array s a -zeros = constant 0 - -export -ones : Num a => (s : Vect rk Nat) -> Array s a -ones = constant 1 diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr new file mode 100644 index 0000000..269631a --- /dev/null +++ b/src/Data/NumIdr/Matrix.idr @@ -0,0 +1,53 @@ +module Data.NumIdr.Matrix + +import Data.Vect +import public Data.NumIdr.Array + +%default total + + +public export +Matrix : Nat -> Nat -> Type -> Type +Matrix m n = Array [m,n] + + + +-------------------------------------------------------------------------------- +-- Matrix constructors +-------------------------------------------------------------------------------- + + +export +matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a +matrix' ord x = array' [m,n] ord x + +export +matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a +matrix = matrix' COrder + + +export +repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a +repeatDiag d o = fromFunction [m,n] + (\[i,j] => if finToNat i == finToNat j then d else o) + +export +identity : Num a => {n : _} -> Matrix n n a +identity = repeatDiag 1 0 + + + + +-------------------------------------------------------------------------------- +-- Indexing +-------------------------------------------------------------------------------- + + +export +index : Fin m -> Fin n -> Matrix m n a -> a +index m n = index [m,n] + + +-------------------------------------------------------------------------------- +-- Operations +-------------------------------------------------------------------------------- diff --git a/src/Data/NumIdr/PrimArray.idr b/src/Data/NumIdr/PrimArray.idr index a033aff..c7247ee 100644 --- a/src/Data/NumIdr/PrimArray.idr +++ b/src/Data/NumIdr/PrimArray.idr @@ -30,6 +30,11 @@ arrayDataSet : Nat -> a -> ArrayData a -> IO () arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x +||| Construct an array with a constant value. +export +constant : Nat -> a -> PrimArray a +constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x + ||| Construct an array from a list of "instructions" to write a ||| value to a particular index. export diff --git a/src/Data/NumIdr/Scalar.idr b/src/Data/NumIdr/Scalar.idr new file mode 100644 index 0000000..8d0e1ac --- /dev/null +++ b/src/Data/NumIdr/Scalar.idr @@ -0,0 +1,23 @@ +module Data.NumIdr.Scalar + +import Data.Vect +import Data.NumIdr.PrimArray +import public Data.NumIdr.Array + +%default total + +||| Scalars are `Array []`, the unique 0-rank array type. They hold a single value. +||| Scalars are not particularly useful as container types, but they are +||| included here anyways. +public export +Scalar : Type -> Type +Scalar = Array [] + + +export +scalar : a -> Scalar a +scalar x = fromVect _ [x] + +export +unwrap : Scalar a -> a +unwrap = index 0 . getPrim diff --git a/src/Data/NumIdr/Vector.idr b/src/Data/NumIdr/Vector.idr new file mode 100644 index 0000000..dc012ab --- /dev/null +++ b/src/Data/NumIdr/Vector.idr @@ -0,0 +1,91 @@ +module Data.NumIdr.Vector + +import Data.Vect +import public Data.NumIdr.Array + +%default total + +public export +Vector : Nat -> Type -> Type +Vector n = Array [n] + + +-------------------------------------------------------------------------------- +-- Vector constructors +-------------------------------------------------------------------------------- + + +export +vector : Vect n a -> Vector n a +vector v = rewrite sym (lengthCorrect v) + in fromVect [length v] $ -- the order doesn't matter here, as + rewrite lengthCorrect v in -- there is only 1 axis + rewrite multOneLeftNeutral n in v + +export +basis : Num a => {n : _} -> (i : Fin n) -> Vector n a +basis i = fromFunction _ (\[j] => if i == j then 1 else 0) + + +export +unit2D : (ang : Double) -> Vector 2 Double +unit2D ang = vector [cos ang, sin ang] + +export +unit3D : (pol, az : Double) -> Vector 3 Double +unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol] + + + +-------------------------------------------------------------------------------- +-- Indexing +-------------------------------------------------------------------------------- + + +export +index : Fin n -> Vector n a -> a +index n = Array.index [n] + + +-- Named projections +export +(.x) : Vector (1 + n) a -> a +(.x) = index FZ + +export +(.y) : Vector (2 + n) a -> a +(.y) = index (FS FZ) + +export +(.z) : Vector (3 + n) a -> a +(.z) = index (FS (FS FZ)) + + +-------------------------------------------------------------------------------- +-- Swizzling +-------------------------------------------------------------------------------- + + +export +swizzle : Vect n (Fin m) -> Vector m a -> Vector n a +swizzle p v = rewrite sym (lengthCorrect p) + in fromFunction [length p] (\[i] => + index (index (rewrite sym (lengthCorrect p) in i) p) v + ) + + + +-------------------------------------------------------------------------------- +-- Vector operations +-------------------------------------------------------------------------------- + + +export +dot : Num a => Vector n a -> Vector n a -> a +dot = sum .: zipWith (*) + + +export +perp : Neg a => Vector 2 a -> Vector 2 a -> a +perp a b = a.x * b.y - a.y * b.x +