Add documentation

This commit is contained in:
Kiana Sheibani 2022-06-25 00:58:36 -04:00
parent 11d771b926
commit 59af31cdd7
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
10 changed files with 188 additions and 37 deletions

View file

@ -1,5 +1,6 @@
module Data.NumIdr.Array module Data.NumIdr.Array
import public Data.NP
import public Data.NumIdr.Array.Array import public Data.NumIdr.Array.Array
import public Data.NumIdr.Array.Coords import public Data.NumIdr.Array.Coords
import public Data.NumIdr.Array.Order import public Data.NumIdr.Array.Order

View file

@ -4,6 +4,7 @@ import Data.List
import Data.List1 import Data.List1
import Data.Vect import Data.Vect
import Data.Zippable import Data.Zippable
import Data.NP
import Data.NumIdr.Multiply import Data.NumIdr.Multiply
import Data.NumIdr.PrimArray import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order import Data.NumIdr.Array.Order
@ -63,6 +64,7 @@ strides : Array {rk} s a -> Vect rk Nat
strides (MkArray _ sts _ _) = sts strides (MkArray _ sts _ _) = sts
||| The total number of elements of the array ||| The total number of elements of the array
|||
||| This is equivalent to `product s`. ||| This is equivalent to `product s`.
export export
size : Array s a -> Nat size : Array s a -> Nat
@ -175,6 +177,26 @@ export
fromStream : (s : Vect rk Nat) -> Stream a -> Array s a fromStream : (s : Vect rk Nat) -> Stream a -> Array s a
fromStream s = fromStream' s COrder fromStream s = fromStream' s COrder
||| Create an array given a function to generate its elements.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
export
fromFunctionNB' : (s : Vect rk Nat) -> (ord : Order) -> (Vect rk Nat -> a) -> Array s a
fromFunctionNB' s ord f = let sts = calcStrides ord s
in MkArray ord sts s (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, f is)) $ getAllCoords' s)
||| Create an array given a function to generate its elements.
||| To specify the order of the array, use `fromFunctionNB'`.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
export
fromFunctionNB : (s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
fromFunctionNB s = fromFunctionNB' s COrder
||| Create an array given a function to generate its elements. ||| Create an array given a function to generate its elements.
||| |||
||| @ s The shape of the constructed array ||| @ s The shape of the constructed array
@ -219,32 +241,39 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
-- Indexing -- Indexing
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
infix 10 !! infixl 10 !!
infix 10 !? infixl 10 !?
infixl 11 !!.. infixl 11 !!..
infix 11 !?.. infixl 11 !?..
||| Index the array using the given `Coords` object. ||| Index the array using the given coordinates.
export export
index : Coords s -> Array s a -> a index : Coords s -> Array s a -> a
index is arr = index (getLocation (strides arr) is) (getPrim arr) index is arr = index (getLocation (strides arr) is) (getPrim arr)
||| Index the array using the given coordinates.
|||
||| This is the operator form of `index`.
export export
(!!) : Array s a -> Coords s -> a (!!) : Array s a -> Coords s -> a
(!!) = flip index (!!) = flip index
-- TODO: Create set/update at index functions -- TODO: Create set/update at index functions
||| Update the value at the given coordinates using the function.
export export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
indexUpdate is f (MkArray ord sts s arr) = indexUpdate is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation sts is) f arr) MkArray ord sts s (updateAt (getLocation sts is) f arr)
||| Set the value at the given coordinates to the given value.
export export
indexSet : Coords s -> a -> Array s a -> Array s a indexSet : Coords s -> a -> Array s a -> Array s a
indexSet is = indexUpdate is . const indexSet is = indexUpdate is . const
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is arr with (viewShape arr) indexNB is arr with (viewShape arr)
@ -252,21 +281,29 @@ indexNB is arr with (viewShape arr)
then Just $ index (getLocation' (strides arr) is) (getPrim arr) then Just $ index (getLocation' (strides arr) is) (getPrim arr)
else Nothing else Nothing
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
|||
||| This is the operator form of `indexNB`.
export export
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
(!?) = flip indexNB (!?) = flip indexNB
||| Update the value at the given coordinates using the function. The array is
||| returned unchanged if the coordinate is out of bounds.
export export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
indexUpdateNB is f (MkArray ord sts s arr) = indexUpdateNB is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation' sts is) f arr) MkArray ord sts s (updateAt (getLocation' sts is) f arr)
||| Set the value at the given coordinates to the given value. The array is
||| returned unchanged if the coordinate is out of bounds.
export export
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
indexSetNB is = indexUpdateNB is . const indexSetNB is = indexUpdateNB is . const
||| Index the array using the given `CoordsRange` object. ||| Index the array using the given range of coordinates, returning a new array.
export export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange rs arr with (viewShape arr) indexRange rs arr with (viewShape arr)
@ -281,6 +318,9 @@ indexRange rs arr with (viewShape arr)
where s' : Vect ? Nat where s' : Vect ? Nat
s' = newShape rs s' = newShape rs
||| Index the array using the given range of coordinates, returning a new array.
|||
||| This is the operator form of `indexRange`.
export export
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
arr !!.. rs = indexRange rs arr arr !!.. rs = indexRange rs arr
@ -320,10 +360,20 @@ reorder ord' arr with (viewShape arr)
getAllCoords' s) getAllCoords' s)
||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. New coordinates are filled with a default value.
|||
||| @ s' The shape to resize the array to
||| @ def The default value to fill the array with
export export
resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB) resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB)
||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. This function requires a proof that the new shape is strictly
||| smaller than the current shape of the array.
|||
||| @ s' The shape to resize the array to
export export
-- TODO: Come up with a solution that doesn't use `believe_me` or trip over some -- TODO: Come up with a solution that doesn't use `believe_me` or trip over some
-- weird bug in the type-checker -- weird bug in the type-checker
@ -332,17 +382,23 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
resizeLTE s' arr = resize s' (believe_me ()) arr resizeLTE s' arr = resize s' (believe_me ()) arr
||| List all of the values in an array along with their coordinates.
export export
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a) enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
enumerateNB (MkArray _ sts sh p) = enumerateNB (MkArray _ sts sh p) =
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh) map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
||| List all of the values in an array along with their coordinates.
export export
enumerate : Array {rk} s a -> List (Coords s, a) enumerate : Array s a -> List (Coords s, a)
enumerate arr with (viewShape arr) enumerate arr with (viewShape arr)
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s) _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
||| Join two arrays along a particular axis, e.g. combining two matrices
||| vertically or horizontally. The arrays must have the same shape on all other axes.
|||
||| @ axis The axis to join the arrays on
export export
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
Array (updateAt axis (+d) s) a Array (updateAt axis (+d) s) a
@ -357,6 +413,9 @@ concat axis a b = let sA = shape a
-- TODO: prove that the type-level shape and `s` are equivalent -- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
|||
||| @ axis The axis to stack the arrays along
export export
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
stack axis arrs = rewrite sym (lengthCorrect arrs) in stack axis arrs = rewrite sym (lengthCorrect arrs) in
@ -415,7 +474,7 @@ export
export export
{s : _} -> Monad (Array s) where {s : _} -> Monad (Array s) where
join arr = fromFunction s (\is => index is $ index is arr) join arr = fromFunction s (\is => arr !! is !! is)
-- Foldable and Traversable operate on the primitive array directly. This means -- Foldable and Traversable operate on the primitive array directly. This means
@ -466,6 +525,7 @@ export
negate = map negate negate = map negate
(-) = zipWith (-) (-) = zipWith (-)
export
{s : _} -> Fractional a => Fractional (Array s a) where {s : _} -> Fractional a => Fractional (Array s a) where
recip = map recip recip = map recip
(/) = zipWith (/) (/) = zipWith (/)
@ -508,23 +568,30 @@ Show a => Show (Array s a) where
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Linearly interpolate between two arrays.
export export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t) lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
||| Calculate the square of an array's Eulidean norm.
export export
normSq : Num a => Array s a -> a normSq : Num a => Array s a -> a
normSq arr = sum $ zipWith (*) arr arr normSq arr = sum $ zipWith (*) arr arr
||| Calculate an array's Eucliean norm.
export export
norm : Array s Double -> Double norm : Array s Double -> Double
norm = sqrt . normSq norm = sqrt . normSq
||| Normalize the array to a norm of 1.
|||
||| If the array contains all zeros, then it is returned unchanged.
export export
normalize : Array s Double -> Array s Double normalize : Array s Double -> Array s Double
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
||| Calculate the Lp-norm of an array.
export export
pnorm : (p : Double) -> Array s Double -> Double pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p) pnorm p = (`pow` recip p) . sum . map (`pow` p)

