Partially reimplement main API using reps
This commit is contained in:
parent
bd235754d2
commit
ccb689af42
|
@ -1,7 +1,6 @@
|
||||||
module Data.NumIdr.Array.Array
|
module Data.NumIdr.Array.Array
|
||||||
|
|
||||||
import Data.List
|
import Data.List
|
||||||
import Data.List1
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.Zippable
|
import Data.Zippable
|
||||||
import Data.NP
|
import Data.NP
|
||||||
|
@ -10,7 +9,6 @@ import Data.NumIdr.Interfaces
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import Data.NumIdr.Array.Rep
|
import Data.NumIdr.Array.Rep
|
||||||
import Data.NumIdr.Array.Coords
|
import Data.NumIdr.Array.Coords
|
||||||
import Data.NumIdr.Array.Shape
|
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
@ -41,14 +39,15 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
|
||||||
||| @ ord The order of the elements of the array
|
||| @ ord The order of the elements of the array
|
||||||
||| @ sts The strides of the array
|
||| @ sts The strides of the array
|
||||||
||| @ s The shape of the array
|
||| @ s The shape of the array
|
||||||
MkArray : (ord : Order) -> (sts : Vect rk Nat) ->
|
MkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
|
||||||
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
PrimArray rep s a @{rc} -> Array s a
|
||||||
|
|
||||||
%name Array arr
|
%name Array arr
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
unsafeMkArray : Order -> Vect rk Nat -> (s : Vect rk Nat) -> PrimArray a -> Array s a
|
unsafeMkArray : (rep : Rep) -> (rc : RepConstraint rep a) => (s : Vect rk Nat) ->
|
||||||
|
PrimArray rep s a @{rc} -> Array s a
|
||||||
unsafeMkArray = MkArray
|
unsafeMkArray = MkArray
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,57 +56,35 @@ unsafeMkArray = MkArray
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
||| Extract the primitive array value.
|
|
||||||
export
|
|
||||||
getPrim : Array s a -> PrimArray a
|
|
||||||
getPrim (MkArray _ _ _ arr) = arr
|
|
||||||
|
|
||||||
||| The order of the elements of the array
|
|
||||||
export
|
|
||||||
getOrder : Array s a -> Order
|
|
||||||
getOrder (MkArray ord _ _ _) = ord
|
|
||||||
|
|
||||||
||| The strides of the array, returned in the same axis order as in the shape.
|
|
||||||
export
|
|
||||||
strides : Array {rk} s a -> Vect rk Nat
|
|
||||||
strides (MkArray _ sts _ _) = sts
|
|
||||||
|
|
||||||
||| The total number of elements of the array
|
|
||||||
|||
|
|
||||||
||| This is equivalent to `product s`.
|
|
||||||
export
|
|
||||||
size : Array s a -> Nat
|
|
||||||
size = length . getPrim
|
|
||||||
|
|
||||||
||| The shape of the array
|
||| The shape of the array
|
||||||
export
|
export
|
||||||
shape : Array {rk} s a -> Vect rk Nat
|
shape : Array {rk} s a -> Vect rk Nat
|
||||||
shape (MkArray _ _ s _) = s
|
shape (MkArray _ s _) = s
|
||||||
|
|
||||||
||| The rank of the array
|
||| The rank of the array
|
||||||
export
|
export
|
||||||
rank : Array s a -> Nat
|
rank : Array s a -> Nat
|
||||||
rank = length . shape
|
rank = length . shape
|
||||||
|
|
||||||
|
export
|
||||||
|
getRep : Array s a -> Rep
|
||||||
|
getRep (MkArray rep _ _) = rep
|
||||||
|
|
||||||
-- Get a list of all coordinates
|
getRepC : (arr : Array s a) -> RepConstraint (getRep arr) a
|
||||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
getRepC (MkArray _ @{rc} _ _) = rc
|
||||||
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
|
||||||
|
|
||||||
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
|
export
|
||||||
getAllCoords [] = pure []
|
getPrim : (arr : Array s a) -> PrimArray (getRep arr) s a @{getRepC arr}
|
||||||
getAllCoords (Z :: s) = []
|
getPrim (MkArray _ _ pr) = pr
|
||||||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Shape view
|
-- Shape view
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
shapeEq : (arr : Array s a) -> s = shape arr
|
shapeEq : (arr : Array s a) -> s = shape arr
|
||||||
shapeEq (MkArray _ _ _ _) = Refl
|
shapeEq (MkArray _ _ _) = Refl
|
||||||
|
|
||||||
|
|
||||||
||| A view for extracting the shape of an array.
|
||| A view for extracting the shape of an array.
|
||||||
|
@ -128,109 +105,70 @@ viewShape arr = rewrite shapeEq arr in
|
||||||
-- Array constructors
|
-- Array constructors
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
||| Create an array by repeating a single value.
|
||| Create an array by repeating a single value.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order of the constructed array
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
repeat' : (s : Vect rk Nat) -> (ord : Order) -> a -> Array s a
|
repeat : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
repeat' s ord x = MkArray ord (calcStrides ord s) s (constant (product s) x)
|
(s : Vect rk Nat) -> a -> Array s a
|
||||||
|
repeat s x = MkArray rep s (constant s x)
|
||||||
||| Create an array by repeating a single value.
|
|
||||||
||| To specify the order of the array, use `repeat'`.
|
|
||||||
|||
|
|
||||||
||| @ s The shape of the constructed array
|
|
||||||
export
|
|
||||||
repeat : (s : Vect rk Nat) -> a -> Array s a
|
|
||||||
repeat s = repeat' s COrder
|
|
||||||
|
|
||||||
||| Create an array filled with zeros.
|
||| Create an array filled with zeros.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
zeros : Num a => (s : Vect rk Nat) -> Array s a
|
zeros : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
zeros s = repeat s 0
|
Num a => (s : Vect rk Nat) -> Array s a
|
||||||
|
zeros {rep} s = repeat {rep} s 0
|
||||||
|
|
||||||
||| Create an array filled with ones.
|
||| Create an array filled with ones.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
ones : Num a => (s : Vect rk Nat) -> Array s a
|
ones : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
ones s = repeat s 1
|
Num a => (s : Vect rk Nat) -> Array s a
|
||||||
|
ones {rep} s = repeat {rep} s 1
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are arranged into the provided shape using the provided order.
|
||| are arranged into the provided shape using the provided order.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order to interpret the elements
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
fromVect' : (s : Vect rk Nat) -> (ord : Order) -> Vect (product s) a -> Array s a
|
fromVect : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
(s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||||
|
fromVect s v = ?fv
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
|
||||||
||| are arranged into the provided shape using row-major order (the last axis is the
|
|
||||||
||| least significant).
|
|
||||||
||| To specify the order of the array, use `fromVect'`.
|
|
||||||
|||
|
|
||||||
||| @ s The shape of the constructed array
|
|
||||||
export
|
|
||||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
|
||||||
fromVect s = fromVect' s COrder
|
|
||||||
|
|
||||||
||| Create an array by taking values from a stream.
|
||| Create an array by taking values from a stream.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order to interpret the elements
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
fromStream' : (s : Vect rk Nat) -> (ord : Order) -> Stream a -> Array s a
|
fromStream : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
fromStream' s ord st = MkArray ord (calcStrides ord s) s (fromList $ take (product s) st)
|
(s : Vect rk Nat) -> Stream a -> Array s a
|
||||||
|
fromStream {rep} s str = fromVect {rep} s (take _ str)
|
||||||
||| Create an array by taking values from a stream.
|
|
||||||
||| To specify the order of the array, use `fromStream'`.
|
|
||||||
|||
|
|
||||||
||| @ s The shape of the constructed array
|
|
||||||
export
|
|
||||||
fromStream : (s : Vect rk Nat) -> Stream a -> Array s a
|
|
||||||
fromStream s = fromStream' s COrder
|
|
||||||
|
|
||||||
|
|
||||||
||| Create an array given a function to generate its elements.
|
||| Create an array given a function to generate its elements.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order to interpret the elements
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
fromFunctionNB' : (s : Vect rk Nat) -> (ord : Order) -> (Vect rk Nat -> a) -> Array s a
|
fromFunctionNB : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
fromFunctionNB' s ord f = let sts = calcStrides ord s
|
(s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
|
||||||
in MkArray ord sts s (unsafeFromIns (product s) $
|
fromFunctionNB s f = MkArray rep s (PrimArray.fromFunctionNB s f)
|
||||||
map (\is => (getLocation' sts is, f is)) $ getAllCoords' s)
|
|
||||||
|
|
||||||
||| Create an array given a function to generate its elements.
|
|
||||||
||| To specify the order of the array, use `fromFunctionNB'`.
|
|
||||||
|||
|
|
||||||
||| @ s The shape of the constructed array
|
|
||||||
||| @ ord The order to interpret the elements
|
|
||||||
export
|
|
||||||
fromFunctionNB : (s : Vect rk Nat) -> (Vect rk Nat -> a) -> Array s a
|
|
||||||
fromFunctionNB s = fromFunctionNB' s COrder
|
|
||||||
|
|
||||||
||| Create an array given a function to generate its elements.
|
||| Create an array given a function to generate its elements.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order to interpret the elements
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
fromFunction' : (s : Vect rk Nat) -> (ord : Order) -> (Coords s -> a) -> Array s a
|
fromFunction : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
fromFunction' s ord f = let sts = calcStrides ord s
|
(s : Vect rk Nat) -> (Coords s -> a) -> Array s a
|
||||||
in MkArray ord sts s (unsafeFromIns (product s) $
|
fromFunction s f = MkArray rep s (PrimArray.fromFunction s f)
|
||||||
map (\is => (getLocation sts is, f is)) $ getAllCoords s)
|
|
||||||
|
|
||||||
||| Create an array given a function to generate its elements.
|
|
||||||
||| To specify the order of the array, use `fromFunction'`.
|
|
||||||
|||
|
|
||||||
||| @ s The shape of the constructed array
|
|
||||||
export
|
|
||||||
fromFunction : (s : Vect rk Nat) -> (Coords s -> a) -> Array s a
|
|
||||||
fromFunction s = fromFunction' s COrder
|
|
||||||
|
|
||||||
||| Construct an array using a structure of nested vectors. The elements are arranged
|
||| Construct an array using a structure of nested vectors. The elements are arranged
|
||||||
||| to the specified order before being written.
|
||| to the specified order before being written.
|
||||||
|
@ -238,27 +176,22 @@ fromFunction s = fromFunction' s COrder
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order of the constructed array
|
||| @ ord The order of the constructed array
|
||||||
export
|
export
|
||||||
array' : (s : Vect rk Nat) -> (ord : Order) -> Vects s a -> Array s a
|
array' : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
(s : Vect rk Nat) -> Vects s a -> Array s a
|
||||||
where
|
array' s v = MkArray rep s (fromVects s v)
|
||||||
sts : Vect rk Nat
|
|
||||||
sts = calcStrides ord s
|
|
||||||
|
|
||||||
ins : List (Nat, a)
|
|
||||||
ins = collapse $ mapWithIndex (MkPair . getLocation' sts) v
|
|
||||||
|
|
||||||
||| Construct an array using a structure of nested vectors.
|
||| Construct an array using a structure of nested vectors.
|
||||||
||| To explicitly specify the shape and order of the array, use `array'`.
|
||| To explicitly specify the shape and order of the array, use `array'`.
|
||||||
export
|
export
|
||||||
array : {s : Vect rk Nat} -> Vects s a -> Array s a
|
array : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
{s : Vect rk Nat} -> Vects s a -> Array s a
|
||||||
|
array {rep} = array' {rep} _
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Indexing
|
-- Indexing
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
infixl 10 !!
|
infixl 10 !!
|
||||||
infixl 10 !?
|
infixl 10 !?
|
||||||
infixl 10 !#
|
infixl 10 !#
|
||||||
|
@ -270,7 +203,7 @@ infixl 11 !#..
|
||||||
||| Index the array using the given coordinates.
|
||| Index the array using the given coordinates.
|
||||||
export
|
export
|
||||||
index : Coords s -> Array s a -> a
|
index : Coords s -> Array s a -> a
|
||||||
index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
index is (MkArray _ _ arr) = PrimArray.index is arr
|
||||||
|
|
||||||
||| Index the array using the given coordinates.
|
||| Index the array using the given coordinates.
|
||||||
|||
|
|||
|
||||||
|
@ -278,7 +211,7 @@ index is arr = index (getLocation (strides arr) is) (getPrim arr)
|
||||||
export %inline
|
export %inline
|
||||||
(!!) : Array s a -> Coords s -> a
|
(!!) : Array s a -> Coords s -> a
|
||||||
arr !! is = index is arr
|
arr !! is = index is arr
|
||||||
|
{-
|
||||||
||| Update the entry at the given coordinates using the function.
|
||| Update the entry at the given coordinates using the function.
|
||||||
export
|
export
|
||||||
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
||||||
|
@ -333,15 +266,13 @@ indexUpdateRange : (rs : CoordsRange s) ->
|
||||||
(Array (newShape rs) a -> Array (newShape rs) a) ->
|
(Array (newShape rs) a -> Array (newShape rs) a) ->
|
||||||
Array s a -> Array s a
|
Array s a -> Array s a
|
||||||
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
|
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
|
||||||
|
-}
|
||||||
|
|
||||||
||| Index the array using the given coordinates, returning `Nothing` if the
|
||| Index the array using the given coordinates, returning `Nothing` if the
|
||||||
||| coordinates are out of bounds.
|
||| coordinates are out of bounds.
|
||||||
export
|
export
|
||||||
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
||||||
indexNB is arr = if all id $ zipWith (<) is (shape arr)
|
indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
|
||||||
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
|
|
||||||
else Nothing
|
|
||||||
|
|
||||||
||| Index the array using the given coordinates, returning `Nothing` if the
|
||| Index the array using the given coordinates, returning `Nothing` if the
|
||||||
||| coordinates are out of bounds.
|
||| coordinates are out of bounds.
|
||||||
|
@ -350,7 +281,7 @@ indexNB is arr = if all id $ zipWith (<) is (shape arr)
|
||||||
export %inline
|
export %inline
|
||||||
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
||||||
arr !? is = indexNB is arr
|
arr !? is = indexNB is arr
|
||||||
|
{-
|
||||||
||| Update the entry at the given coordinates using the function. `Nothing` is
|
||| Update the entry at the given coordinates using the function. `Nothing` is
|
||||||
||| returned if the coordinates are out of bounds.
|
||| returned if the coordinates are out of bounds.
|
||||||
export
|
export
|
||||||
|
@ -393,14 +324,14 @@ indexRangeNB {s} rs arr with (viewShape arr)
|
||||||
export %inline
|
export %inline
|
||||||
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
|
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
|
||||||
arr !?.. rs = indexRangeNB rs arr
|
arr !?.. rs = indexRangeNB rs arr
|
||||||
|
-}
|
||||||
|
|
||||||
||| Index the array using the given coordinates.
|
||| Index the array using the given coordinates.
|
||||||
||| WARNING: This function does not perform any bounds check on its inputs.
|
||| WARNING: This function does not perform any bounds check on its inputs.
|
||||||
||| Misuse of this function can easily break memory safety.
|
||| Misuse of this function can easily break memory safety.
|
||||||
export
|
export
|
||||||
indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
|
indexUnsafe : Vect rk Nat -> Array {rk} s a -> a
|
||||||
indexUnsafe is arr = index (getLocation' (strides arr) is) (getPrim arr)
|
indexUnsafe is (MkArray _ _ arr) = PrimArray.indexUnsafe is arr
|
||||||
|
|
||||||
||| Index the array using the given coordinates.
|
||| Index the array using the given coordinates.
|
||||||
||| WARNING: This function does not perform any bounds check on its inputs.
|
||| WARNING: This function does not perform any bounds check on its inputs.
|
||||||
|
@ -411,7 +342,7 @@ export %inline
|
||||||
(!#) : Array {rk} s a -> Vect rk Nat -> a
|
(!#) : Array {rk} s a -> Vect rk Nat -> a
|
||||||
arr !# is = indexUnsafe is arr
|
arr !# is = indexUnsafe is arr
|
||||||
|
|
||||||
|
{-
|
||||||
||| Index the array using the given range of coordinates, returning a new array.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
||| WARNING: This function does not perform any bounds check on its inputs.
|
||| WARNING: This function does not perform any bounds check on its inputs.
|
||||||
||| Misuse of this function can easily break memory safety.
|
||| Misuse of this function can easily break memory safety.
|
||||||
|
@ -784,3 +715,4 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
|
||||||
export
|
export
|
||||||
pnorm : (p : Double) -> Array s Double -> Double
|
pnorm : (p : Double) -> Array s Double -> Double
|
||||||
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
||||||
|
-}
|
||||||
|
|
|
@ -2,6 +2,7 @@ module Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
import Data.Either
|
import Data.Either
|
||||||
import Data.List
|
import Data.List
|
||||||
|
import Data.List1
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NP
|
import Data.NP
|
||||||
|
|
||||||
|
@ -54,11 +55,6 @@ namespace Strict
|
||||||
Indices : List (Fin n) -> CRange n
|
Indices : List (Fin n) -> CRange n
|
||||||
Filter : (Fin n -> Bool) -> CRange n
|
Filter : (Fin n -> Bool) -> CRange n
|
||||||
|
|
||||||
infix 0 ...
|
|
||||||
public export
|
|
||||||
(...) : Fin (S n) -> Fin (S n) -> CRange n
|
|
||||||
(...) = Bounds
|
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
CoordsRange : (s : Vect rk Nat) -> Type
|
CoordsRange : (s : Vect rk Nat) -> Type
|
||||||
|
@ -197,3 +193,14 @@ namespace NB
|
||||||
go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat)
|
go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat)
|
||||||
go [] [] = [[]]
|
go [] [] = [[]]
|
||||||
go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |]
|
go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |]
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||||
|
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||||
|
|
||||||
|
export
|
||||||
|
getAllCoords : (s : Vect rk Nat) -> List (Coords s)
|
||||||
|
getAllCoords [] = [[]]
|
||||||
|
getAllCoords (Z :: s) = []
|
||||||
|
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
module Data.NumIdr.Array.Order
|
|
||||||
|
|
||||||
import Data.Vect
|
|
||||||
|
|
||||||
%default total
|
|
||||||
|
|
||||||
|
|
||||||
||| An order is an abstract representation of the way in which array
|
|
||||||
||| elements are stored in memory. Orders are used to calculate strides,
|
|
||||||
||| which provide a method of converting an array coordinate into a linear
|
|
||||||
||| memory location.
|
|
||||||
public export
|
|
||||||
data Order : Type where
|
|
||||||
||| C-like order, or contiguous order. This order stores elements in a
|
|
||||||
||| row-major fashion (the last axis is the least significant).
|
|
||||||
COrder : Order
|
|
||||||
||| Fortran-like order. This order stores elements in a column-major
|
|
||||||
||| fashion (the first axis is the least significant).
|
|
||||||
FOrder : Order
|
|
||||||
|
|
||||||
|
|
||||||
public export
|
|
||||||
Eq Order where
|
|
||||||
COrder == COrder = True
|
|
||||||
FOrder == FOrder = True
|
|
||||||
COrder == FOrder = False
|
|
||||||
FOrder == COrder = False
|
|
||||||
|
|
||||||
|
|
||||||
||| Calculate an array's strides given its order and shape.
|
|
||||||
export
|
|
||||||
calcStrides : Order -> Vect rk Nat -> Vect rk Nat
|
|
||||||
calcStrides _ [] = []
|
|
||||||
calcStrides COrder v@(_::_) = scanr (*) 1 $ tail v
|
|
||||||
calcStrides FOrder v@(_::_) = scanl (*) 1 $ init v
|
|
|
@ -1,180 +1,102 @@
|
||||||
module Data.NumIdr.PrimArray
|
module Data.NumIdr.PrimArray
|
||||||
|
|
||||||
import Data.Nat
|
import Data.Buffer
|
||||||
import Data.IORef
|
import Data.Vect
|
||||||
import Data.IOArray.Prims
|
import Data.NP
|
||||||
|
import Data.NumIdr.Array.Rep
|
||||||
|
import Data.NumIdr.Array.Coords
|
||||||
|
import Data.NumIdr.PrimArray.Bytes
|
||||||
|
import Data.NumIdr.PrimArray.Boxed
|
||||||
|
import Data.NumIdr.PrimArray.Linked
|
||||||
|
import Data.NumIdr.PrimArray.Delayed
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
||| A wrapper for Idris's primitive array type.
|
public export
|
||||||
export
|
RepConstraint : Rep -> Type -> Type
|
||||||
record PrimArray a where
|
RepConstraint (Bytes _) a = ByteRep a
|
||||||
constructor MkPrimArray
|
RepConstraint (Boxed _) a = ()
|
||||||
arraySize : Nat
|
RepConstraint Linked a = ()
|
||||||
content : ArrayData a
|
RepConstraint Delayed a = ()
|
||||||
|
|
||||||
export
|
|
||||||
length : PrimArray a -> Nat
|
|
||||||
length = arraySize
|
|
||||||
|
|
||||||
|
|
||||||
-- Private helper functions for ArrayData primitives
|
|
||||||
newArrayData : Nat -> a -> IO (ArrayData a)
|
|
||||||
newArrayData n x = fromPrim $ prim__newArray (cast n) x
|
|
||||||
|
|
||||||
arrayDataGet : Nat -> ArrayData a -> IO a
|
|
||||||
arrayDataGet n arr = fromPrim $ prim__arrayGet arr (cast n)
|
|
||||||
|
|
||||||
arrayDataSet : Nat -> a -> ArrayData a -> IO ()
|
|
||||||
arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x
|
|
||||||
|
|
||||||
|
|
||||||
||| Construct an array with a constant value.
|
|
||||||
export
|
|
||||||
constant : Nat -> a -> PrimArray a
|
|
||||||
constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x
|
|
||||||
|
|
||||||
||| Construct an array from a list of "instructions" to write a value to a
|
|
||||||
||| particular index.
|
|
||||||
export
|
|
||||||
unsafeFromIns : Nat -> List (Nat, a) -> PrimArray a
|
|
||||||
unsafeFromIns size ins = unsafePerformIO $ do
|
|
||||||
arr <- newArrayData size (believe_me ())
|
|
||||||
for_ ins $ \(i,x) => arrayDataSet i x arr
|
|
||||||
pure $ MkPrimArray size arr
|
|
||||||
|
|
||||||
export
|
|
||||||
unsafeDoIns : List (Nat, a) -> PrimArray a -> IO ()
|
|
||||||
unsafeDoIns ins arr = for_ ins $ \(i,x) => arrayDataSet i x arr.content
|
|
||||||
|
|
||||||
||| Create an array given its size and a function to generate its elements by
|
|
||||||
||| its index.
|
|
||||||
export
|
|
||||||
create : Nat -> (Nat -> a) -> PrimArray a
|
|
||||||
create size f = unsafePerformIO $ do
|
|
||||||
arr <- newArrayData size (believe_me ())
|
|
||||||
addToArray Z size arr
|
|
||||||
pure $ MkPrimArray size arr
|
|
||||||
where
|
|
||||||
addToArray : Nat -> Nat -> ArrayData a -> IO ()
|
|
||||||
addToArray loc Z arr = pure ()
|
|
||||||
addToArray loc (S n) arr
|
|
||||||
= do arrayDataSet loc (f loc) arr
|
|
||||||
addToArray (S loc) n arr
|
|
||||||
|
|
||||||
||| Index into a primitive array. This function is unsafe, as it performs no
|
|
||||||
||| boundary check on the index given.
|
|
||||||
export
|
|
||||||
index : Nat -> PrimArray a -> a
|
|
||||||
index n arr = unsafePerformIO $ arrayDataGet n $ content arr
|
|
||||||
|
|
||||||
||| A safe version of `index` that ensures the index entered is valid.
|
|
||||||
export
|
|
||||||
safeIndex : Nat -> PrimArray a -> Maybe a
|
|
||||||
safeIndex n arr = if n < length arr
|
|
||||||
then Just $ index n arr
|
|
||||||
else Nothing
|
|
||||||
|
|
||||||
export
|
|
||||||
copy : PrimArray a -> PrimArray a
|
|
||||||
copy arr = create (length arr) (\n => index n arr)
|
|
||||||
|
|
||||||
export
|
|
||||||
updateAt : Nat -> (a -> a) -> PrimArray a -> PrimArray a
|
|
||||||
updateAt n f arr = if n >= length arr then arr else
|
|
||||||
unsafePerformIO $ do
|
|
||||||
let cpy = copy arr
|
|
||||||
x <- arrayDataGet n cpy.content
|
|
||||||
arrayDataSet n (f x) cpy.content
|
|
||||||
pure cpy
|
|
||||||
|
|
||||||
export
|
|
||||||
unsafeUpdateInPlace : Nat -> (a -> a) -> PrimArray a -> PrimArray a
|
|
||||||
unsafeUpdateInPlace n f arr = unsafePerformIO $ do
|
|
||||||
x <- arrayDataGet n arr.content
|
|
||||||
arrayDataSet n (f x) arr.content
|
|
||||||
pure arr
|
|
||||||
|
|
||||||
||| Convert a primitive array to a list.
|
|
||||||
export
|
|
||||||
toList : PrimArray a -> List a
|
|
||||||
toList arr = iter (length arr) []
|
|
||||||
where
|
|
||||||
iter : Nat -> List a -> List a
|
|
||||||
iter Z acc = acc
|
|
||||||
iter (S n) acc = let el = index n arr
|
|
||||||
in iter n (el :: acc)
|
|
||||||
|
|
||||||
||| Construct a primitive array from a list.
|
|
||||||
export
|
|
||||||
fromList : List a -> PrimArray a
|
|
||||||
fromList xs = create (length xs)
|
|
||||||
(\n => assert_total $ fromJust $ getAt n xs)
|
|
||||||
where
|
|
||||||
partial
|
|
||||||
fromJust : Maybe a -> a
|
|
||||||
fromJust (Just x) = x
|
|
||||||
|
|
||||||
||| Map a function over a primitive array.
|
|
||||||
export
|
|
||||||
map : (a -> b) -> PrimArray a -> PrimArray b
|
|
||||||
map f arr = create (length arr) (\n => f $ index n arr)
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
unsafeZipWith : (a -> b -> c) -> PrimArray a -> PrimArray b -> PrimArray c
|
PrimArray : (rep : Rep) -> Vect rk Nat -> (a : Type) -> RepConstraint rep a => Type
|
||||||
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
|
PrimArray (Bytes o) s a = PrimArrayBytes o s
|
||||||
|
PrimArray (Boxed o) s a = PrimArrayBoxed o s a
|
||||||
export
|
PrimArray Linked s a = Vects s a
|
||||||
unsafeZipWith3 : (a -> b -> c -> d) ->
|
PrimArray Delayed s a = Coords s -> a
|
||||||
PrimArray a -> PrimArray b -> PrimArray c -> PrimArray d
|
|
||||||
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
|
|
||||||
|
|
||||||
export
|
|
||||||
unzipWith : (a -> (b, c)) -> PrimArray a -> (PrimArray b, PrimArray c)
|
|
||||||
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
|
|
||||||
|
|
||||||
export
|
|
||||||
unzipWith3 : (a -> (b, c, d)) -> PrimArray a -> (PrimArray b, PrimArray c, PrimArray d)
|
|
||||||
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
|
|
||||||
map ((\(_,y,_) => y) . f) arr,
|
|
||||||
map ((\(_,_,z) => z) . f) arr)
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
foldl : (b -> a -> b) -> b -> PrimArray a -> b
|
constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a
|
||||||
foldl f z (MkPrimArray size arr) =
|
constant {rep = Bytes o} = Bytes.constant
|
||||||
if size == 0 then z
|
constant {rep = Boxed o} = Boxed.constant
|
||||||
else unsafePerformIO $ do
|
constant {rep = Linked} = Linked.constant
|
||||||
ref <- newIORef z
|
constant {rep = Delayed} = Delayed.constant
|
||||||
for_ [0..pred size] $ \n => do
|
|
||||||
x <- readIORef ref
|
|
||||||
y <- arrayDataGet n arr
|
|
||||||
writeIORef ref (f x y)
|
|
||||||
readIORef ref
|
|
||||||
|
|
||||||
export
|
export
|
||||||
foldr : (a -> b -> b) -> b -> PrimArray a -> b
|
fromFunctionNB : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> (Vect rk Nat -> a) -> PrimArray rep s a
|
||||||
foldr f z (MkPrimArray size arr) =
|
fromFunctionNB {rep = Bytes o} @{rc} s f =
|
||||||
if size == 0 then z
|
let sts = calcStrides o s
|
||||||
else unsafePerformIO $ do
|
in Bytes.unsafeFromIns @{rc} s ((\is => (getLocation' sts is, f is)) <$> getAllCoords' s)
|
||||||
ref <- newIORef z
|
fromFunctionNB {rep = Boxed o} s f =
|
||||||
for_ [pred size..0] $ \n => do
|
let sts = calcStrides o s
|
||||||
x <- arrayDataGet n arr
|
in Boxed.unsafeFromIns s ((\is => (getLocation' sts is, f is)) <$> getAllCoords' s)
|
||||||
y <- readIORef ref
|
fromFunctionNB {rep = Linked} s f = Linked.fromFunctionNB f
|
||||||
writeIORef ref (f x y)
|
fromFunctionNB {rep = Delayed} s f = f . toNB
|
||||||
readIORef ref
|
|
||||||
|
|
||||||
export
|
export
|
||||||
traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
|
fromFunction : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> (Coords s -> a) -> PrimArray rep s a
|
||||||
traverse f = map fromList . traverse f . toList
|
fromFunction {rep = Bytes o} @{rc} s f =
|
||||||
|
let sts = calcStrides o s
|
||||||
|
in Bytes.unsafeFromIns @{rc} s ((\is => (getLocation sts is, f is)) <$> getAllCoords s)
|
||||||
|
fromFunction {rep = Boxed o} s f =
|
||||||
|
let sts = calcStrides o s
|
||||||
|
in Boxed.unsafeFromIns s ((\is => (getLocation sts is, f is)) <$> getAllCoords s)
|
||||||
|
fromFunction {rep = Linked} s f = Linked.fromFunction f
|
||||||
|
fromFunction {rep = Delayed} s f = f
|
||||||
|
|
||||||
|
|
||||||
||| Compares two primitive arrays for equal elements. This function assumes the
|
|
||||||
||| arrays have the same length; it must not be used in any other case.
|
|
||||||
export
|
export
|
||||||
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
|
index : {rep,s : _} -> RepConstraint rep a => Coords s -> PrimArray rep s a -> a
|
||||||
unsafeEq a b = unsafePerformIO $
|
index {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation sts is) arr
|
||||||
map (concat @{All}) $ for [0..pred (arraySize a)] $
|
index {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation sts is) arr
|
||||||
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)
|
index {rep = Linked} is arr = Linked.index is arr
|
||||||
|
index {rep = Delayed} is arr = arr is
|
||||||
|
|
||||||
|
export
|
||||||
|
indexNB : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> Maybe a
|
||||||
|
indexNB {rep = Bytes o} is arr@(MkPABytes sts _) =
|
||||||
|
if and (zipWith (delay .: (<)) is s)
|
||||||
|
then Just $ index (getLocation' sts is) arr
|
||||||
|
else Nothing
|
||||||
|
indexNB {rep = Boxed o} is arr@(MkPABoxed sts _) =
|
||||||
|
if and (zipWith (delay .: (<)) is s)
|
||||||
|
then Just $ index (getLocation' sts is) arr
|
||||||
|
else Nothing
|
||||||
|
indexNB {rep = Linked} is arr = (`Linked.index` arr) <$> checkRange s is
|
||||||
|
indexNB {rep = Delayed} is arr = arr <$> checkRange s is
|
||||||
|
|
||||||
|
export
|
||||||
|
indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a
|
||||||
|
indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr
|
||||||
|
indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr
|
||||||
|
indexUnsafe {rep = Linked} is arr = assert_total $ case checkRange s is of
|
||||||
|
Just is' => Linked.index is' arr
|
||||||
|
indexUnsafe {rep = Delayed} is arr = assert_total $ case checkRange s is of
|
||||||
|
Just is' => arr is'
|
||||||
|
|
||||||
|
export
|
||||||
|
convertRep : {r1,r2,s : _} -> RepConstraint r1 a => RepConstraint r2 a => PrimArray r1 s a -> PrimArray r2 s a
|
||||||
|
convertRep {r1 = Bytes o, r2 = Bytes o'} @{rc} arr = reorder @{rc} arr
|
||||||
|
convertRep {r1 = Boxed o, r2 = Boxed o'} arr = reorder arr
|
||||||
|
convertRep {r1 = Linked, r2 = Linked} arr = arr
|
||||||
|
convertRep {r1 = Linked, r2 = Bytes COrder} @{_} @{rc} arr = fromList @{rc} s (collapse arr)
|
||||||
|
convertRep {r1 = Linked, r2 = Boxed COrder} arr = fromList s (collapse arr)
|
||||||
|
convertRep {r1 = Delayed, r2 = Delayed} arr = arr
|
||||||
|
convertRep {r1, r2} arr = fromFunction s (\is => PrimArray.index is arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a
|
||||||
|
fromVects s v = convertRep {r1=Linked} v
|
||||||
|
|
Loading…
Reference in a new issue