Partially reimplement main API using reps

This commit is contained in:
Kiana Sheibani 2023-05-05 13:43:14 -04:00
parent bd235754d2
commit ccb689af42
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 153 additions and 327 deletions

View file

@ -1,7 +1,6 @@
module Data.NumIdr.Array.Array
import Data.List
import Data.List1
import Data.Vect
import Data.Zippable
import Data.NP
@ -10,7 +9,6 @@ import Data.NumIdr.Interfaces
import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords
import Data.NumIdr.Array.Shape
%default total
@ -41,14 +39,15 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
||| @ ord The order of the elements of the array
||| @ sts The strides of the array
||| @ s The shape of the array
MkArray : (ord : Order) -> (sts : Vect rk Nat) ->
(s : Vect rk Nat) -> PrimArray a -> Array s a
MkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
PrimArray rep s a @{rc} -> Array s a
%name Array arr
export
unsafeMkArray : Order -> Vect rk Nat -> (s : Vect rk Nat) -> PrimArray a -> Array s a
unsafeMkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
PrimArray rep s a @{rc} -> Array s a
unsafeMkArray = MkArray
@ -57,57 +56,35 @@ unsafeMkArray = MkArray
--------------------------------------------------------------------------------
||| Extract the primitive array value.
export
getPrim : Array s a -> PrimArray a
getPrim (MkArray _ _ _ arr) = arr
||| The order of the elements of the array
export
getOrder : Array s a -> Order
getOrder (MkArray ord _ _ _) = ord
||| The strides of the array, returned in the same axis order as in the shape.
export
strides : Array {rk} s a -> Vect rk Nat
strides (MkArray _ sts _ _) = sts
||| The total number of elements of the array
|||
||| This is equivalent to `product s`.
export
size : Array s a -> Nat
size = length . getPrim
||| The shape of the array
export
shape : Array {rk} s a -> Vect rk Nat
shape (MkArray _ _ s _) = s
shape (MkArray _ s _) = s
||| The rank of the array
export
rank : Array s a -> Nat
rank = length . shape
export
getRep : Array s a -> Rep
getRep (MkArray rep _ _) = rep
-- Get a list of all coordinates
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
getAllCoords' = traverse (\case Z => []; S n => [0..n])
getRepC : (arr : Array s a) -> RepConstraint (getRep arr) a
getRepC (MkArray _ @{rc} _ _) = rc
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
getAllCoords [] = pure []
getAllCoords (Z :: s) = []
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
export
getPrim : (arr : Array s a) -> PrimArray (getRep arr) s a @{getRepC arr}
getPrim (MkArray _ _ pr) = pr
--------------------------------------------------------------------------------
-- Shape view
--------------------------------------------------------------------------------
export
shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl
shapeEq (MkArray _ _ _) = Refl
||| A view for extracting the shape of an array.
@ -128,109 +105,70 @@ viewShape arr = rewrite shapeEq arr in
-- Array constructors
--------------------------------------------------------------------------------
||| Create an array by repeating a single value.
|||
||| @ s The shape of the constructed array
||| @ ord The order of the constructed array
||| @ rep The internal representation of the constructed array
export
repeat' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
repeat' s ord x = MkArray ord (calcStrides ord s) s (constant (product s) x)
||| Create an array by repeating a single value.
||| To specify the order of the array, use `repeat'`.
|||
||| @ s The shape of the constructed array
export
repeat : (s : Vect rk Nat) -> a -> Array s a
repeat s = repeat' s COrder
repeat : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> a -> Array s a
repeat s x = MkArray rep s (constant s x)
||| Create an array filled with zeros.
|||
||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array
export
zeros : Num a => (s : Vect rk Nat) -> Array s a
zeros s = repeat s 0
zeros : {default B rep : Rep} -> RepConstraint rep a =>
Num a => (s : Vect rk Nat) -> Array s a
zeros {rep} s = repeat {rep} s 0
||| Create an array filled with ones.
|||
||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array
export
ones : Num a => (s : Vect rk Nat) -> Array s a
ones s = repeat s 1
ones : {default B rep : Rep} -> RepConstraint rep a =>
Num a => (s : Vect rk Nat) -> Array s a
ones {rep} s = repeat {rep} s 1
||| Create an array given a vector of its elements. The elements of the vector
||| are arranged into the provided shape using the provided order.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
||| @ rep The internal representation of the constructed array
export
fromVect' : (s : Vect rk Nat) -> (ord : Order) -> Vect (product s) a -> Array s a
fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
fromVect : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> Vect (product s) a -> Array s a
fromVect s v = ?fv
||| Create an array given a vector of its elements. The elements of the vector
||| are arranged into the provided shape using row-major order (the last axis is the
||| least significant).
||| To specify the order of the array, use `fromVect'`.
|||
||| @ s The shape of the constructed array
export
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
fromVect s = fromVect' s COrder
||| Create an array by taking values from a stream.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
||| @ rep The internal representation of the constructed array
export
fromStream' : (s : Vect rk Nat) -> (ord : Order) -> Stream a -> Array s a
fromStream' s ord st = MkArray ord (calcStrides ord s) s (fromList $ take (product s) st)
||| Create an array by taking values from a stream.
||| To specify the order of the array, use `fromStream'`.
|||
||| @ s The shape of the constructed array
export
fromStream : (s : Vect rk Nat) -> Stream a -> Array s a
fromStream s = fromStream' s COrder
fromStream : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> Stream a -> Array s a
fromStream {rep} s str = fromVect {rep} s (take _ str)
||| Create an array given a function to generate its elements.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
||| @ rep The internal representation of the constructed array
export
fromFunctionNB' : (s : Vect rk Nat) -> (ord : Order) -> (Vect rk Nat -> a) -> Array s a
fromFunctionNB' s ord f = let sts = calcStrides ord s
in MkArray ord sts s (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, f is)) $ getAllCoords' s)
||| Create an array given a function to generate its elements.
||| To specify the order of the array, use `fromFunctionNB'`.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
export
fromFunctionNB : (s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
fromFunctionNB s = fromFunctionNB' s COrder
fromFunctionNB : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
fromFunctionNB s f = MkArray rep s (PrimArray.fromFunctionNB s f)
||| Create an array given a function to generate its elements.
|||
||| @ s The shape of the constructed array
||| @ ord The order to interpret the elements
||| @ rep The internal representation of the constructed array
export
fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a
fromFunction' s ord f = let sts = calcStrides ord s
in MkArray ord sts s (unsafeFromIns (product s) $
map (\is => (getLocation sts is, f is)) $ getAllCoords s)
||| Create an array given a function to generate its elements.
||| To specify the order of the array, use `fromFunction'`.
|||
||| @ s The shape of the constructed array
export
fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
fromFunction s = fromFunction' s COrder
fromFunction : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> (Coords s -> a) -> Array s a
fromFunction s f = MkArray rep s (PrimArray.fromFunction s f)
||| Construct an array using a structure of nested vectors. The elements are arranged
||| to the specified order before being written.
@ -238,27 +176,22 @@ fromFunction s = fromFunction' s COrder
||| @ s The shape of the constructed array
||| @ ord The order of the constructed array
export
array' : (s : Vect rk Nat) -> (ord : Order) -> Vects s a -> Array s a
array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
where
sts : Vect rk Nat
sts = calcStrides ord s
ins : List (Nat, a)
ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v
array' : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> Vects s a -> Array s a
array' s v = MkArray rep s (fromVects s v)
||| Construct an array using a structure of nested vectors.
||| To explicitly specify the shape and order of the array, use `array'`.
export
array : {s : Vect rk Nat} -> Vects s a -> Array s a
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
array : {default B rep : Rep} -> RepConstraint rep a =>
{s : Vect rk Nat} -> Vects s a -> Array s a
array {rep} = array' {rep} _
--------------------------------------------------------------------------------
-- Indexing
--------------------------------------------------------------------------------
infixl 10 !!
infixl 10 !?
infixl 10 !#
@ -270,7 +203,7 @@ infixl 11 !#..
||| Index the array using the given coordinates.
export
index : Coords s -> Array s a -> a
index is arr = index (getLocation (strides arr) is) (getPrim arr)
index is (MkArray _ _ arr) = PrimArray.index is arr
||| Index the array using the given coordinates.
|||
@ -278,7 +211,7 @@ index is arr = index (getLocation (strides arr) is) (getPrim arr)
export %inline
(!!) : Array s a -> Coords s -> a
arr !! is = index is arr
{-
||| Update the entry at the given coordinates using the function.
export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
@ -333,15 +266,13 @@ indexUpdateRange : (rs : CoordsRange s) ->
(Array (newShape rs) a -> Array (newShape rs) a) ->
Array s a -> Array s a
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
-}
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is arr = if all id $ zipWith (<) is (shape arr)
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
else Nothing
indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
@ -350,7 +281,7 @@ indexNB is arr = if all id $ zipWith (<) is (shape arr)
export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr
{-
||| Update the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds.
export
@ -393,14 +324,14 @@ indexRangeNB {s} rs arr with (viewShape arr)
export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
arr !?.. rs = indexRangeNB rs arr
-}
||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
export
indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr)
indexUnsafe is (MkArray _ _ arr) = PrimArray.indexUnsafe is arr
||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs.
@ -411,7 +342,7 @@ export %inline
(!#) : Array {rk} s a -> Vect rk Nat -> a
arr !# is = indexUnsafe is arr
{-
||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
@ -784,3 +715,4 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export
pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p)
-}