View file

@ -2,7 +2,7 @@ module Data.NumIdr.Array.Coords
import Data.Either import Data.Either
import Data.Vect import Data.Vect
import public Data.NP import Data.NP
%default total %default total

View file

@ -29,7 +29,6 @@ Eq Order where
FOrder == COrder = False FOrder == COrder = False
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
scanr _ q0 [] = [q0] scanr _ q0 [] = [q0]
scanr f q0 (x::xs) = f x (head qs) :: qs scanr f q0 (x::xs) = f x (head qs) :: qs

View file

@ -8,14 +8,33 @@ import Data.NumIdr.Matrix
%default total %default total
public export
HVector : Nat -> Type -> Type
HVector n = Vector (S n)
||| A homogeneous matrix type.
|||
||| Homogeneous matrices are matrices that encode an affine transformation, as
||| opposed to the more restriced linear transformations of normal matrices.
||| Their multiplication is identical to regular matrix multiplication, but the
||| extra dimension in each axis allows the homogeneous matrix to store a
||| translation vector that is applied during multiplication.
|||
||| ```hmatrix mat tr *. v == mat *. v + tr```
public export public export
HMatrix : Nat -> Nat -> Type -> Type HMatrix : Nat -> Nat -> Type -> Type
HMatrix m n = Matrix (S m) (S n) HMatrix m n = Matrix (S m) (S n)
||| A synonym for a square homogeneous matrix.
public export
HMatrix' : Nat -> Type -> Type
HMatrix' n = HMatrix n n
||| An `n`-dimensional homogeneous vector type.
|||
||| Homogeneous vectors are vectors intended to be multiplied with homogeneous
||| matrices (see `HMatrix`). They have an extra dimension compared to regular
||| vectors.
public export
HVector : Nat -> Type -> Type
HVector n = Vector (S n)
export export

