Add documentation
This commit is contained in:
parent
11d771b926
commit
59af31cdd7
|
@ -1,5 +1,6 @@
|
||||||
module Data.NumIdr.Array
|
module Data.NumIdr.Array
|
||||||
|
|
||||||
|
import public Data.NP
|
||||||
import public Data.NumIdr.Array.Array
|
import public Data.NumIdr.Array.Array
|
||||||
import public Data.NumIdr.Array.Coords
|
import public Data.NumIdr.Array.Coords
|
||||||
import public Data.NumIdr.Array.Order
|
import public Data.NumIdr.Array.Order
|
||||||
|
|
|
@ -4,6 +4,7 @@ import Data.List
|
||||||
import Data.List1
|
import Data.List1
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.Zippable
|
import Data.Zippable
|
||||||
|
import Data.NP
|
||||||
import Data.NumIdr.Multiply
|
import Data.NumIdr.Multiply
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import Data.NumIdr.Array.Order
|
import Data.NumIdr.Array.Order
|
||||||
|
@ -63,6 +64,7 @@ strides : Array {rk} s a -> Vect rk Nat
|
||||||
strides (MkArray _ sts _ _) = sts
|
strides (MkArray _ sts _ _) = sts
|
||||||
|
|
||||||
||| The total number of elements of the array
|
||| The total number of elements of the array
|
||||||
|
|||
|
||||||
||| This is equivalent to `product s`.
|
||| This is equivalent to `product s`.
|
||||||
export
|
export
|
||||||
size : Array s a -> Nat
|
size : Array s a -> Nat
|
||||||
|
@ -175,6 +177,26 @@ export
|
||||||
fromStream : (s : Vect rk Nat) -> Stream a -> Array s a
|
fromStream : (s : Vect rk Nat) -> Stream a -> Array s a
|
||||||
fromStream s = fromStream' s COrder
|
fromStream s = fromStream' s COrder
|
||||||
|
|
||||||
|
|
||||||
|
||| Create an array given a function to generate its elements.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ ord The order to interpret the elements
|
||||||
|
export
|
||||||
|
fromFunctionNB' : (s : Vect rk Nat) -> (ord : Order) -> (Vect rk Nat -> a) -> Array s a
|
||||||
|
fromFunctionNB' s ord f = let sts = calcStrides ord s
|
||||||
|
in MkArray ord sts s (unsafeFromIns (product s) $
|
||||||
|
map (\is => (getLocation' sts is, f is)) $ getAllCoords' s)
|
||||||
|
|
||||||
|
||| Create an array given a function to generate its elements.
|
||||||
|
||| To specify the order of the array, use `fromFunctionNB'`.
|
||||||
|
|||
|
||||||
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ ord The order to interpret the elements
|
||||||
|
export
|
||||||
|
fromFunctionNB : (s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
|
||||||
|
fromFunctionNB s = fromFunctionNB' s COrder
|
||||||
|
|
||||||
||| Create an array given a function to generate its elements.
|
||| Create an array given a function to generate its elements.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
|
@ -219,32 +241,39 @@ array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||||
-- Indexing
|
-- Indexing
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
infix 10 !!
|
infixl 10 !!
|
||||||
infix 10 !?
|
infixl 10 !?
|
||||||
infixl 11 !!..
|
infixl 11 !!..
|
||||||
infix 11 !?..
|
infixl 11 !?..
|
||||||
|
|
||||||
||| Index the array using the given `Coords` object.
|
||| Index the array using the given coordinates.
|
||||||
export
|
export
|
||||||
index : Coords s -> Array s a -> a
|
index : Coords s -> Array s a -> a
|
||||||
index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
||||||
|
|
||||||
|
||| Index the array using the given coordinates.
|
||||||
|
|||
|
||||||
|
||| This is the operator form of `index`.
|
||||||
export
|
export
|
||||||
(!!) : Array s a -> Coords s -> a
|
(!!) : Array s a -> Coords s -> a
|
||||||
(!!) = flip index
|
(!!) = flip index
|
||||||
|
|
||||||
-- TODO: Create set/update at index functions
|
-- TODO: Create set/update at index functions
|
||||||
|
|
||||||
|
||| Update the value at the given coordinates using the function.
|
||||||
export
|
export
|
||||||
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
||||||
indexUpdate is f (MkArray ord sts s arr) =
|
indexUpdate is f (MkArray ord sts s arr) =
|
||||||
MkArray ord sts s (updateAt (getLocation sts is) f arr)
|
MkArray ord sts s (updateAt (getLocation sts is) f arr)
|
||||||
|
|
||||||
|
||| Set the value at the given coordinates to the given value.
|
||||||
export
|
export
|
||||||
indexSet : Coords s -> a -> Array s a -> Array s a
|
indexSet : Coords s -> a -> Array s a -> Array s a
|
||||||
indexSet is = indexUpdate is . const
|
indexSet is = indexUpdate is . const
|
||||||
|
|
||||||
|
|
||||||
|
||| Index the array using the given coordinates, returning `Nothing` if the
|
||||||
|
||| coordinates are out of bounds.
|
||||||
export
|
export
|
||||||
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
||||||
indexNB is arr with (viewShape arr)
|
indexNB is arr with (viewShape arr)
|
||||||
|
@ -252,21 +281,29 @@ indexNB is arr with (viewShape arr)
|
||||||
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
|
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
|
||||||
else Nothing
|
else Nothing
|
||||||
|
|
||||||
|
||| Index the array using the given coordinates, returning `Nothing` if the
|
||||||
|
||| coordinates are out of bounds.
|
||||||
|
|||
|
||||||
|
||| This is the operator form of `indexNB`.
|
||||||
export
|
export
|
||||||
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
||||||
(!?) = flip indexNB
|
(!?) = flip indexNB
|
||||||
|
|
||||||
|
||| Update the value at the given coordinates using the function. The array is
|
||||||
|
||| returned unchanged if the coordinate is out of bounds.
|
||||||
export
|
export
|
||||||
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
|
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
|
||||||
indexUpdateNB is f (MkArray ord sts s arr) =
|
indexUpdateNB is f (MkArray ord sts s arr) =
|
||||||
MkArray ord sts s (updateAt (getLocation' sts is) f arr)
|
MkArray ord sts s (updateAt (getLocation' sts is) f arr)
|
||||||
|
|
||||||
|
||| Set the value at the given coordinates to the given value. The array is
|
||||||
|
||| returned unchanged if the coordinate is out of bounds.
|
||||||
export
|
export
|
||||||
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
|
indexSetNB : Vect rk Nat -> a -> Array {rk} s a -> Array s a
|
||||||
indexSetNB is = indexUpdateNB is . const
|
indexSetNB is = indexUpdateNB is . const
|
||||||
|
|
||||||
|
|
||||||
||| Index the array using the given `CoordsRange` object.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
export
|
export
|
||||||
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||||
indexRange rs arr with (viewShape arr)
|
indexRange rs arr with (viewShape arr)
|
||||||
|
@ -281,6 +318,9 @@ indexRange rs arr with (viewShape arr)
|
||||||
where s' : Vect ? Nat
|
where s' : Vect ? Nat
|
||||||
s' = newShape rs
|
s' = newShape rs
|
||||||
|
|
||||||
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
|
|||
|
||||||
|
||| This is the operator form of `indexRange`.
|
||||||
export
|
export
|
||||||
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
|
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
|
||||||
arr !!.. rs = indexRange rs arr
|
arr !!.. rs = indexRange rs arr
|
||||||
|
@ -320,10 +360,20 @@ reorder ord' arr with (viewShape arr)
|
||||||
getAllCoords' s)
|
getAllCoords' s)
|
||||||
|
|
||||||
|
|
||||||
|
||| Resize the array to a new shape, preserving the coordinates of the original
|
||||||
|
||| elements. New coordinates are filled with a default value.
|
||||||
|
|||
|
||||||
|
||| @ s' The shape to resize the array to
|
||||||
|
||| @ def The default value to fill the array with
|
||||||
export
|
export
|
||||||
resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a
|
resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
|
||||||
resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNB)
|
resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB)
|
||||||
|
|
||||||
|
||| Resize the array to a new shape, preserving the coordinates of the original
|
||||||
|
||| elements. This function requires a proof that the new shape is strictly
|
||||||
|
||| smaller than the current shape of the array.
|
||||||
|
|||
|
||||||
|
||| @ s' The shape to resize the array to
|
||||||
export
|
export
|
||||||
-- TODO: Come up with a solution that doesn't use `believe_me` or trip over some
|
-- TODO: Come up with a solution that doesn't use `believe_me` or trip over some
|
||||||
-- weird bug in the type-checker
|
-- weird bug in the type-checker
|
||||||
|
@ -332,17 +382,23 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
|
||||||
resizeLTE s' arr = resize s' (believe_me ()) arr
|
resizeLTE s' arr = resize s' (believe_me ()) arr
|
||||||
|
|
||||||
|
|
||||||
|
||| List all of the values in an array along with their coordinates.
|
||||||
export
|
export
|
||||||
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
|
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
|
||||||
enumerateNB (MkArray _ sts sh p) =
|
enumerateNB (MkArray _ sts sh p) =
|
||||||
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
||||||
|
|
||||||
|
||| List all of the values in an array along with their coordinates.
|
||||||
export
|
export
|
||||||
enumerate : Array {rk} s a -> List (Coords s, a)
|
enumerate : Array s a -> List (Coords s, a)
|
||||||
enumerate arr with (viewShape arr)
|
enumerate arr with (viewShape arr)
|
||||||
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
|
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
|
||||||
|
|
||||||
|
|
||||||
|
||| Join two arrays along a particular axis, e.g. combining two matrices
|
||||||
|
||| vertically or horizontally. The arrays must have the same shape on all other axes.
|
||||||
|
|||
|
||||||
|
||| @ axis The axis to join the arrays on
|
||||||
export
|
export
|
||||||
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||||
Array (updateAt axis (+d) s) a
|
Array (updateAt axis (+d) s) a
|
||||||
|
@ -357,6 +413,9 @@ concat axis a b = let sA = shape a
|
||||||
-- TODO: prove that the type-level shape and `s` are equivalent
|
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||||
|
|
||||||
|
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
|
||||||
|
|||
|
||||||
|
||| @ axis The axis to stack the arrays along
|
||||||
export
|
export
|
||||||
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
|
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
|
||||||
stack axis arrs = rewrite sym (lengthCorrect arrs) in
|
stack axis arrs = rewrite sym (lengthCorrect arrs) in
|
||||||
|
@ -415,7 +474,7 @@ export
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Monad (Array s) where
|
{s : _} -> Monad (Array s) where
|
||||||
join arr = fromFunction s (\is => index is $ index is arr)
|
join arr = fromFunction s (\is => arr !! is !! is)
|
||||||
|
|
||||||
|
|
||||||
-- Foldable and Traversable operate on the primitive array directly. This means
|
-- Foldable and Traversable operate on the primitive array directly. This means
|
||||||
|
@ -466,6 +525,7 @@ export
|
||||||
negate = map negate
|
negate = map negate
|
||||||
(-) = zipWith (-)
|
(-) = zipWith (-)
|
||||||
|
|
||||||
|
export
|
||||||
{s : _} -> Fractional a => Fractional (Array s a) where
|
{s : _} -> Fractional a => Fractional (Array s a) where
|
||||||
recip = map recip
|
recip = map recip
|
||||||
(/) = zipWith (/)
|
(/) = zipWith (/)
|
||||||
|
@ -508,23 +568,30 @@ Show a => Show (Array s a) where
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
||| Linearly interpolate between two arrays.
|
||||||
export
|
export
|
||||||
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
|
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
|
||||||
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
|
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
|
||||||
|
|
||||||
|
|
||||||
|
||| Calculate the square of an array's Eulidean norm.
|
||||||
export
|
export
|
||||||
normSq : Num a => Array s a -> a
|
normSq : Num a => Array s a -> a
|
||||||
normSq arr = sum $ zipWith (*) arr arr
|
normSq arr = sum $ zipWith (*) arr arr
|
||||||
|
|
||||||
|
||| Calculate an array's Eucliean norm.
|
||||||
export
|
export
|
||||||
norm : Array s Double -> Double
|
norm : Array s Double -> Double
|
||||||
norm = sqrt . normSq
|
norm = sqrt . normSq
|
||||||
|
|
||||||
|
||| Normalize the array to a norm of 1.
|
||||||
|
|||
|
||||||
|
||| If the array contains all zeros, then it is returned unchanged.
|
||||||
export
|
export
|
||||||
normalize : Array s Double -> Array s Double
|
normalize : Array s Double -> Array s Double
|
||||||
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
|
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
|
||||||
|
|
||||||
|
||| Calculate the Lp-norm of an array.
|
||||||
export
|
export
|
||||||
pnorm : (p : Double) -> Array s Double -> Double
|
pnorm : (p : Double) -> Array s Double -> Double
|
||||||
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
||||||
|
|
|
@ -2,7 +2,7 @@ module Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
import Data.Either
|
import Data.Either
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import public Data.NP
|
import Data.NP
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,6 @@ Eq Order where
|
||||||
FOrder == COrder = False
|
FOrder == COrder = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
||||||
scanr _ q0 [] = [q0]
|
scanr _ q0 [] = [q0]
|
||||||
scanr f q0 (x::xs) = f x (head qs) :: qs
|
scanr f q0 (x::xs) = f x (head qs) :: qs
|
||||||
|
|
|
@ -8,14 +8,33 @@ import Data.NumIdr.Matrix
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
public export
|
|
||||||
HVector : Nat -> Type -> Type
|
|
||||||
HVector n = Vector (S n)
|
|
||||||
|
|
||||||
|
||| A homogeneous matrix type.
|
||||||
|
|||
|
||||||
|
||| Homogeneous matrices are matrices that encode an affine transformation, as
|
||||||
|
||| opposed to the more restriced linear transformations of normal matrices.
|
||||||
|
||| Their multiplication is identical to regular matrix multiplication, but the
|
||||||
|
||| extra dimension in each axis allows the homogeneous matrix to store a
|
||||||
|
||| translation vector that is applied during multiplication.
|
||||||
|
|||
|
||||||
|
||| ```hmatrix mat tr *. v == mat *. v + tr```
|
||||||
public export
|
public export
|
||||||
HMatrix : Nat -> Nat -> Type -> Type
|
HMatrix : Nat -> Nat -> Type -> Type
|
||||||
HMatrix m n = Matrix (S m) (S n)
|
HMatrix m n = Matrix (S m) (S n)
|
||||||
|
|
||||||
|
||| A synonym for a square homogeneous matrix.
|
||||||
|
public export
|
||||||
|
HMatrix' : Nat -> Type -> Type
|
||||||
|
HMatrix' n = HMatrix n n
|
||||||
|
|
||||||
|
||| An `n`-dimensional homogeneous vector type.
|
||||||
|
|||
|
||||||
|
||| Homogeneous vectors are vectors intended to be multiplied with homogeneous
|
||||||
|
||| matrices (see `HMatrix`). They have an extra dimension compared to regular
|
||||||
|
||| vectors.
|
||||||
|
public export
|
||||||
|
HVector : Nat -> Type -> Type
|
||||||
|
HVector n = Vector (S n)
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
|
|
@ -8,10 +8,12 @@ import Data.NumIdr.Vector
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
|
||| A matrix is a rank-2 array.
|
||||||
public export
|
public export
|
||||||
Matrix : Nat -> Nat -> Type -> Type
|
Matrix : Nat -> Nat -> Type -> Type
|
||||||
Matrix m n = Array [m,n]
|
Matrix m n = Array [m,n]
|
||||||
|
|
||||||
|
||| A synonym for a square matrix with dimensions of length `n`.
|
||||||
public export
|
public export
|
||||||
Matrix' : Nat -> Type -> Type
|
Matrix' : Nat -> Type -> Type
|
||||||
Matrix' n = Matrix n n
|
Matrix' n = Matrix n n
|
||||||
|
@ -22,26 +24,30 @@ Matrix' n = Matrix n n
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
||| Construct a matrix with the given order and elements.
|
||||||
export
|
export
|
||||||
matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a
|
matrix' : {m, n : _} -> Order -> Vect m (Vect n a) -> Matrix m n a
|
||||||
matrix' ord x = array' [m,n] ord x
|
matrix' ord x = array' [m,n] ord x
|
||||||
|
|
||||||
|
||| Construct a matrix with the given elements.
|
||||||
export
|
export
|
||||||
matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a
|
matrix : {m, n : _} -> Vect m (Vect n a) -> Matrix m n a
|
||||||
matrix = matrix' COrder
|
matrix = matrix' COrder
|
||||||
|
|
||||||
|
|
||||||
|
||| Construct a matrix with a specific value along the diagonal.
|
||||||
|
|||
|
||||||
|
||| @ diag The value to repeat along the diagonal
|
||||||
|
||| @ other The value to repeat elsewhere
|
||||||
export
|
export
|
||||||
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
repeatDiag : {m, n : _} -> (diag, other : a) -> Matrix m n a
|
||||||
repeatDiag d o = fromFunction [m,n]
|
repeatDiag d o = fromFunctionNB [m,n]
|
||||||
(\[i,j] => if i `eq` j then d else o)
|
(\[i,j] => if i == j then d else o)
|
||||||
where
|
|
||||||
eq : {0 m,n : Nat} -> Fin m -> Fin n -> Bool
|
|
||||||
eq FZ FZ = True
|
|
||||||
eq (FS x) (FS y) = eq x y
|
|
||||||
eq FZ (FS _) = False
|
|
||||||
eq (FS _) FZ = False
|
|
||||||
|
|
||||||
|
||| Construct a matrix given its diagonal elements.
|
||||||
|
|||
|
||||||
|
||| @ diag The elements of the matrix's diagonal
|
||||||
|
||| @ other The value to repeat elsewhere
|
||||||
export
|
export
|
||||||
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
fromDiag : {m, n : _} -> (diag : Vect (minimum m n) a) -> (other : a) -> Matrix m n a
|
||||||
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||||
|
@ -53,15 +59,18 @@ fromDiag ds o = fromFunction [m,n] (\[i,j] => maybe o (`index` ds) $ i `eq` j)
|
||||||
eq (FS _) FZ = Nothing
|
eq (FS _) FZ = Nothing
|
||||||
|
|
||||||
|
|
||||||
|
||| The `n`-dimensional identity matrix.
|
||||||
export
|
export
|
||||||
identity : Num a => {n : _} -> Matrix' n a
|
identity : Num a => {n : _} -> Matrix' n a
|
||||||
identity = repeatDiag 1 0
|
identity = repeatDiag 1 0
|
||||||
|
|
||||||
|
|
||||||
|
||| Calculate the matrix that scales a vector by the given value.
|
||||||
export
|
export
|
||||||
scaling : Num a => {n : _} -> a -> Matrix' n a
|
scaling : Num a => {n : _} -> a -> Matrix' n a
|
||||||
scaling x = repeatDiag x 0
|
scaling x = repeatDiag x 0
|
||||||
|
|
||||||
|
||| Calculate the rotation matrix of an angle.
|
||||||
export
|
export
|
||||||
rotation2D : Double -> Matrix' 2 Double
|
rotation2D : Double -> Matrix' 2 Double
|
||||||
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
||||||
|
@ -72,19 +81,24 @@ rotation2D a = matrix [[cos a, - sin a], [sin a, cos a]]
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
||| Index the matrix at the given coordinates.
|
||||||
export
|
export
|
||||||
index : Fin m -> Fin n -> Matrix m n a -> a
|
index : Fin m -> Fin n -> Matrix m n a -> a
|
||||||
index m n = index [m,n]
|
index m n = index [m,n]
|
||||||
|
|
||||||
|
||| Index the matrix at the given coordinates, returning `Nothing` if the
|
||||||
|
||| coordinates are out of bounds.
|
||||||
export
|
export
|
||||||
indexNB : Nat -> Nat -> Matrix m n a -> Maybe a
|
indexNB : Nat -> Nat -> Matrix m n a -> Maybe a
|
||||||
indexNB m n = indexNB [m,n]
|
indexNB m n = indexNB [m,n]
|
||||||
|
|
||||||
|
|
||||||
|
||| Return a row of the matrix as a vector.
|
||||||
export
|
export
|
||||||
getRow : Fin m -> Matrix m n a -> Vector n a
|
getRow : Fin m -> Matrix m n a -> Vector n a
|
||||||
getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat
|
getRow r mat = rewrite sym (minusZeroRight n) in indexRange [One r, All] mat
|
||||||
|
|
||||||
|
||| Return a column of the matrix as a vector.
|
||||||
export
|
export
|
||||||
getColumn : Fin n -> Matrix m n a -> Vector m a
|
getColumn : Fin n -> Matrix m n a -> Vector m a
|
||||||
getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
|
getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
|
||||||
|
@ -94,19 +108,22 @@ getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
|
||||||
-- Operations
|
-- Operations
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
||| Concatenate two matrices vertically.
|
||||||
export
|
export
|
||||||
vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
||||||
vconcat = concat 0
|
vconcat = concat 0
|
||||||
|
|
||||||
|
||| Concatenate two matrices horizontally.
|
||||||
export
|
export
|
||||||
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||||
hconcat = concat 1
|
hconcat = concat 1
|
||||||
|
|
||||||
|
|
||||||
|
||| Calculate the kronecker product of two vectors as a matrix.
|
||||||
export
|
export
|
||||||
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
||||||
kronecker a b with (viewShape a, viewShape b)
|
kronecker a b with (viewShape a, viewShape b)
|
||||||
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b)
|
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => a !! i * b !! j)
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
|
@ -6,32 +6,49 @@ module Data.NumIdr.Multiply
|
||||||
infixr 9 *.
|
infixr 9 *.
|
||||||
infixr 10 ^
|
infixr 10 ^
|
||||||
|
|
||||||
||| A generalized multiplication/transformation operator. This interface is
|
||| A generalized multiplication/application operator. This interface is
|
||||||
||| necessary since the standard multiplication operator is homogenous.
|
||| necessary since the standard multiplication operator is homogenous.
|
||||||
|
|||
|
||||||
|
||| All instances of this interface must collectively satisfy these axioms:
|
||||||
|
||| * If `(x *. y) *. z` is defined, then `x *. (y *. z)` is defined and equal.
|
||||||
|
||| * If `x *. (y *. z)` is defined, then `(x *. y) *. z` is defined and equal.
|
||||||
public export
|
public export
|
||||||
interface Mult a b c | a,b where
|
interface Mult a b c | a,b where
|
||||||
(*.) : a -> b -> c
|
(*.) : a -> b -> c
|
||||||
|
|
||||||
|
||| An interface for monoids using the `*.` operator.
|
||||||
|
|||
|
||||||
|
||| An instance of this interface must satisfy:
|
||||||
|
||| * `x *. neutral == x`
|
||||||
|
||| * `neutral *. x == x`
|
||||||
public export
|
public export
|
||||||
interface (Mult a a a) => MultNeutral a where
|
interface Mult a a a => MultNeutral a where
|
||||||
neutral : a
|
neutral : a
|
||||||
|
|
||||||
|
|
||||||
|
||| Multiplication forms a semigroup
|
||||||
public export
|
public export
|
||||||
[MultSemigroup] Mult a a a => Semigroup a where
|
[MultSemigroup] Mult a a a => Semigroup a where
|
||||||
(<+>) = (*.)
|
(<+>) = (*.)
|
||||||
|
|
||||||
|
||| Multiplication with a neutral element forms a monoid
|
||||||
public export
|
public export
|
||||||
[MultMonoid] MultNeutral a => Monoid a using MultSemigroup where
|
[MultMonoid] MultNeutral a => Monoid a using MultSemigroup where
|
||||||
neutral = Multiply.neutral
|
neutral = Multiply.neutral
|
||||||
|
|
||||||
|
|
||||||
|
||| Raise a multiplicative value (e.g. a matrix or a transformation) to a natural
|
||||||
|
||| number power.
|
||||||
public export
|
public export
|
||||||
power : MultNeutral a => Nat -> a -> a
|
power : MultNeutral a => Nat -> a -> a
|
||||||
power 0 _ = neutral
|
power 0 _ = neutral
|
||||||
power 1 x = x
|
power 1 x = x
|
||||||
power (S n@(S _)) x = x *. power n x
|
power (S n@(S _)) x = x *. power n x
|
||||||
|
|
||||||
|
||| Raise a multiplicative value (e.g. a matrix or a transformation) to a natural
|
||||||
|
||| number power.
|
||||||
|
|||
|
||||||
|
||| This is the operator form of `power`.
|
||||||
public export
|
public export
|
||||||
(^) : MultNeutral a => a -> Nat -> a
|
(^) : MultNeutral a => a -> Nat -> a
|
||||||
(^) = flip power
|
(^) = flip power
|
||||||
|
|
|
@ -35,8 +35,8 @@ export
|
||||||
constant : Nat -> a -> PrimArray a
|
constant : Nat -> a -> PrimArray a
|
||||||
constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x
|
constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x
|
||||||
|
|
||||||
||| Construct an array from a list of "instructions" to write a
|
||| Construct an array from a list of "instructions" to write a value to a
|
||||||
||| value to a particular index.
|
||| particular index.
|
||||||
export
|
export
|
||||||
unsafeFromIns : Nat -> List (Nat, a) -> PrimArray a
|
unsafeFromIns : Nat -> List (Nat, a) -> PrimArray a
|
||||||
unsafeFromIns size ins = unsafePerformIO $ do
|
unsafeFromIns size ins = unsafePerformIO $ do
|
||||||
|
@ -44,8 +44,8 @@ unsafeFromIns size ins = unsafePerformIO $ do
|
||||||
for_ ins $ \(i,x) => arrayDataSet i x arr
|
for_ ins $ \(i,x) => arrayDataSet i x arr
|
||||||
pure $ MkPrimArray size arr
|
pure $ MkPrimArray size arr
|
||||||
|
|
||||||
||| Create an array given its size and a function to generate
|
||| Create an array given its size and a function to generate its elements by
|
||||||
||| its elements by its index.
|
||| its index.
|
||||||
export
|
export
|
||||||
create : Nat -> (Nat -> a) -> PrimArray a
|
create : Nat -> (Nat -> a) -> PrimArray a
|
||||||
create size f = unsafePerformIO $ do
|
create size f = unsafePerformIO $ do
|
||||||
|
@ -59,8 +59,8 @@ create size f = unsafePerformIO $ do
|
||||||
= do arrayDataSet loc (f loc) arr
|
= do arrayDataSet loc (f loc) arr
|
||||||
addToArray (S loc) n arr
|
addToArray (S loc) n arr
|
||||||
|
|
||||||
||| Index into a primitive array. This function is unsafe, as it
|
||| Index into a primitive array. This function is unsafe, as it performs no
|
||||||
||| performs no boundary check on the index given.
|
||| boundary check on the index given.
|
||||||
export
|
export
|
||||||
index : Nat -> PrimArray a -> a
|
index : Nat -> PrimArray a -> a
|
||||||
index n arr = unsafePerformIO $ arrayDataGet n $ content arr
|
index n arr = unsafePerformIO $ arrayDataGet n $ content arr
|
||||||
|
@ -156,8 +156,8 @@ traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
|
||||||
traverse f = map fromList . traverse f . toList
|
traverse f = map fromList . traverse f . toList
|
||||||
|
|
||||||
|
|
||||||
||| Compares two primitive arrays for equal elements. This function assumes
|
||| Compares two primitive arrays for equal elements. This function assumes the
|
||||||
||| the arrays same the same length; it must not be used in any other case.
|
||| arrays have the same length; it must not be used in any other case.
|
||||||
export
|
export
|
||||||
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
|
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
|
||||||
unsafeEq a b = unsafePerformIO $
|
unsafeEq a b = unsafePerformIO $
|
||||||
|
|
|
@ -15,10 +15,12 @@ Scalar : Type -> Type
|
||||||
Scalar = Array []
|
Scalar = Array []
|
||||||
|
|
||||||
|
|
||||||
|
||| Convert a value to a scalar.
|
||||||
export
|
export
|
||||||
scalar : a -> Scalar a
|
scalar : a -> Scalar a
|
||||||
scalar x = fromVect _ [x]
|
scalar x = fromVect _ [x]
|
||||||
|
|
||||||
|
||| Unwrap the single value from a scalar.
|
||||||
export
|
export
|
||||||
unwrap : Scalar a -> a
|
unwrap : Scalar a -> a
|
||||||
unwrap = index 0 . getPrim
|
unwrap = index 0 . getPrim
|
||||||
|
|
|
@ -6,11 +6,14 @@ import public Data.NumIdr.Array
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
|
||| A vector is a rank-1 array.
|
||||||
public export
|
public export
|
||||||
Vector : Nat -> Type -> Type
|
Vector : Nat -> Type -> Type
|
||||||
Vector n = Array [n]
|
Vector n = Array [n]
|
||||||
|
|
||||||
|
|
||||||
|
||| The length (number of dimensions) of the vector
|
||||||
public export
|
public export
|
||||||
dim : Vector n a -> Nat
|
dim : Vector n a -> Nat
|
||||||
dim = head . shape
|
dim = head . shape
|
||||||
|
@ -25,6 +28,7 @@ dimEq v = cong head $ shapeEq v
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
||| Construct a vector from a `Vect`.
|
||||||
export
|
export
|
||||||
vector : Vect n a -> Vector n a
|
vector : Vect n a -> Vector n a
|
||||||
vector v = rewrite sym (lengthCorrect v)
|
vector v = rewrite sym (lengthCorrect v)
|
||||||
|
@ -32,20 +36,28 @@ vector v = rewrite sym (lengthCorrect v)
|
||||||
rewrite lengthCorrect v in -- there is only 1 axis
|
rewrite lengthCorrect v in -- there is only 1 axis
|
||||||
rewrite multOneLeftNeutral n in v
|
rewrite multOneLeftNeutral n in v
|
||||||
|
|
||||||
|
||| Convert a vector into a `Vect`.
|
||||||
export
|
export
|
||||||
toVect : Vector n a -> Vect n a
|
toVect : Vector n a -> Vect n a
|
||||||
toVect v = believe_me $ Vect.fromList $ toList v
|
toVect v = believe_me $ Vect.fromList $ toList v
|
||||||
|
|
||||||
|
|
||||||
|
||| Return the `i`-th basis vector.
|
||||||
export
|
export
|
||||||
basis : Num a => {n : _} -> (i : Fin n) -> Vector n a
|
basis : Num a => {n : _} -> (i : Fin n) -> Vector n a
|
||||||
basis i = fromFunction _ (\[j] => if i == j then 1 else 0)
|
basis i = fromFunction _ (\[j] => if i == j then 1 else 0)
|
||||||
|
|
||||||
|
|
||||||
|
||| Calculate the 2D unit vector with the given angle off the x-axis.
|
||||||
export
|
export
|
||||||
unit2D : (ang : Double) -> Vector 2 Double
|
unit2D : (ang : Double) -> Vector 2 Double
|
||||||
unit2D ang = vector [cos ang, sin ang]
|
unit2D ang = vector [cos ang, sin ang]
|
||||||
|
|
||||||
|
||| Calculate the 3D unit vector corresponding to the given spherical coordinates,
|
||||||
|
||| where the polar axis is the z-axis.
|
||||||
|
|||
|
||||||
|
||| @ pol The polar angle of the vector
|
||||||
|
||| @ az The azimuthal angle of the vector
|
||||||
export
|
export
|
||||||
unit3D : (pol, az : Double) -> Vector 3 Double
|
unit3D : (pol, az : Double) -> Vector 3 Double
|
||||||
unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol]
|
unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol]
|
||||||
|
@ -56,35 +68,49 @@ unit3D pol az = vector [cos az * sin pol, sin az * sin pol, cos pol]
|
||||||
-- Indexing
|
-- Indexing
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
infix 10 !!
|
infixl 10 !!
|
||||||
infix 10 !?
|
infixl 10 !?
|
||||||
|
|
||||||
|
||| Index the vector at the given coordinate.
|
||||||
export
|
export
|
||||||
index : Fin n -> Vector n a -> a
|
index : Fin n -> Vector n a -> a
|
||||||
index n = Array.index [n]
|
index n = Array.index [n]
|
||||||
|
|
||||||
|
||| Index the vector at the given coordinate.
|
||||||
|
|||
|
||||||
|
||| This is the operator form of `index`.
|
||||||
export
|
export
|
||||||
(!!) : Vector n a -> Fin n -> a
|
(!!) : Vector n a -> Fin n -> a
|
||||||
(!!) = flip index
|
(!!) = flip index
|
||||||
|
|
||||||
|
||| Index the vector at the given coordinate, returning `Nothing` if the
|
||||||
|
||| coordinate is out of bounds.
|
||||||
export
|
export
|
||||||
indexNB : Nat -> Vector n a -> Maybe a
|
indexNB : Nat -> Vector n a -> Maybe a
|
||||||
indexNB n = Array.indexNB [n]
|
indexNB n = Array.indexNB [n]
|
||||||
|
|
||||||
|
||| Index the vector at the given coordinate, returning `Nothing` if the
|
||||||
|
||| coordinate is out of bounds.
|
||||||
|
|||
|
||||||
|
||| This is the operator form of `indexNB`.
|
||||||
export
|
export
|
||||||
(!?) : Vector n a -> Nat -> Maybe a
|
(!?) : Vector n a -> Nat -> Maybe a
|
||||||
(!?) = flip indexNB
|
(!?) = flip indexNB
|
||||||
|
|
||||||
|
|
||||||
-- Named projections
|
-- Named projections
|
||||||
|
|
||||||
|
||| Return the x-coordinate (the first value) of the vector.
|
||||||
export
|
export
|
||||||
(.x) : Vector (1 + n) a -> a
|
(.x) : Vector (1 + n) a -> a
|
||||||
(.x) = index FZ
|
(.x) = index FZ
|
||||||
|
|
||||||
|
||| Return the y-coordinate (the second value) of the vector.
|
||||||
export
|
export
|
||||||
(.y) : Vector (2 + n) a -> a
|
(.y) : Vector (2 + n) a -> a
|
||||||
(.y) = index (FS FZ)
|
(.y) = index (FS FZ)
|
||||||
|
|
||||||
|
||| Return the z-coordinate (the third value) of the vector.
|
||||||
export
|
export
|
||||||
(.z) : Vector (3 + n) a -> a
|
(.z) : Vector (3 + n) a -> a
|
||||||
(.z) = index (FS (FS FZ))
|
(.z) = index (FS (FS FZ))
|
||||||
|
@ -109,16 +135,19 @@ swizzle p v = rewrite sym (lengthCorrect p)
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
||| Concatenate one vector with another.
|
||||||
export
|
export
|
||||||
(++) : Vector m a -> Vector n a -> Vector (m + n) a
|
(++) : Vector m a -> Vector n a -> Vector (m + n) a
|
||||||
(++) = concat 0
|
(++) = concat 0
|
||||||
|
|
||||||
|
|
||||||
|
||| Calculate the dot product of the two vectors.
|
||||||
export
|
export
|
||||||
dot : Num a => Vector n a -> Vector n a -> a
|
dot : Num a => Vector n a -> Vector n a -> a
|
||||||
dot = sum .: zipWith (*)
|
dot = sum .: zipWith (*)
|
||||||
|
|
||||||
|
|
||||||
|
||| Calculate the perpendicular product of the two vectors.
|
||||||
export
|
export
|
||||||
perp : Neg a => Vector 2 a -> Vector 2 a -> a
|
perp : Neg a => Vector 2 a -> Vector 2 a -> a
|
||||||
perp a b = a.x * b.y - a.y * b.x
|
perp a b = a.x * b.y - a.y * b.x
|
||||||
|
|
Loading…
Reference in a new issue