Revamp all of the standard array API

This commit is contained in:
Kiana Sheibani 2023-09-25 22:55:33 -04:00
parent ccb689af42
commit a0068469c5
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
10 changed files with 655 additions and 364 deletions

View file

@ -26,6 +26,11 @@ public export
elems {all=[_]} [x] = show x elems {all=[_]} [x] = show x
elems {all=_::_} (x :: xs) = show x ++ ", " ++ elems xs elems {all=_::_} (x :: xs) = show x ++ ", " ++ elems xs
-- use this when idris can't figure out the constraint above
public export
eqNP : forall f. (forall t. Eq (f t)) => (x, y : NP f ts) -> Bool
eqNP [] [] = True
eqNP (x :: xs) (y :: ys) = x == y && eqNP xs ys
public export public export

View file

@ -3,4 +3,4 @@ module Data.NumIdr.Array
import public Data.NP import public Data.NP
import public Data.NumIdr.Array.Array import public Data.NumIdr.Array.Array
import public Data.NumIdr.Array.Coords import public Data.NumIdr.Array.Coords
import public Data.NumIdr.Array.Order import public Data.NumIdr.Array.Rep

View file

@ -112,7 +112,7 @@ viewShape arr = rewrite shapeEq arr in
export export
repeat : {default B rep : Rep} -> RepConstraint rep a => repeat : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> a -> Array s a (s : Vect rk Nat) -> a -> Array s a
repeat s x = MkArray rep s (constant s x) repeat s x = MkArray rep s (PrimArray.constant s x)
||| Create an array filled with zeros. ||| Create an array filled with zeros.
||| |||
@ -138,9 +138,9 @@ ones {rep} s = repeat {rep} s 1
||| @ s The shape of the constructed array ||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array ||| @ rep The internal representation of the constructed array
export export
fromVect : {default B rep : Rep} -> RepConstraint rep a => fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
(s : Vect rk Nat) -> Vect (product s) a -> Array s a (s : Vect rk Nat) -> Vect (product s) a -> Array s a
fromVect s v = ?fv fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $ toList v)
||| Create an array by taking values from a stream. ||| Create an array by taking values from a stream.
@ -148,7 +148,7 @@ fromVect s v = ?fv
||| @ s The shape of the constructed array ||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array ||| @ rep The internal representation of the constructed array
export export
fromStream : {default B rep : Rep} -> RepConstraint rep a => fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
(s : Vect rk Nat) -> Stream a -> Array s a (s : Vect rk Nat) -> Stream a -> Array s a
fromStream {rep} s str = fromVect {rep} s (take _ str) fromStream {rep} s str = fromVect {rep} s (take _ str)
@ -211,12 +211,12 @@ index is (MkArray _ _ arr) = PrimArray.index is arr
export %inline export %inline
(!!) : Array s a -> Coords s -> a (!!) : Array s a -> Coords s -> a
arr !! is = index is arr arr !! is = index is arr
{-
||| Update the entry at the given coordinates using the function. ||| Update the entry at the given coordinates using the function.
export export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
indexUpdate is f (MkArray ord sts s arr) = indexUpdate is f (MkArray rep @{rc} s arr) =
MkArray ord sts s (updateAt (getLocation sts is) f arr) MkArray rep @{rc} s (indexUpdate @{rc} is f arr)
||| Set the entry at the given coordinates to the given value. ||| Set the entry at the given coordinates to the given value.
export export
@ -227,17 +227,8 @@ indexSet is = indexUpdate is . const
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
export export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange {s} rs arr with (viewShape arr) indexRange rs (MkArray rep @{rc} s arr) =
_ | Shape s = MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr)
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList rs)
where s' : Vect ? Nat
s' = newShape rs
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| |||
@ -250,13 +241,8 @@ arr !!.. rs = indexRange rs arr
export export
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a -> indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
Array s a -> Array s a Array s a -> Array s a
indexSetRange rs rpl (MkArray ord sts s arr) = indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) =
MkArray ord sts s (unsafePerformIO $ do MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRep rpl) arr)
let arr' = copy arr
unsafeDoIns (map (\(is,is') =>
(getLocation' sts is, index (getLocation' (strides rpl) is') (getPrim rpl)))
$ getCoordsList rs) arr'
pure arr')
||| Update the sub-array at the given range of coordinates by applying ||| Update the sub-array at the given range of coordinates by applying
@ -266,13 +252,14 @@ indexUpdateRange : (rs : CoordsRange s) ->
(Array (newShape rs) a -> Array (newShape rs) a) -> (Array (newShape rs) a -> Array (newShape rs) a) ->
Array s a -> Array s a Array s a -> Array s a
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
-}
||| Index the array using the given coordinates, returning `Nothing` if the ||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds. ||| coordinates are out of bounds.
export export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr indexNB is (MkArray rep @{rc} s arr) =
map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is)
||| Index the array using the given coordinates, returning `Nothing` if the ||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds. ||| coordinates are out of bounds.
@ -281,15 +268,13 @@ indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
export %inline export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr arr !? is = indexNB is arr
{-
||| Update the entry at the given coordinates using the function. `Nothing` is ||| Update the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds. ||| returned if the coordinates are out of bounds.
export export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a) indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
indexUpdateNB is f (MkArray ord sts s arr) = indexUpdateNB is f (MkArray rep @{rc} s arr) =
if all id $ zipWith (<) is s map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is)
then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr)
else Nothing
||| Set the entry at the given coordinates using the function. `Nothing` is ||| Set the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds. ||| returned if the coordinates are out of bounds.
@ -302,20 +287,8 @@ indexSetNB is = indexUpdateNB is . const
||| If any of the given indices are out of bounds, then `Nothing` is returned. ||| If any of the given indices are out of bounds, then `Nothing` is returned.
export export
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a) indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
indexRangeNB {s} rs arr with (viewShape arr) indexRangeNB rs (MkArray rep @{rc} s arr) =
_ | Shape s = map (\rs => believe_me $ Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs)
if validateShape s rs
then
let ord = getOrder arr
sts = calcStrides ord s'
in Just $ MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
else Nothing
where s' : Vect ? Nat
s' = newShape s rs
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| If any of the given indices are out of bounds, then `Nothing` is returned. ||| If any of the given indices are out of bounds, then `Nothing` is returned.
@ -324,7 +297,7 @@ indexRangeNB {s} rs arr with (viewShape arr)
export %inline export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a) (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
arr !?.. rs = indexRangeNB rs arr arr !?.. rs = indexRangeNB rs arr
-}
||| Index the array using the given coordinates. ||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs. ||| WARNING: This function does not perform any bounds check on its inputs.
@ -342,30 +315,21 @@ export %inline
(!#) : Array {rk} s a -> Vect rk Nat -> a (!#) : Array {rk} s a -> Vect rk Nat -> a
arr !# is = indexUnsafe is arr arr !# is = indexUnsafe is arr
{-
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs. ||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety. ||| Misuse of this function can easily break memory safety.
export export %unsafe
indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
indexRangeUnsafe {s} rs arr with (viewShape arr) indexRangeUnsafe rs (MkArray rep @{rc} s arr) =
_ | Shape s = believe_me $ Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr)
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
where s' : Vect ? Nat
s' = newShape s rs
||| Index the array using the given range of coordinates, returning a new array. ||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs. ||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety. ||| Misuse of this function can easily break memory safety.
||| |||
||| This is the operator form of `indexRangeUnsafe`. ||| This is the operator form of `indexRangeUnsafe`.
export %inline export %inline %unsafe
(!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a (!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
arr !#.. is = indexRangeUnsafe is arr arr !#.. is = indexRangeUnsafe is arr
@ -375,35 +339,42 @@ arr !#.. is = indexRangeUnsafe is arr
-- Operations on arrays -- Operations on arrays
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
||| Reshape the array into the given shape and reinterpret it according to
||| the given order.
|||
||| @ s' The shape to convert the array to
||| @ ord The order to reinterpret the array by
export export
reshape' : (s' : Vect rk' Nat) -> (ord : Order) -> Array {rk} s a -> mapArray' : (a -> a) -> Array s a -> Array s a
(0 ok : product s = product s') => Array s' a mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
reshape' s' ord' arr = MkArray ord' (calcStrides ord' s') s' (getPrim arr)
export
mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
export
zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
zipWithArray' {s} f a b with (viewShape a)
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b))
@{mergeRepConstraint (getRepC a) (getRepC b)} s
$ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
(\is => f (a !# is) (b !# is))
export
zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
zipWithArray {s} f a b @{rc} with (viewShape a)
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s
$ PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is))
||| Reshape the array into the given shape. ||| Reshape the array into the given shape.
||| |||
||| @ s' The shape to convert the array to ||| @ s' The shape to convert the array to
export export
reshape : (s' : Vect rk' Nat) -> Array {rk} s a -> reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) =>
(0 ok : product s = product s') => Array s' a (0 ok : product s = product s') => Array s' a
reshape s' arr = reshape' s' (getOrder arr) arr reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr)
||| Change the internal order of the array's elements. ||| Change the internal order of the array's elements.
export export
reorder : Order -> Array s a -> Array s a convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a
reorder {s} ord' arr with (viewShape arr) convertRep rep (MkArray _ s arr) = MkArray rep s (PrimArray.convertRep arr)
_ | Shape s = let sts = calcStrides ord' s
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
||| Resize the array to a new shape, preserving the coordinates of the original ||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. New coordinates are filled with a default value. ||| elements. New coordinates are filled with a default value.
@ -412,7 +383,7 @@ reorder {s} ord' arr with (viewShape arr)
||| @ def The default value to fill the array with ||| @ def The default value to fill the array with
export export
resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB) resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB)
||| Resize the array to a new shape, preserving the coordinates of the original ||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. This function requires a proof that the new shape is strictly ||| elements. This function requires a proof that the new shape is strictly
@ -427,20 +398,11 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
resizeLTE s' arr = resize s' (believe_me ()) arr resizeLTE s' arr = resize s' (believe_me ()) arr
||| List all of the values in an array in row-major order.
export
elements : Array {rk} s a -> Vect (product s) a
elements (MkArray _ sts sh p) =
let elems = map (flip index p . getLocation' sts) (getAllCoords' sh)
-- TODO: prove that the number of elements is `product s`
in assert_total $ case toVect (product sh) elems of Just v => v
||| List all of the values in an array along with their coordinates. ||| List all of the values in an array along with their coordinates.
export export
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a) enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
enumerateNB (MkArray _ sts sh p) = enumerateNB (MkArray _ s arr) =
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh) map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s)
||| List all of the values in an array along with their coordinates. ||| List all of the values in an array along with their coordinates.
export export
@ -448,6 +410,12 @@ enumerate : Array s a -> List (Coords s, a)
enumerate {s} arr with (viewShape arr) enumerate {s} arr with (viewShape arr)
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s) _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
||| List all of the values in an array in row-major order.
export
elements : Array {rk} s a -> Vect (product s) a
elements (MkArray _ s arr) =
believe_me $ Vect.fromList $
map (flip PrimArray.indexUnsafe arr) (getAllCoords' s)
||| Join two arrays along a particular axis, e.g. combining two matrices ||| Join two arrays along a particular axis, e.g. combining two matrices
||| vertically or horizontally. All other axes of the arrays must have the ||| vertically or horizontally. All other axes of the arrays must have the
@ -457,16 +425,12 @@ enumerate {s} arr with (viewShape arr)
export export
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
Array (updateAt axis (+d) s) a Array (updateAt axis (+d) s) a
concat axis a b = let sA = shape a concat {s,d} axis a b with (viewShape a, viewShape b)
sB = shape b _ | (Shape s, Shape (replaceAt axis d s)) =
dA = index axis sA believe_me $ Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)}
dB = index axis sB @{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s)
s = replaceAt axis (dA + dB) sA (\is => let limit = index axis s
sts = calcStrides COrder s in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is)
ins = map (mapFst $ getLocation' sts . toNB) (enumerate a)
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNB) (enumerate b)
-- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix. ||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
||| |||
@ -500,12 +464,11 @@ splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a) Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is'))) splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
||| Construct the transpose of an array by reversing the order of its axes. ||| Construct the transpose of an array by reversing the order of its axes.
export export
transpose : Array s a -> Array (reverse s) a transpose : Array s a -> Array (reverse s) a
transpose {s} arr with (viewShape arr) transpose {s} arr with (viewShape arr)
_ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is) _ | Shape s = fromFunctionNB _ (\is => arr !# reverse is)
||| Construct the transpose of an array by reversing the order of its axes. ||| Construct the transpose of an array by reversing the order of its axes.
||| |||
@ -551,41 +514,56 @@ permuteInAxis {s} axis p arr with (viewShape arr)
-- Implementations -- Implementations
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Most of these implementations apply the operation pointwise. If there are
-- multiple arrays involved with different orders, then all of the arrays are
-- reordered to match one.
export export
Zippable (Array s) where Zippable (Array s) where
zipWith {s} f a b with (viewShape a) zipWith {s} f a b with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $ _ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b))
if getOrder a == getOrder b @{mergeNCRepConstraint} s
then unsafeZipWith f (getPrim a) (getPrim b) $ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b) (\is => f (a !# is) (b !# is))
zipWith3 {s} f a b c with (viewShape a) zipWith3 {s} f a b c with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $ _ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c))
if (getOrder a == getOrder b) && (getOrder b == getOrder c) @{mergeNCRepConstraint} s
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c) $ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b) (\is => f (a !# is) (b !# is) (c !# is))
(getPrim $ reorder (getOrder a) c)
unzipWith {s} f arr with (viewShape arr) unzipWith {s} f arr with (viewShape arr)
_ | Shape s = case unzipWith f (getPrim arr) of _ | Shape s =
(a, b) => (MkArray (getOrder arr) (strides arr) s a, let rep : Rep
MkArray (getOrder arr) (strides arr) s b) rep = forceRepNC $ getRep arr
in (MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => snd $ f (arr !# is)))
unzipWith3 {s} f arr with (viewShape arr) unzipWith3 {s} f arr with (viewShape arr)
_ | Shape s = case unzipWith3 f (getPrim arr) of _ | Shape s =
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a, let rep : Rep
MkArray (getOrder arr) (strides arr) s b, rep = forceRepNC $ getRep arr
MkArray (getOrder arr) (strides arr) s c) in (MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ snd $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => snd $ snd $ f (arr !# is)))
export export
Functor (Array s) where Functor (Array s) where
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr) map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s
(mapPrim @{forceRepConstraint} @{forceRepConstraint} f
$ convertRep @{rc} @{forceRepConstraint} arr)
export export
{s : _} -> Applicative (Array s) where {s : _} -> Applicative (Array s) where
@ -602,15 +580,18 @@ export
export export
Foldable (Array s) where Foldable (Array s) where
foldl f z = foldl f z . getPrim foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr
foldr f z = foldr f z . getPrim foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr
null arr = size arr == Z null (MkArray _ s _) = isZero (product s)
toList = toList . getPrim
export export
Traversable (Array s) where Traversable (Array s) where
traverse f (MkArray ord sts s arr) = traverse f (MkArray rep @{rc} s arr) =
map (MkArray ord sts s) (traverse f arr) map (MkArray (forceRepNC rep) @{forceRepConstraint} s)
(PrimArray.traverse {rep=forceRepNC rep}
@{%search} @{forceRepConstraint} @{forceRepConstraint} f
(PrimArray.convertRep @{rc} @{forceRepConstraint} arr))
export export
@ -619,13 +600,11 @@ Cast a b => Cast (Array s a) (Array s b) where
export export
Eq a => Eq (Array s a) where Eq a => Eq (Array s a) where
a == b = if getOrder a == getOrder b a == b = and $ zipWith (delay .: (==)) (convertRep D a) (convertRep D b)
then unsafeEq (getPrim a) (getPrim b)
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
export export
Semigroup a => Semigroup (Array s a) where Semigroup a => Semigroup (Array s a) where
(<+>) = zipWith (<+>) (<+>) = zipWithArray' (<+>)
export export
{s : _} -> Monoid a => Monoid (Array s a) where {s : _} -> Monoid a => Monoid (Array s a) where
@ -635,36 +614,34 @@ export
-- were moved into its own interface, this constraint could be removed. -- were moved into its own interface, this constraint could be removed.
export export
{s : _} -> Num a => Num (Array s a) where {s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+) (+) = zipWithArray' (+)
(*) = zipWith (*) (*) = zipWithArray' (*)
fromInteger = repeat s . fromInteger fromInteger = repeat s . fromInteger
export export
{s : _} -> Neg a => Neg (Array s a) where {s : _} -> Neg a => Neg (Array s a) where
negate = map negate negate = mapArray' negate
(-) = zipWith (-) (-) = zipWithArray' (-)
export export
{s : _} -> Fractional a => Fractional (Array s a) where {s : _} -> Fractional a => Fractional (Array s a) where
recip = map recip recip = mapArray' recip
(/) = zipWith (/) (/) = zipWithArray' (/)
export export
Num a => Mult a (Array {rk} s a) (Array s a) where Num a => Mult a (Array {rk} s a) (Array s a) where
(*.) x = map (*x) (*.) x = mapArray' (*x)
export export
Num a => Mult (Array {rk} s a) a (Array s a) where Num a => Mult (Array {rk} s a) a (Array s a) where
(*.) = flip (*.) (*.) = flip (*.)
export export
Show a => Show (Array s a) where Show a => Show (Array s a) where
showPrec d arr = let orderedElems = PrimArray.toList $ getPrim $ showPrec d arr = let orderedElems = toList $ elements arr
if getOrder arr == COrder then arr else reorder COrder arr
in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems
where where
splitWindow : Nat -> List String -> List (List String) splitWindow : Nat -> List String -> List (List String)
@ -691,7 +668,7 @@ Show a => Show (Array s a) where
||| Linearly interpolate between two arrays. ||| Linearly interpolate between two arrays.
export export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t) lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t)
||| Calculate the square of an array's Eulidean norm. ||| Calculate the square of an array's Eulidean norm.
@ -715,4 +692,3 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export export
pnorm : (p : Double) -> Array s Double -> Double pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p) pnorm p = (`pow` recip p) . sum . map (`pow` p)
-}