View file

@ -8,10 +8,12 @@ import Data.NumIdr.Vector
%default total %default total
||| A matrix is a rank-2 array.
public export public export
Matrix : Nat -> Nat -> Type -> Type Matrix : Nat -> Nat -> Type -> Type
Matrix m n = Array [m,n] Matrix m n = Array [m,n]
||| A synonym for a square matrix with dimensions of length `n`.
public export public export
Matrix' : Nat -> Type -> Type Matrix' : Nat -> Type -> Type
Matrix' n = Matrix n n Matrix' n = Matrix n n
@ -22,26 +24,30 @@ Matrix' n = Matrix n n
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Construct a matrix with the given order and elements.
export export
matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a
matrix' ord x = array' [m,n] ord x matrix' ord x = array' [m,n] ord x
||| Construct a matrix with the given elements.
export export
matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a
matrix = matrix' COrder matrix = matrix' COrder
||| Construct a matrix with a specific value along the diagonal.
|||
||| @ diag The value to repeat along the diagonal
||| @ other The value to repeat elsewhere
export export
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
repeatDiag d o = fromFunction [m,n] repeatDiag d o = fromFunctionNB [m,n]
(\[i,j] => if i `eq` j then d else o) (\[i,j] => if i == j then d else o)
where
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Bool
eq FZ FZ = True
eq (FS x) (FS y) = eq x y
eq FZ (FS _) = False
eq (FS _) FZ = False
||| Construct a matrix given its diagonal elements.
|||
||| @ diag The elements of the matrix's diagonal
||| @ other The value to repeat elsewhere
export export
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j) fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
@ -53,15 +59,18 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
eq (FS _) FZ = Nothing eq (FS _) FZ = Nothing
||| The `n`-dimensional identity matrix.
export export
identity : Num a => {n : _} -> Matrix' n a identity : Num a => {n : _} -> Matrix' n a
identity = repeatDiag 1 0 identity = repeatDiag 1 0
||| Calculate the matrix that scales a vector by the given value.
export export
scaling : Num a => {n : _} -> a -> Matrix' n a scaling : Num a => {n : _} -> a -> Matrix' n a
scaling x = repeatDiag x 0 scaling x = repeatDiag x 0
||| Calculate the rotation matrix of an angle.
export export
rotation2D : Double -> Matrix' 2 Double rotation2D : Double -> Matrix' 2 Double
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]] rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
@ -72,19 +81,24 @@ rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Index the matrix at the given coordinates.
export export
index : Fin m -> Fin n -> Matrix m n a -> a index : Fin m -> Fin n -> Matrix m n a -> a
index m n = index [m,n] index m n = index [m,n]
||| Index the matrix at the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export export
indexNB : Nat -> Nat -> Matrix m n a -> Maybe a indexNB : Nat -> Nat -> Matrix m n a -> Maybe a
indexNB m n = indexNB [m,n] indexNB m n = indexNB [m,n]
||| Return a row of the matrix as a vector.
export export
getRow : Fin m -> Matrix m n a -> Vector n a getRow : Fin m -> Matrix m n a -> Vector n a
getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat
||| Return a column of the matrix as a vector.
export export
getColumn : Fin n -> Matrix m n a -> Vector m a getColumn : Fin n -> Matrix m n a -> Vector m a
getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
@ -94,19 +108,22 @@ getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
-- Operations -- Operations
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Concatenate two matrices vertically.
export export
vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
vconcat = concat 0 vconcat = concat 0
||| Concatenate two matrices horizontally.
export export
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
hconcat = concat 1 hconcat = concat 1
||| Calculate the kronecker product of two vectors as a matrix.
export export
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
kronecker a b with (viewShape a, viewShape b) kronecker a b with (viewShape a, viewShape b)
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b) _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a !! i * b !! j)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View file

