Update docstrings for new backend

This commit is contained in:
Kiana Sheibani 2024-03-01 18:55:43 -05:00
parent 9a5c0be049
commit 4e3c524b08
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
2 changed files with 85 additions and 17 deletions

View file

@ -13,6 +13,8 @@ import Data.NumIdr.Array.Coords
%default total
||| The type of an array.
|||
||| Arrays are the central data structure of NumIdr. They are an `n`-dimensional
||| grid of values, where `n` is a value known as the *rank* of the array. Arrays
||| of rank 0 are single values, arrays of rank 1 are vectors, and arrays of rank
@ -23,22 +25,18 @@ import Data.NumIdr.Array.Coords
||| array's total size.
|||
||| Arrays are indexed by row first, as in the standard mathematical notation for
||| matrices. This is independent of the actual order in which they are stored; the
||| default order is row-major, but this is configurable.
||| matrices.
|||
||| @ rk The rank of the array
||| @ s The shape of the array
||| @ a The type of the array's elements
export
data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
||| Internally, arrays are stored using Idris's primitive array type. This is
||| stored along with the array's shape, and a vector of values called the
||| *strides*, which determine how indexes into the internal array should be
||| performed. This is how the order of the array is configurable.
||| Internally, arrays are stored via one of a handful of representations.
|||
||| @ ord The order of the elements of the array
||| @ sts The strides of the array
||| @ s The shape of the array
||| @ rep The internal representation of the array
||| @ rc A witness that the element type satisfies the representation constraint
MkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
PrimArray rep s a @{rc} -> Array s a
@ -56,24 +54,27 @@ unsafeMkArray = MkArray
--------------------------------------------------------------------------------
||| The shape of the array
||| The shape of the array.
export
shape : Array {rk} s a -> Vect rk Nat
shape (MkArray _ s _) = s
||| The rank of the array
||| The rank of the array.
export
rank : Array s a -> Nat
rank = length . shape
||| The internal representation of the array.
export
getRep : Array s a -> Rep
getRep (MkArray rep _ _) = rep
||| The representation constraint of the array.
export
getRepC : (arr : Array s a) -> RepConstraint (getRep arr) a
getRepC (MkArray _ @{rc} _ _) = rc
||| Extract the primitive backend array.
export
getPrim : (arr : Array s a) -> PrimArray (getRep arr) s a @{getRepC arr}
getPrim (MkArray _ _ pr) = pr
@ -340,14 +341,26 @@ arr !#.. is = indexRangeUnsafe is arr
-- Operations on arrays
--------------------------------------------------------------------------------
||| Map a function over an array.
|||
||| You should almost always use `map` instead; only use this function if you
||| know what you are doing!
export
mapArray' : (a -> a) -> Array s a -> Array s a
mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
||| Map a function over an array.
|||
||| You should almost always use `map` instead; only use this function if you
||| know what you are doing!
export
mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
||| Combine two arrays of the same element type using a binary function.
|||
||| You should almost always use `zipWith` instead; only use this function if
||| you know what you are doing!
export
zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
zipWithArray' {s} f a b with (viewShape a)
@ -356,6 +369,10 @@ zipWithArray' {s} f a b with (viewShape a)
$ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
(\is => f (a !# is) (b !# is))
||| Combine two arrays using a binary function.
|||
||| You should almost always use `zipWith` instead; only use this function if
||| you know what you are doing!
export
zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
@ -446,6 +463,7 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in
getAxisInd FZ (i :: is) = (i, is)
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
||| Join the axes of a nested array structure to form a single array.
export
joinAxes : {s' : _} -> Array s (Array s' a) -> Array (s ++ s') a
joinAxes {s} arr with (viewShape arr)
@ -460,6 +478,7 @@ joinAxes {s} arr with (viewShape arr)
dropUpTo (x::xs) (y::ys) = dropUpTo xs ys
||| Split an array into a nested array structure along the specified axes.
export
splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
@ -577,7 +596,8 @@ export
-- Foldable and Traversable operate on the primitive array directly. This means
-- that their operation is dependent on the order of the array.
-- that their operation is dependent on the internal representation of the
-- array.
export
Foldable (Array s) where