View file

@ -43,6 +43,11 @@ toNB : Coords {rk} s -> Vect rk Nat
toNB [] = [] toNB [] = []
toNB (i :: is) = finToNat i :: toNB is toNB (i :: is) = finToNat i :: toNB is
export
validateCoords : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
validateCoords [] [] = Just []
validateCoords (d :: s) (i :: is) = (::) <$> natToFin i d <*> validateCoords s is
namespace Strict namespace Strict
public export public export
@ -155,6 +160,44 @@ namespace Strict
namespace NB namespace NB
export export
validateCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> Maybe (CoordsRange s)
validateCRange [] [] = Just []
validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |]
where
validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n)
validate' n All = Just All
validate' n (StartBound x) =
case isLTE x n of
Yes _ => Just (StartBound (natToFinLT x))
_ => Nothing
validate' n (EndBound x) =
case isLTE x n of
Yes _ => Just (EndBound (natToFinLT x))
_ => Nothing
validate' n (Bounds x y) =
case (isLTE x n, isLTE y n) of
(Yes _, Yes _) => Just (Bounds (natToFinLT x) (natToFinLT y))
_ => Nothing
validate' n (Indices xs) = Indices <$> traverse
(\x => case isLT x n of
Yes _ => Just (natToFinLT x)
No _ => Nothing) xs
validate' n (Filter f) = Just (Filter (f . finToNat))
export %unsafe
assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s
assertCRange [] [] = []
assertCRange (d :: s) (r :: rs) = assert' d r :: assertCRange s rs
where
assert' : (n : Nat) -> CRangeNB -> CRange n
assert' n All = All
assert' n (StartBound x) = StartBound (believe_me x)
assert' n (EndBound x) = EndBound (believe_me x)
assert' n (Bounds x y) = Bounds (believe_me x) (believe_me y)
assert' n (Indices xs) = Indices (believe_me <$> xs)
assert' n (Filter f) = Filter (f . finToNat)
public export
cRangeNBToList : Nat -> CRangeNB -> List Nat cRangeNBToList : Nat -> CRangeNB -> List Nat
cRangeNBToList s All = range 0 s cRangeNBToList s All = range 0 s
cRangeNBToList s (StartBound x) = range x s cRangeNBToList s (StartBound x) = range x s
@ -163,38 +206,11 @@ namespace NB
cRangeNBToList s (Indices xs) = nub xs cRangeNBToList s (Indices xs) = nub xs
cRangeNBToList s (Filter p) = filter p $ range 0 s cRangeNBToList s (Filter p) = filter p $ range 0 s
export
validateCRangeNB : Nat -> CRangeNB -> Bool
validateCRangeNB s All = True
validateCRangeNB s (StartBound x) = x < s
validateCRangeNB s (EndBound x) = x <= s
validateCRangeNB s (Bounds x y) = x < s && y <= s
validateCRangeNB s (Indices xs) = all (<s) xs
validateCRangeNB s (Filter p) = True
||| Calculate the new shape given by a coordinate range. ||| Calculate the new shape given by a coordinate range.
export public export
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
newShape = zipWith (length .: cRangeNBToList) newShape = zipWith (length .: cRangeNBToList)
export
validateShape : Vect rk Nat -> Vect rk CRangeNB -> Bool
validateShape = all id .: zipWith validateCRangeNB
export
getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat
getNewPos = zipWith3 (\d,r,i => assert_total $
case findIndex (==i) (cRangeNBToList d r) of Just x => cast x)
export
getCoordsList : Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat, Vect rk Nat)
getCoordsList s rs = map (\is => (is, getNewPos s rs is)) $ go s rs
where
go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat)
go [] [] = [[]]
go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |]
export export
getAllCoords' : Vect rk Nat -> List (Vect rk Nat) getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
getAllCoords' = traverse (\case Z => []; S n => [0..n]) getAllCoords' = traverse (\case Z => []; S n => [0..n])