@ -6,32 +6,49 @@ module Data.NumIdr.Multiply
infixr 9 *. infixr 9 *.
infixr 10 ^ infixr 10 ^
||| A generalized multiplication/transformation operator. This interface is ||| A generalized multiplication/application operator. This interface is
||| necessary since the standard multiplication operator is homogenous. ||| necessary since the standard multiplication operator is homogenous.
|||
||| All instances of this interface must collectively satisfy these axioms:
||| * If `(x *. y) *. z` is defined, then `x *. (y *. z)` is defined and equal.
||| * If `x *. (y *. z)` is defined, then `(x *. y) *. z` is defined and equal.
public export public export
interface Mult a b c | a,b where interface Mult a b c | a,b where
(*.) : a -> b -> c (*.) : a -> b -> c
||| An interface for monoids using the `*.` operator.
|||
||| An instance of this interface must satisfy:
||| * `x *. neutral == x`
||| * `neutral *. x == x`
public export public export
interface (Mult a a a) => MultNeutral a where interface Mult a a a => MultNeutral a where
neutral : a neutral : a
||| Multiplication forms a semigroup
public export public export
[MultSemigroup] Mult a a a => Semigroup a where [MultSemigroup] Mult a a a => Semigroup a where
(<+>) = (*.) (<+>) = (*.)
||| Multiplication with a neutral element forms a monoid
public export public export
[MultMonoid] MultNeutral a => Monoid a using MultSemigroup where [MultMonoid] MultNeutral a => Monoid a using MultSemigroup where
neutral = Multiply.neutral neutral = Multiply.neutral
||| Raise a multiplicative value (e.g. a matrix or a transformation) to a natural
||| number power.
public export public export
power : MultNeutral a => Nat -> a -> a power : MultNeutral a => Nat -> a -> a
power 0 _ = neutral power 0 _ = neutral
power 1 x = x power 1 x = x
power (S n@(S _)) x = x *. power n x power (S n@(S _)) x = x *. power n x
||| Raise a multiplicative value (e.g. a matrix or a transformation) to a natural
||| number power.
|||
||| This is the operator form of `power`.
public export public export
(^) : MultNeutral a => a -> Nat -> a (^) : MultNeutral a => a -> Nat -> a
(^) = flip power (^) = flip power

View file

@ -35,8 +35,8 @@ export
constant : Nat -> a -> PrimArray a constant : Nat -> a -> PrimArray a
constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x
||| Construct an array from a list of "instructions" to write a ||| Construct an array from a list of "instructions" to write a value to a
||| value to a particular index. ||| particular index.
export export
unsafeFromIns : Nat -> List (Nat, a) -> PrimArray a unsafeFromIns : Nat -> List (Nat, a) -> PrimArray a
unsafeFromIns size ins = unsafePerformIO $ do unsafeFromIns size ins = unsafePerformIO $ do
@ -44,8 +44,8 @@ unsafeFromIns size ins = unsafePerformIO $ do
for_ ins $ \(i,x) => arrayDataSet i x arr for_ ins $ \(i,x) => arrayDataSet i x arr
pure $ MkPrimArray size arr pure $ MkPrimArray size arr
||| Create an array given its size and a function to generate ||| Create an array given its size and a function to generate its elements by
||| its elements by its index. ||| its index.
export export
create : Nat -> (Nat -> a) -> PrimArray a create : Nat -> (Nat -> a) -> PrimArray a
create size f = unsafePerformIO $ do create size f = unsafePerformIO $ do
@ -59,8 +59,8 @@ create size f = unsafePerformIO $ do
= do arrayDataSet loc (f loc) arr = do arrayDataSet loc (f loc) arr
addToArray (S loc) n arr addToArray (S loc) n arr
||| Index into a primitive array. This function is unsafe, as it ||| Index into a primitive array. This function is unsafe, as it performs no
||| performs no boundary check on the index given. ||| boundary check on the index given.
export export
index : Nat -> PrimArray a -> a index : Nat -> PrimArray a -> a
index n arr = unsafePerformIO $ arrayDataGet n $ content arr index n arr = unsafePerformIO $ arrayDataGet n $ content arr
@ -156,8 +156,8 @@ traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
traverse f = map fromList . traverse f . toList traverse f = map fromList . traverse f . toList
||| Compares two primitive arrays for equal elements. This function assumes ||| Compares two primitive arrays for equal elements. This function assumes the
||| the arrays same the same length; it must not be used in any other case. ||| arrays have the same length; it must not be used in any other case.
export export
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
unsafeEq a b = unsafePerformIO $ unsafeEq a b = unsafePerformIO $

View file

@ -15,10 +15,12 @@ Scalar : Type -> Type
Scalar = Array [] Scalar = Array []
||| Convert a value to a scalar.
export export
scalar : a -> Scalar a scalar : a -> Scalar a
scalar x = fromVect _ [x] scalar x = fromVect _ [x]
||| Unwrap the single value from a scalar.
export export
unwrap : Scalar a -> a unwrap : Scalar a -> a
unwrap = index 0 . getPrim unwrap = index 0 . getPrim

View file