View file

@ -1,6 +1,8 @@
module Data.NumIdr.Array.Rep module Data.NumIdr.Array.Rep
import Data.Vect import Data.Vect
import Data.Bits
import Data.Buffer
%default total %default total
@ -23,6 +25,8 @@ data Order : Type where
||| fashion (the first axis is the least significant). ||| fashion (the first axis is the least significant).
FOrder : Order FOrder : Order
%name Order o
public export public export
Eq Order where Eq Order where
@ -52,6 +56,9 @@ data Rep : Type where
Linked : Rep Linked : Rep
Delayed : Rep Delayed : Rep
%name Rep rep
public export public export
B : Rep B : Rep
B = Boxed COrder B = Boxed COrder
@ -65,7 +72,231 @@ D : Rep
D = Delayed D = Delayed
public export
Eq Rep where
Bytes o == Bytes o' = o == o'
Boxed o == Boxed o' = o == o'
Linked == Linked = True
Delayed == Delayed = True
_ == _ = False
public export public export
data LinearRep : Rep -> Type where data LinearRep : Rep -> Type where
BytesIsL : LinearRep (Bytes o) BytesIsL : LinearRep (Bytes o)
BoxedIsL : LinearRep (Boxed o) BoxedIsL : LinearRep (Boxed o)
public export
forceRepNC : Rep -> Rep
forceRepNC (Bytes o) = Boxed o
forceRepNC r = r
public export
mergeRep : Rep -> Rep -> Rep
mergeRep r r' = if r == Delayed || r' == Delayed then Delayed else r
public export
mergeRepNC : Rep -> Rep -> Rep
mergeRepNC r r' =
if r == Delayed || r' == Delayed then Delayed
else case r of
Bytes _ => case r' of
Bytes o => Boxed o
_ => r'
_ => r
public export
data NoConstraintRep : Rep -> Type where
BoxedNC : NoConstraintRep (Boxed o)
LinkedNC : NoConstraintRep Linked
DelayedNC : NoConstraintRep Delayed
public export
mergeRepNCCorrect : (r, r' : Rep) -> NoConstraintRep (mergeRepNC r r')
mergeRepNCCorrect Delayed _ = DelayedNC
mergeRepNCCorrect (Bytes y) (Bytes x) = BoxedNC
mergeRepNCCorrect (Boxed y) (Bytes x) = BoxedNC
mergeRepNCCorrect Linked (Bytes x) = LinkedNC
mergeRepNCCorrect (Bytes y) (Boxed x) = BoxedNC
mergeRepNCCorrect (Boxed y) (Boxed x) = BoxedNC
mergeRepNCCorrect Linked (Boxed x) = LinkedNC
mergeRepNCCorrect (Bytes x) Linked = LinkedNC
mergeRepNCCorrect (Boxed x) Linked = BoxedNC
mergeRepNCCorrect Linked Linked = LinkedNC
mergeRepNCCorrect (Bytes x) Delayed = DelayedNC
mergeRepNCCorrect (Boxed x) Delayed = DelayedNC
mergeRepNCCorrect Linked Delayed = DelayedNC
public export
forceRepNCCorrect : (r : Rep) -> NoConstraintRep (forceRepNC r)
forceRepNCCorrect (Bytes x) = BoxedNC
forceRepNCCorrect (Boxed x) = BoxedNC
forceRepNCCorrect Linked = LinkedNC
forceRepNCCorrect Delayed = DelayedNC
--------------------------------------------------------------------------------
-- Byte representations of elements
--------------------------------------------------------------------------------
public export
interface ByteRep a where
bytes : Nat
toBytes : a -> Vect bytes Bits8
fromBytes : Vect bytes Bits8 -> a
export
ByteRep Bits8 where
bytes = 1
toBytes x = [x]
fromBytes [x] = x
export
ByteRep Bits16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Bits32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Bits64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int8 where
bytes = 1
toBytes x = [cast x]
fromBytes [x] = cast x
export
ByteRep Int16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Int32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Int64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Bool where
bytes = 1
toBytes b = [if b then 1 else 0]
fromBytes [x] = x /= 0
export
ByteRep a => ByteRep b => ByteRep (a, b) where
bytes = bytes {a} + bytes {a=b}
toBytes (x,y) = toBytes x ++ toBytes y
fromBytes = bimap fromBytes fromBytes . splitAt _
export
{n : _} -> ByteRep a => ByteRep (Vect n a) where
bytes = n * bytes {a}
toBytes xs = concat $ map toBytes xs
fromBytes {n = 0} bs = []
fromBytes {n = S n} bs =
let (bs1, bs2) = splitAt _ bs
in fromBytes bs1 :: fromBytes bs2
public export
RepConstraint : Rep -> Type -> Type
RepConstraint (Bytes _) a = ByteRep a
RepConstraint (Boxed _) a = ()
RepConstraint Linked a = ()
RepConstraint Delayed a = ()