@ -6,11 +6,14 @@ import public Data.NumIdr.Array
%default total %default total
||| A vector is a rank-1 array.
public export public export
Vector : Nat -> Type -> Type Vector : Nat -> Type -> Type
Vector n = Array [n] Vector n = Array [n]
||| The length (number of dimensions) of the vector
public export public export
dim : Vector n a -> Nat dim : Vector n a -> Nat
dim = head . shape dim = head . shape
@ -25,6 +28,7 @@ dimEq v = cong head $ shapeEq v
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Construct a vector from a `Vect`.
export export
vector : Vect n a -> Vector n a vector : Vect n a -> Vector n a
vector v = rewrite sym (lengthCorrect v) vector v = rewrite sym (lengthCorrect v)
@ -32,20 +36,28 @@ vector v = rewrite sym (lengthCorrect v)
rewrite lengthCorrect v in -- there is only 1 axis rewrite lengthCorrect v in -- there is only 1 axis
rewrite multOneLeftNeutral n in v rewrite multOneLeftNeutral n in v
||| Convert a vector into a `Vect`.
export export
toVect : Vector n a -> Vect n a toVect : Vector n a -> Vect n a
toVect v = believe_me $ Vect.fromList $ toList v toVect v = believe_me $ Vect.fromList $ toList v
||| Return the `i`-th basis vector.
export export
basis : Num a => {n : _} -> (i : Fin n) -> Vector n a basis : Num a => {n : _} -> (i : Fin n) -> Vector n a
basis i = fromFunction _ (\[j] => if i == j then 1 else 0) basis i = fromFunction _ (\[j] => if i == j then 1 else 0)
||| Calculate the 2D unit vector with the given angle off the x-axis.
export export
unit2D : (ang : Double) -> Vector 2 Double unit2D : (ang : Double) -> Vector 2 Double
unit2D ang = vector [cos ang, sin ang] unit2D ang = vector [cos ang, sin ang]
||| Calculate the 3D unit vector corresponding to the given spherical coordinates,
||| where the polar axis is the z-axis.
|||
||| @ pol The polar angle of the vector
||| @ az The azimuthal angle of the vector
export export
unit3D : (pol, az : Double) -> Vector 3 Double unit3D : (pol, az : Double) -> Vector 3 Double
unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol] unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol]
@ -56,35 +68,49 @@ unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol]
-- Indexing -- Indexing
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
infix 10 !! infixl 10 !!
infix 10 !? infixl 10 !?
||| Index the vector at the given coordinate.
export export
index : Fin n -> Vector n a -> a index : Fin n -> Vector n a -> a
index n = Array.index [n] index n = Array.index [n]
||| Index the vector at the given coordinate.
|||
||| This is the operator form of `index`.
export export
(!!) : Vector n a -> Fin n -> a (!!) : Vector n a -> Fin n -> a
(!!) = flip index (!!) = flip index
||| Index the vector at the given coordinate, returning `Nothing` if the
||| coordinate is out of bounds.
export export
indexNB : Nat -> Vector n a -> Maybe a indexNB : Nat -> Vector n a -> Maybe a
indexNB n = Array.indexNB [n] indexNB n = Array.indexNB [n]
||| Index the vector at the given coordinate, returning `Nothing` if the
||| coordinate is out of bounds.
|||
||| This is the operator form of `indexNB`.
export export
(!?) : Vector n a -> Nat -> Maybe a (!?) : Vector n a -> Nat -> Maybe a
(!?) = flip indexNB (!?) = flip indexNB
-- Named projections -- Named projections
||| Return the x-coordinate (the first value) of the vector.
export export
(.x) : Vector (1 + n) a -> a (.x) : Vector (1 + n) a -> a
(.x) = index FZ (.x) = index FZ
||| Return the y-coordinate (the second value) of the vector.
export export
(.y) : Vector (2 + n) a -> a (.y) : Vector (2 + n) a -> a
(.y) = index (FS FZ) (.y) = index (FS FZ)
||| Return the z-coordinate (the third value) of the vector.
export export
(.z) : Vector (3 + n) a -> a (.z) : Vector (3 + n) a -> a
(.z) = index (FS (FS FZ)) (.z) = index (FS (FS FZ))
@ -109,16 +135,19 @@ swizzle p v = rewrite sym (lengthCorrect p)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Concatenate one vector with another.
export export
(++) : Vector m a -> Vector n a -> Vector (m + n) a (++) : Vector m a -> Vector n a -> Vector (m + n) a
(++) = concat 0 (++) = concat 0
||| Calculate the dot product of the two vectors.
export export
dot : Num a => Vector n a -> Vector n a -> a dot : Num a => Vector n a -> Vector n a -> a
dot = sum .: zipWith (*) dot = sum .: zipWith (*)
||| Calculate the perpendicular product of the two vectors.
export export
perp : Neg a => Vector 2 a -> Vector 2 a -> a perp : Neg a => Vector 2 a -> Vector 2 a -> a
perp a b = a.x * b.y - a.y * b.x perp a b = a.x * b.y - a.y * b.x