View file

@ -2,23 +2,39 @@ module Data.NumIdr.PrimArray
import Data.Buffer import Data.Buffer
import Data.Vect import Data.Vect
import Data.Fin
import Data.NP import Data.NP
import Data.NumIdr.Array.Rep import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords import Data.NumIdr.Array.Coords
import Data.NumIdr.PrimArray.Bytes import public Data.NumIdr.PrimArray.Bytes
import Data.NumIdr.PrimArray.Boxed import public Data.NumIdr.PrimArray.Boxed
import Data.NumIdr.PrimArray.Linked import public Data.NumIdr.PrimArray.Linked
import Data.NumIdr.PrimArray.Delayed import public Data.NumIdr.PrimArray.Delayed
%default total %default total
public export
RepConstraint : Rep -> Type -> Type noRCCorrect : (rep : Rep) -> NoConstraintRep rep -> RepConstraint rep a
RepConstraint (Bytes _) a = ByteRep a noRCCorrect (Boxed _) BoxedNC = ()
RepConstraint (Boxed _) a = () noRCCorrect Linked LinkedNC = ()
RepConstraint Linked a = () noRCCorrect Delayed DelayedNC = ()
RepConstraint Delayed a = ()
export
mergeRepConstraint : {r, r' : Rep} -> RepConstraint r a -> RepConstraint r' a ->
RepConstraint (mergeRep r r') a
mergeRepConstraint {r,r'} a b with (r == Delayed || r' == Delayed)
_ | True = ()
_ | False = a
export
mergeNCRepConstraint : {r, r' : Rep} -> RepConstraint (mergeRepNC r r') a
mergeNCRepConstraint = noRCCorrect _ (mergeRepNCCorrect r r')
export
forceRepConstraint : {r : Rep} -> RepConstraint (forceRepNC r) a
forceRepConstraint = noRCCorrect _ (forceRepNCCorrect r)
export export
@ -29,6 +45,11 @@ PrimArray Linked s a = Vects s a
PrimArray Delayed s a = Coords s -> a PrimArray Delayed s a = Coords s -> a
export
reshape : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s' : Vect rk Nat) -> PrimArray rep s a -> PrimArray rep s' a
reshape {rep = Bytes o} s' (MkPABytes _ arr) = MkPABytes (calcStrides o s') arr
reshape {rep = Boxed o} s' (MkPABoxed _ arr) = MkPABoxed (calcStrides o s') arr
export export
constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a
constant {rep = Bytes o} = Bytes.constant constant {rep = Bytes o} = Bytes.constant
@ -66,25 +87,68 @@ index {rep = Linked} is arr = Linked.index is arr
index {rep = Delayed} is arr = arr is index {rep = Delayed} is arr = arr is
export export
indexNB : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> Maybe a indexUpdate : {rep,s : _} -> RepConstraint rep a => Coords s -> (a -> a) -> PrimArray rep s a -> PrimArray rep s a
indexNB {rep = Bytes o} is arr@(MkPABytes sts _) = indexUpdate {rep = Bytes o} @{rc} is f arr@(MkPABytes sts _) =
if and (zipWith (delay .: (<)) is s) let i = getLocation sts is
then Just $ index (getLocation' sts is) arr in create @{rc} s (\n => let x = index n arr in if n == i then f x else x)
else Nothing indexUpdate {rep = Boxed o} is f arr@(MkPABoxed sts _) =
indexNB {rep = Boxed o} is arr@(MkPABoxed sts _) = let i = getLocation sts is
if and (zipWith (delay .: (<)) is s) in create s (\n => let x = index n arr in if n == i then f x else x)
then Just $ index (getLocation' sts is) arr indexUpdate {rep = Linked} is f arr = update is f arr
else Nothing indexUpdate {rep = Delayed} is f arr =
indexNB {rep = Linked} is arr = (`Linked.index` arr) <$> checkRange s is \is' =>
indexNB {rep = Delayed} is arr = arr <$> checkRange s is let x = arr is'
in if eqNP is is' then f x else x
export
indexRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s) -> PrimArray rep s a
-> PrimArray rep (newShape rs) a
indexRange {rep = Bytes o} @{rc} rs arr@(MkPABytes sts _) =
let s' : Vect ? Nat
s' = newShape rs
sts' := calcStrides o s'
in unsafeFromIns @{rc} (newShape rs) $
map (\(is,is') => (getLocation' sts' is',
index (getLocation' sts is) arr))
$ getCoordsList rs
indexRange {rep = Boxed o} rs arr@(MkPABoxed sts _) =
let s' : Vect ? Nat
s' = newShape rs
sts' := calcStrides o s'
in unsafeFromIns (newShape rs) $
map (\(is,is') => (getLocation' sts' is',
index (getLocation' sts is) arr))
$ getCoordsList rs
indexRange {rep = Linked} rs arr = Linked.indexRange rs arr
indexRange {rep = Delayed} rs arr = Delayed.indexRange rs arr
export
indexSetRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s)
-> PrimArray rep (newShape rs) a -> PrimArray rep s a -> PrimArray rep s a
indexSetRange {rep = Bytes o} @{rc} rs rpl@(MkPABytes sts' _) arr@(MkPABytes sts _) =
unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns @{rc} (map (\(is,is') =>
(getLocation' sts is, index (getLocation' sts' is') rpl))
$ getCoordsList rs) arr'
pure arr'
indexSetRange {rep = Boxed o} rs rpl@(MkPABoxed sts' _) arr@(MkPABoxed sts _) =
unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns (map (\(is,is') =>
(getLocation' sts is, index (getLocation' sts' is') rpl))
$ getCoordsList rs) arr'
pure arr'
indexSetRange {rep = Linked} rs rpl arr = Linked.indexSetRange rs rpl arr
indexSetRange {rep = Delayed} rs rpl arr = Delayed.indexSetRange rs rpl arr
export export
indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a
indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr
indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr
indexUnsafe {rep = Linked} is arr = assert_total $ case checkRange s is of indexUnsafe {rep = Linked} is arr = assert_total $ case validateCoords s is of
Just is' => Linked.index is' arr Just is' => Linked.index is' arr
indexUnsafe {rep = Delayed} is arr = assert_total $ case checkRange s is of indexUnsafe {rep = Delayed} is arr = assert_total $ case validateCoords s is of
Just is' => arr is' Just is' => arr is'
export export
@ -100,3 +164,41 @@ convertRep {r1, r2} arr = fromFunction s (\is => PrimArray.index is arr)
export export
fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a
fromVects s v = convertRep {r1=Linked} v fromVects s v = convertRep {r1=Linked} v
export
fromList : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s : Vect rk Nat) -> List a -> PrimArray rep s a
fromList {rep = (Boxed x)} = Boxed.fromList
fromList {rep = (Bytes x)} = Bytes.fromList
export
mapPrim : {rep,s : _} -> RepConstraint rep a => RepConstraint rep b =>
(a -> b) -> PrimArray rep s a -> PrimArray rep s b
mapPrim {rep = (Bytes x)} @{_} @{rc} f arr = Bytes.create @{rc} _ (\i => f $ Bytes.index i arr)
mapPrim {rep = (Boxed x)} f arr = Boxed.create _ (\i => f $ Boxed.index i arr)
mapPrim {rep = Linked} f arr = Linked.mapVects f arr
mapPrim {rep = Delayed} f arr = f . arr
export
foldl : {rep,s : _} -> RepConstraint rep a => (b -> a -> b) -> b -> PrimArray rep s a -> b
foldl {rep = Bytes _} = Bytes.foldl
foldl {rep = Boxed _} = Boxed.foldl
foldl {rep = Linked} = Linked.foldl
foldl {rep = Delayed} = \f,z => Boxed.foldl f z . convertRep {r1=Delayed,r2=B}
export
foldr : {rep,s : _} -> RepConstraint rep a => (a -> b -> b) -> b -> PrimArray rep s a -> b
foldr {rep = Bytes _} = Bytes.foldr
foldr {rep = Boxed _} = Boxed.foldr
foldr {rep = Linked} = Linked.foldr
foldr {rep = Delayed} = \f,z => Boxed.foldr f z . convertRep {r1=Delayed,r2=B}
export
traverse : {rep,s : _} -> Applicative f => RepConstraint rep a => RepConstraint rep b =>
(a -> f b) -> PrimArray rep s a -> f (PrimArray rep s b)
traverse {rep = Bytes o} = Bytes.traverse
traverse {rep = Boxed o} = Boxed.traverse
traverse {rep = Linked} = Linked.traverse
traverse {rep = Delayed} = \f => map (convertRep @{()} @{()} {r1=B,r2=Delayed}) .
Boxed.traverse f .
convertRep @{()} @{()} {r1=Delayed,r2=B}

View file

@ -2,6 +2,7 @@ module Data.NumIdr.PrimArray.Boxed
import Data.Buffer import Data.Buffer
import System import System
import Data.IORef
import Data.IOArray.Prims import Data.IOArray.Prims
import Data.Vect import Data.Vect
import Data.NumIdr.Array.Rep import Data.NumIdr.Array.Rep
@ -108,6 +109,34 @@ fromList s xs = arrayAction s $ go xs 0
go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf
export
foldl : {s : _} -> (b -> a -> b) -> b -> PrimArrayBoxed o s a -> b
foldl f z (MkPABoxed sts arr) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred $ product s] $ \n => do
x <- readIORef ref
y <- arrayDataGet n arr
writeIORef ref (f x y)
readIORef ref
export
foldr : {s : _} -> (a -> b -> b) -> b -> PrimArrayBoxed o s a -> b
foldr f z (MkPABoxed sts arr) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [pred $ product s..0] $ \n => do
x <- arrayDataGet n arr
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : {o,s : _} -> Applicative f => (a -> f b) -> PrimArrayBoxed o s a -> f (PrimArrayBoxed o s b)
traverse f = map (Boxed.fromList _) . traverse f . foldr (::) []
{- {-
export export
updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
@ -145,10 +174,6 @@ fromList xs = create (length xs)
fromJust : Maybe a -> a fromJust : Maybe a -> a
fromJust (Just x) = x fromJust (Just x) = x
||| Map a function over a primitive array.
export
map : (a -> b) -> PrimArrayBoxed a -> PrimArrayBoxed b
map f arr = create (length arr) (\n => f $ index n arr)
export export

View file

@ -2,8 +2,6 @@ module Data.NumIdr.PrimArray.Bytes
import System import System
import Data.IORef import Data.IORef
import Data.Bits
import Data.So
import Data.Buffer import Data.Buffer
import Data.Vect import Data.Vect
import Data.NumIdr.Array.Coords import Data.NumIdr.Array.Coords
@ -12,160 +10,6 @@ import Data.NumIdr.Array.Rep
%default total %default total
public export
interface ByteRep a where
bytes : Nat
toBytes : a -> Vect bytes Bits8
fromBytes : Vect bytes Bits8 -> a
export
ByteRep Bits8 where
bytes = 1
toBytes x = [x]
fromBytes [x] = x
export
ByteRep Bits16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Bits32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Bits64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int8 where
bytes = 1
toBytes x = [cast x]
fromBytes [x] = cast x
export
ByteRep Int16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Int32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Int64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Bool where
bytes = 1
toBytes b = [if b then 1 else 0]
fromBytes [x] = x /= 0
export
ByteRep a => ByteRep b => ByteRep (a, b) where
bytes = bytes {a} + bytes {a=b}
toBytes (x,y) = toBytes x ++ toBytes y
fromBytes = bimap fromBytes fromBytes . splitAt _
export
{n : _} -> ByteRep a => ByteRep (Vect n a) where
bytes = n * bytes {a}
toBytes xs = concat $ map toBytes xs
fromBytes {n = 0} bs = []
fromBytes {n = S n} bs =
let (bs1, bs2) = splitAt _ bs
in fromBytes bs1 :: fromBytes bs2
export export
readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a
readBuffer i buf = do readBuffer i buf = do
@ -226,6 +70,15 @@ export
index : ByteRep a => Nat -> PrimArrayBytes o s -> a index : ByteRep a => Nat -> PrimArrayBytes o s -> a
index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf
export
copy : {o,s : _} -> PrimArrayBytes o s -> PrimArrayBytes o s
copy (MkPABytes sts buf) =
MkPABytes sts (unsafePerformIO $ do
Just buf' <- newBuffer !(rawSize buf)
| Nothing => die "Cannot create array"
copyData buf 0 !(rawSize buf) buf' 0
pure buf)
export export
reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s
reorder @{br} arr@(MkPABytes sts buf) = reorder @{br} arr@(MkPABytes sts buf) =
@ -242,3 +95,31 @@ fromList @{br} s xs = bufferAction @{br} s $ go xs 0
go : List a -> Nat -> Buffer -> IO () go : List a -> Nat -> Buffer -> IO ()
go [] _ buf = pure () go [] _ buf = pure ()
go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf
export
foldl : {s : _} -> ByteRep a => (b -> a -> b) -> b -> PrimArrayBytes o s -> b
foldl f z (MkPABytes sts buf) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred $ product s] $ \n => do
x <- readIORef ref
y <- readBuffer n buf
writeIORef ref (f x y)
readIORef ref
export
foldr : {s : _} -> ByteRep a => (a -> b -> b) -> b -> PrimArrayBytes o s -> b
foldr f z (MkPABytes sts buf) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [pred $ product s..0] $ \n => do
x <- readBuffer n buf
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : {o,s : _} -> ByteRep a => ByteRep b => Applicative f => (a -> f b) -> PrimArrayBytes o s -> f (PrimArrayBytes o s)
traverse f = map (Bytes.fromList _) . traverse f . foldr (::) []

View file

@ -1,5 +1,6 @@
module Data.NumIdr.PrimArray.Delayed module Data.NumIdr.PrimArray.Delayed
import Data.List
import Data.Vect import Data.Vect
import Data.NP import Data.NP
import Data.NumIdr.Array.Rep import Data.NumIdr.Array.Rep
@ -18,6 +19,20 @@ constant s x _ = x
export export
checkRange : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s) indexRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed s a -> PrimArrayDelayed (newShape rs) a
checkRange [] [] = Just [] indexRange [] v = v
checkRange (d :: s) (i :: is) = (::) <$> natToFin i d <*> checkRange s is indexRange (r :: rs) v with (cRangeToList r)
_ | Left i = indexRange rs (\is' => v (believe_me i :: is'))
_ | Right is = \(i::is') => indexRange rs (\is'' => v (believe_me (Vect.index i (fromList is)) :: is'')) is'
export
indexSetRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed (newShape rs) a
-> PrimArrayDelayed s a -> PrimArrayDelayed s a
indexSetRange {s=[]} [] rv _ = rv
indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r)
_ | Left i = \(i'::is) => if i == finToNat i'
then indexSetRange rs rv (v . (i'::)) is
else v (i'::is)
_ | Right is = \(i'::is') => case findIndex (== finToNat i') is of
Just x => indexSetRange rs (rv . (x::)) (v . (i'::)) is'
Nothing => v (i'::is')

View file

@ -1,6 +1,7 @@
module Data.NumIdr.PrimArray.Linked module Data.NumIdr.PrimArray.Linked
import Data.Vect import Data.Vect
import Data.DPair
import Data.NP import Data.NP
import Data.NumIdr.Array.Rep import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords import Data.NumIdr.Array.Coords
@ -18,6 +19,25 @@ index : Coords s -> Vects s a -> a
index [] x = x index [] x = x
index (c :: cs) xs = index cs (index c xs) index (c :: cs) xs = index cs (index c xs)
export
update : Coords s -> (a -> a) -> Vects s a -> Vects s a
update [] f v = f v
update (i :: is) f v = updateAt i (update is f) v
export
indexRange : {s : _} -> (rs : CoordsRange s) -> Vects s a -> Vects (newShape rs) a
indexRange [] v = v
indexRange (r :: rs) v with (cRangeToList r)
_ | Left i = indexRange rs (Vect.index (believe_me i) v)
_ | Right is = believe_me $ map (\i => indexRange rs (Vect.index (believe_me i) v)) is
export
indexSetRange : {s : _} -> (rs : CoordsRange s) -> Vects (newShape rs) a -> Vects s a -> Vects s a
indexSetRange {s=[]} [] rv _ = rv
indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r)
_ | Left i = updateAt (believe_me i) (indexSetRange rs rv) v
_ | Right is = foldl (\v,i => updateAt (believe_me i) (indexSetRange rs (Vect.index (believe_me i) rv)) v) v is
export export
fromFunction : {s : _} -> (Coords s -> a) -> Vects s a fromFunction : {s : _} -> (Coords s -> a) -> Vects s a
@ -32,3 +52,23 @@ fromFunctionNB {s = d :: s} f = tabulate' {n=d} (\i => fromFunctionNB (f . (i::)
tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a
tabulate' {n=Z} f = [] tabulate' {n=Z} f = []
tabulate' {n=S _} f = f Z :: tabulate' (f . S) tabulate' {n=S _} f = f Z :: tabulate' (f . S)
export
mapVects : {s : _} -> (a -> b) -> Vects s a -> Vects s b
mapVects {s = []} f x = f x
mapVects {s = _ :: _} f v = Prelude.map (mapVects f) v
export
foldl : {s : _} -> (b -> a -> b) -> b -> Vects s a -> b
foldl {s=[]} f z v = f z v
foldl {s=_::_} f z v = Prelude.foldl (Linked.foldl f) z v
export
foldr : {s : _} -> (a -> b -> b) -> b -> Vects s a -> b
foldr {s=[]} f z v = f v z
foldr {s=_::_} f z v = Prelude.foldr (flip $ Linked.foldr f) z v
export
traverse : {s : _} -> Applicative f => (a -> f b) -> Vects s a -> f (Vects s b)
traverse {s=[]} f v = f v
traverse {s=_::_} f v = Prelude.traverse (Linked.traverse f) v