Revamp all of the standard array API
This commit is contained in:
parent
ccb689af42
commit
a0068469c5
|
@ -26,6 +26,11 @@ public export
|
||||||
elems {all=[_]} [x] = show x
|
elems {all=[_]} [x] = show x
|
||||||
elems {all=_::_} (x :: xs) = show x ++ ", " ++ elems xs
|
elems {all=_::_} (x :: xs) = show x ++ ", " ++ elems xs
|
||||||
|
|
||||||
|
-- use this when idris can't figure out the constraint above
|
||||||
|
public export
|
||||||
|
eqNP : forall f. (forall t. Eq (f t)) => (x, y : NP f ts) -> Bool
|
||||||
|
eqNP [] [] = True
|
||||||
|
eqNP (x :: xs) (y :: ys) = x == y && eqNP xs ys
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
|
|
|
@ -3,4 +3,4 @@ module Data.NumIdr.Array
|
||||||
import public Data.NP
|
import public Data.NP
|
||||||
import public Data.NumIdr.Array.Array
|
import public Data.NumIdr.Array.Array
|
||||||
import public Data.NumIdr.Array.Coords
|
import public Data.NumIdr.Array.Coords
|
||||||
import public Data.NumIdr.Array.Order
|
import public Data.NumIdr.Array.Rep
|
||||||
|
|
|
@ -112,7 +112,7 @@ viewShape arr = rewrite shapeEq arr in
|
||||||
export
|
export
|
||||||
repeat : {default B rep : Rep} -> RepConstraint rep a =>
|
repeat : {default B rep : Rep} -> RepConstraint rep a =>
|
||||||
(s : Vect rk Nat) -> a -> Array s a
|
(s : Vect rk Nat) -> a -> Array s a
|
||||||
repeat s x = MkArray rep s (constant s x)
|
repeat s x = MkArray rep s (PrimArray.constant s x)
|
||||||
|
|
||||||
||| Create an array filled with zeros.
|
||| Create an array filled with zeros.
|
||||||
|||
|
|||
|
||||||
|
@ -138,9 +138,9 @@ ones {rep} s = repeat {rep} s 1
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ rep The internal representation of the constructed array
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
fromVect : {default B rep : Rep} -> RepConstraint rep a =>
|
fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
|
||||||
(s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
(s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||||
fromVect s v = ?fv
|
fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $ toList v)
|
||||||
|
|
||||||
|
|
||||||
||| Create an array by taking values from a stream.
|
||| Create an array by taking values from a stream.
|
||||||
|
@ -148,7 +148,7 @@ fromVect s v = ?fv
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ rep The internal representation of the constructed array
|
||| @ rep The internal representation of the constructed array
|
||||||
export
|
export
|
||||||
fromStream : {default B rep : Rep} -> RepConstraint rep a =>
|
fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
|
||||||
(s : Vect rk Nat) -> Stream a -> Array s a
|
(s : Vect rk Nat) -> Stream a -> Array s a
|
||||||
fromStream {rep} s str = fromVect {rep} s (take _ str)
|
fromStream {rep} s str = fromVect {rep} s (take _ str)
|
||||||
|
|
||||||
|
@ -211,12 +211,12 @@ index is (MkArray _ _ arr) = PrimArray.index is arr
|
||||||
export %inline
|
export %inline
|
||||||
(!!) : Array s a -> Coords s -> a
|
(!!) : Array s a -> Coords s -> a
|
||||||
arr !! is = index is arr
|
arr !! is = index is arr
|
||||||
{-
|
|
||||||
||| Update the entry at the given coordinates using the function.
|
||| Update the entry at the given coordinates using the function.
|
||||||
export
|
export
|
||||||
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
||||||
indexUpdate is f (MkArray ord sts s arr) =
|
indexUpdate is f (MkArray rep @{rc} s arr) =
|
||||||
MkArray ord sts s (updateAt (getLocation sts is) f arr)
|
MkArray rep @{rc} s (indexUpdate @{rc} is f arr)
|
||||||
|
|
||||||
||| Set the entry at the given coordinates to the given value.
|
||| Set the entry at the given coordinates to the given value.
|
||||||
export
|
export
|
||||||
|
@ -227,17 +227,8 @@ indexSet is = indexUpdate is . const
|
||||||
||| Index the array using the given range of coordinates, returning a new array.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
export
|
export
|
||||||
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||||
indexRange {s} rs arr with (viewShape arr)
|
indexRange rs (MkArray rep @{rc} s arr) =
|
||||||
_ | Shape s =
|
MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr)
|
||||||
let ord = getOrder arr
|
|
||||||
sts = calcStrides ord s'
|
|
||||||
in MkArray ord sts s'
|
|
||||||
(unsafeFromIns (product s') $
|
|
||||||
map (\(is,is') => (getLocation' sts is',
|
|
||||||
index (getLocation' (strides arr) is) (getPrim arr)))
|
|
||||||
$ getCoordsList rs)
|
|
||||||
where s' : Vect ? Nat
|
|
||||||
s' = newShape rs
|
|
||||||
|
|
||||||
||| Index the array using the given range of coordinates, returning a new array.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
|||
|
|||
|
||||||
|
@ -250,13 +241,8 @@ arr !!.. rs = indexRange rs arr
|
||||||
export
|
export
|
||||||
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
|
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
|
||||||
Array s a -> Array s a
|
Array s a -> Array s a
|
||||||
indexSetRange rs rpl (MkArray ord sts s arr) =
|
indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) =
|
||||||
MkArray ord sts s (unsafePerformIO $ do
|
MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRep rpl) arr)
|
||||||
let arr' = copy arr
|
|
||||||
unsafeDoIns (map (\(is,is') =>
|
|
||||||
(getLocation' sts is, index (getLocation' (strides rpl) is') (getPrim rpl)))
|
|
||||||
$ getCoordsList rs) arr'
|
|
||||||
pure arr')
|
|
||||||
|
|
||||||
|
|
||||||
||| Update the sub-array at the given range of coordinates by applying
|
||| Update the sub-array at the given range of coordinates by applying
|
||||||
|
@ -266,13 +252,14 @@ indexUpdateRange : (rs : CoordsRange s) ->
|
||||||
(Array (newShape rs) a -> Array (newShape rs) a) ->
|
(Array (newShape rs) a -> Array (newShape rs) a) ->
|
||||||
Array s a -> Array s a
|
Array s a -> Array s a
|
||||||
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
|
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
|
||||||
-}
|
|
||||||
|
|
||||||
||| Index the array using the given coordinates, returning `Nothing` if the
|
||| Index the array using the given coordinates, returning `Nothing` if the
|
||||||
||| coordinates are out of bounds.
|
||| coordinates are out of bounds.
|
||||||
export
|
export
|
||||||
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
|
||||||
indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
|
indexNB is (MkArray rep @{rc} s arr) =
|
||||||
|
map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is)
|
||||||
|
|
||||||
||| Index the array using the given coordinates, returning `Nothing` if the
|
||| Index the array using the given coordinates, returning `Nothing` if the
|
||||||
||| coordinates are out of bounds.
|
||| coordinates are out of bounds.
|
||||||
|
@ -281,15 +268,13 @@ indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
|
||||||
export %inline
|
export %inline
|
||||||
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
||||||
arr !? is = indexNB is arr
|
arr !? is = indexNB is arr
|
||||||
{-
|
|
||||||
||| Update the entry at the given coordinates using the function. `Nothing` is
|
||| Update the entry at the given coordinates using the function. `Nothing` is
|
||||||
||| returned if the coordinates are out of bounds.
|
||| returned if the coordinates are out of bounds.
|
||||||
export
|
export
|
||||||
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
|
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
|
||||||
indexUpdateNB is f (MkArray ord sts s arr) =
|
indexUpdateNB is f (MkArray rep @{rc} s arr) =
|
||||||
if all id $ zipWith (<) is s
|
map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is)
|
||||||
then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr)
|
|
||||||
else Nothing
|
|
||||||
|
|
||||||
||| Set the entry at the given coordinates using the function. `Nothing` is
|
||| Set the entry at the given coordinates using the function. `Nothing` is
|
||||||
||| returned if the coordinates are out of bounds.
|
||| returned if the coordinates are out of bounds.
|
||||||
|
@ -302,20 +287,8 @@ indexSetNB is = indexUpdateNB is . const
|
||||||
||| If any of the given indices are out of bounds, then `Nothing` is returned.
|
||| If any of the given indices are out of bounds, then `Nothing` is returned.
|
||||||
export
|
export
|
||||||
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
|
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
|
||||||
indexRangeNB {s} rs arr with (viewShape arr)
|
indexRangeNB rs (MkArray rep @{rc} s arr) =
|
||||||
_ | Shape s =
|
map (\rs => believe_me $ Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs)
|
||||||
if validateShape s rs
|
|
||||||
then
|
|
||||||
let ord = getOrder arr
|
|
||||||
sts = calcStrides ord s'
|
|
||||||
in Just $ MkArray ord sts s'
|
|
||||||
(unsafeFromIns (product s') $
|
|
||||||
map (\(is,is') => (getLocation' sts is',
|
|
||||||
index (getLocation' (strides arr) is) (getPrim arr)))
|
|
||||||
$ getCoordsList s rs)
|
|
||||||
else Nothing
|
|
||||||
where s' : Vect ? Nat
|
|
||||||
s' = newShape s rs
|
|
||||||
|
|
||||||
||| Index the array using the given range of coordinates, returning a new array.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
||| If any of the given indices are out of bounds, then `Nothing` is returned.
|
||| If any of the given indices are out of bounds, then `Nothing` is returned.
|
||||||
|
@ -324,7 +297,7 @@ indexRangeNB {s} rs arr with (viewShape arr)
|
||||||
export %inline
|
export %inline
|
||||||
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
|
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
|
||||||
arr !?.. rs = indexRangeNB rs arr
|
arr !?.. rs = indexRangeNB rs arr
|
||||||
-}
|
|
||||||
|
|
||||||
||| Index the array using the given coordinates.
|
||| Index the array using the given coordinates.
|
||||||
||| WARNING: This function does not perform any bounds check on its inputs.
|
||| WARNING: This function does not perform any bounds check on its inputs.
|
||||||
|
@ -342,30 +315,21 @@ export %inline
|
||||||
(!#) : Array {rk} s a -> Vect rk Nat -> a
|
(!#) : Array {rk} s a -> Vect rk Nat -> a
|
||||||
arr !# is = indexUnsafe is arr
|
arr !# is = indexUnsafe is arr
|
||||||
|
|
||||||
{-
|
|
||||||
||| Index the array using the given range of coordinates, returning a new array.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
||| WARNING: This function does not perform any bounds check on its inputs.
|
||| WARNING: This function does not perform any bounds check on its inputs.
|
||||||
||| Misuse of this function can easily break memory safety.
|
||| Misuse of this function can easily break memory safety.
|
||||||
export
|
export %unsafe
|
||||||
indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
|
indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
|
||||||
indexRangeUnsafe {s} rs arr with (viewShape arr)
|
indexRangeUnsafe rs (MkArray rep @{rc} s arr) =
|
||||||
_ | Shape s =
|
believe_me $ Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr)
|
||||||
let ord = getOrder arr
|
|
||||||
sts = calcStrides ord s'
|
|
||||||
in MkArray ord sts s'
|
|
||||||
(unsafeFromIns (product s') $
|
|
||||||
map (\(is,is') => (getLocation' sts is',
|
|
||||||
index (getLocation' (strides arr) is) (getPrim arr)))
|
|
||||||
$ getCoordsList s rs)
|
|
||||||
where s' : Vect ? Nat
|
|
||||||
s' = newShape s rs
|
|
||||||
|
|
||||||
||| Index the array using the given range of coordinates, returning a new array.
|
||| Index the array using the given range of coordinates, returning a new array.
|
||||||
||| WARNING: This function does not perform any bounds check on its inputs.
|
||| WARNING: This function does not perform any bounds check on its inputs.
|
||||||
||| Misuse of this function can easily break memory safety.
|
||| Misuse of this function can easily break memory safety.
|
||||||
|||
|
|||
|
||||||
||| This is the operator form of `indexRangeUnsafe`.
|
||| This is the operator form of `indexRangeUnsafe`.
|
||||||
export %inline
|
export %inline %unsafe
|
||||||
(!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
|
(!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
|
||||||
arr !#.. is = indexRangeUnsafe is arr
|
arr !#.. is = indexRangeUnsafe is arr
|
||||||
|
|
||||||
|
@ -375,35 +339,42 @@ arr !#.. is = indexRangeUnsafe is arr
|
||||||
-- Operations on arrays
|
-- Operations on arrays
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
||| Reshape the array into the given shape and reinterpret it according to
|
|
||||||
||| the given order.
|
|
||||||
|||
|
|
||||||
||| @ s' The shape to convert the array to
|
|
||||||
||| @ ord The order to reinterpret the array by
|
|
||||||
export
|
export
|
||||||
reshape' : (s' : Vect rk' Nat) -> (ord : Order) -> Array {rk} s a ->
|
mapArray' : (a -> a) -> Array s a -> Array s a
|
||||||
(0 ok : product s = product s') => Array s' a
|
mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
|
||||||
reshape' s' ord' arr = MkArray ord' (calcStrides ord' s') s' (getPrim arr)
|
|
||||||
|
export
|
||||||
|
mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
|
||||||
|
mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
|
||||||
|
zipWithArray' {s} f a b with (viewShape a)
|
||||||
|
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b))
|
||||||
|
@{mergeRepConstraint (getRepC a) (getRepC b)} s
|
||||||
|
$ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
|
||||||
|
(\is => f (a !# is) (b !# is))
|
||||||
|
|
||||||
|
export
|
||||||
|
zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
|
||||||
|
RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
|
||||||
|
zipWithArray {s} f a b @{rc} with (viewShape a)
|
||||||
|
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s
|
||||||
|
$ PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is))
|
||||||
|
|
||||||
|
|
||||||
||| Reshape the array into the given shape.
|
||| Reshape the array into the given shape.
|
||||||
|||
|
|||
|
||||||
||| @ s' The shape to convert the array to
|
||| @ s' The shape to convert the array to
|
||||||
export
|
export
|
||||||
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) =>
|
||||||
(0 ok : product s = product s') => Array s' a
|
(0 ok : product s = product s') => Array s' a
|
||||||
reshape s' arr = reshape' s' (getOrder arr) arr
|
reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr)
|
||||||
|
|
||||||
|
|
||||||
||| Change the internal order of the array's elements.
|
||| Change the internal order of the array's elements.
|
||||||
export
|
export
|
||||||
reorder : Order -> Array s a -> Array s a
|
convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a
|
||||||
reorder {s} ord' arr with (viewShape arr)
|
convertRep rep (MkArray _ s arr) = MkArray rep s (PrimArray.convertRep arr)
|
||||||
_ | Shape s = let sts = calcStrides ord' s
|
|
||||||
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
|
||||||
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
|
||||||
getAllCoords' s)
|
|
||||||
|
|
||||||
|
|
||||||
||| Resize the array to a new shape, preserving the coordinates of the original
|
||| Resize the array to a new shape, preserving the coordinates of the original
|
||||||
||| elements. New coordinates are filled with a default value.
|
||| elements. New coordinates are filled with a default value.
|
||||||
|
@ -412,7 +383,7 @@ reorder {s} ord' arr with (viewShape arr)
|
||||||
||| @ def The default value to fill the array with
|
||| @ def The default value to fill the array with
|
||||||
export
|
export
|
||||||
resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
|
resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
|
||||||
resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB)
|
resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB)
|
||||||
|
|
||||||
||| Resize the array to a new shape, preserving the coordinates of the original
|
||| Resize the array to a new shape, preserving the coordinates of the original
|
||||||
||| elements. This function requires a proof that the new shape is strictly
|
||| elements. This function requires a proof that the new shape is strictly
|
||||||
|
@ -427,20 +398,11 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
|
||||||
resizeLTE s' arr = resize s' (believe_me ()) arr
|
resizeLTE s' arr = resize s' (believe_me ()) arr
|
||||||
|
|
||||||
|
|
||||||
||| List all of the values in an array in row-major order.
|
|
||||||
export
|
|
||||||
elements : Array {rk} s a -> Vect (product s) a
|
|
||||||
elements (MkArray _ sts sh p) =
|
|
||||||
let elems = map (flip index p . getLocation' sts) (getAllCoords' sh)
|
|
||||||
-- TODO: prove that the number of elements is `product s`
|
|
||||||
in assert_total $ case toVect (product sh) elems of Just v => v
|
|
||||||
|
|
||||||
|
|
||||||
||| List all of the values in an array along with their coordinates.
|
||| List all of the values in an array along with their coordinates.
|
||||||
export
|
export
|
||||||
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
|
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
|
||||||
enumerateNB (MkArray _ sts sh p) =
|
enumerateNB (MkArray _ s arr) =
|
||||||
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
|
map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s)
|
||||||
|
|
||||||
||| List all of the values in an array along with their coordinates.
|
||| List all of the values in an array along with their coordinates.
|
||||||
export
|
export
|
||||||
|
@ -448,6 +410,12 @@ enumerate : Array s a -> List (Coords s, a)
|
||||||
enumerate {s} arr with (viewShape arr)
|
enumerate {s} arr with (viewShape arr)
|
||||||
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
|
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
|
||||||
|
|
||||||
|
||| List all of the values in an array in row-major order.
|
||||||
|
export
|
||||||
|
elements : Array {rk} s a -> Vect (product s) a
|
||||||
|
elements (MkArray _ s arr) =
|
||||||
|
believe_me $ Vect.fromList $
|
||||||
|
map (flip PrimArray.indexUnsafe arr) (getAllCoords' s)
|
||||||
|
|
||||||
||| Join two arrays along a particular axis, e.g. combining two matrices
|
||| Join two arrays along a particular axis, e.g. combining two matrices
|
||||||
||| vertically or horizontally. All other axes of the arrays must have the
|
||| vertically or horizontally. All other axes of the arrays must have the
|
||||||
|
@ -457,16 +425,12 @@ enumerate {s} arr with (viewShape arr)
|
||||||
export
|
export
|
||||||
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||||
Array (updateAt axis (+d) s) a
|
Array (updateAt axis (+d) s) a
|
||||||
concat axis a b = let sA = shape a
|
concat {s,d} axis a b with (viewShape a, viewShape b)
|
||||||
sB = shape b
|
_ | (Shape s, Shape (replaceAt axis d s)) =
|
||||||
dA = index axis sA
|
believe_me $ Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)}
|
||||||
dB = index axis sB
|
@{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s)
|
||||||
s = replaceAt axis (dA + dB) sA
|
(\is => let limit = index axis s
|
||||||
sts = calcStrides COrder s
|
in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is)
|
||||||
ins = map (mapFst $ getLocation' sts . toNB) (enumerate a)
|
|
||||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNB) (enumerate b)
|
|
||||||
-- TODO: prove that the type-level shape and `s` are equivalent
|
|
||||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
|
||||||
|
|
||||||
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
|
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
|
||||||
|||
|
|||
|
||||||
|
@ -500,12 +464,11 @@ splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
|
||||||
Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
|
Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
|
||||||
splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
|
splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
|
||||||
|
|
||||||
|
|
||||||
||| Construct the transpose of an array by reversing the order of its axes.
|
||| Construct the transpose of an array by reversing the order of its axes.
|
||||||
export
|
export
|
||||||
transpose : Array s a -> Array (reverse s) a
|
transpose : Array s a -> Array (reverse s) a
|
||||||
transpose {s} arr with (viewShape arr)
|
transpose {s} arr with (viewShape arr)
|
||||||
_ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is)
|
_ | Shape s = fromFunctionNB _ (\is => arr !# reverse is)
|
||||||
|
|
||||||
||| Construct the transpose of an array by reversing the order of its axes.
|
||| Construct the transpose of an array by reversing the order of its axes.
|
||||||
|||
|
|||
|
||||||
|
@ -551,41 +514,56 @@ permuteInAxis {s} axis p arr with (viewShape arr)
|
||||||
-- Implementations
|
-- Implementations
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
-- Most of these implementations apply the operation pointwise. If there are
|
|
||||||
-- multiple arrays involved with different orders, then all of the arrays are
|
|
||||||
-- reordered to match one.
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Zippable (Array s) where
|
Zippable (Array s) where
|
||||||
zipWith {s} f a b with (viewShape a)
|
zipWith {s} f a b with (viewShape a)
|
||||||
_ | Shape s = MkArray (getOrder a) (strides a) s $
|
_ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b))
|
||||||
if getOrder a == getOrder b
|
@{mergeNCRepConstraint} s
|
||||||
then unsafeZipWith f (getPrim a) (getPrim b)
|
$ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
|
||||||
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
(\is => f (a !# is) (b !# is))
|
||||||
|
|
||||||
zipWith3 {s} f a b c with (viewShape a)
|
zipWith3 {s} f a b c with (viewShape a)
|
||||||
_ | Shape s = MkArray (getOrder a) (strides a) s $
|
_ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c))
|
||||||
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
|
@{mergeNCRepConstraint} s
|
||||||
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
|
$ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
|
||||||
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
(\is => f (a !# is) (b !# is) (c !# is))
|
||||||
(getPrim $ reorder (getOrder a) c)
|
|
||||||
|
|
||||||
unzipWith {s} f arr with (viewShape arr)
|
unzipWith {s} f arr with (viewShape arr)
|
||||||
_ | Shape s = case unzipWith f (getPrim arr) of
|
_ | Shape s =
|
||||||
(a, b) => (MkArray (getOrder arr) (strides arr) s a,
|
let rep : Rep
|
||||||
MkArray (getOrder arr) (strides arr) s b)
|
rep = forceRepNC $ getRep arr
|
||||||
|
in (MkArray rep
|
||||||
|
@{forceRepConstraint} s
|
||||||
|
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
|
||||||
|
(\is => fst $ f (arr !# is)),
|
||||||
|
MkArray rep
|
||||||
|
@{forceRepConstraint} s
|
||||||
|
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
|
||||||
|
(\is => snd $ f (arr !# is)))
|
||||||
|
|
||||||
unzipWith3 {s} f arr with (viewShape arr)
|
unzipWith3 {s} f arr with (viewShape arr)
|
||||||
_ | Shape s = case unzipWith3 f (getPrim arr) of
|
_ | Shape s =
|
||||||
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
|
let rep : Rep
|
||||||
MkArray (getOrder arr) (strides arr) s b,
|
rep = forceRepNC $ getRep arr
|
||||||
MkArray (getOrder arr) (strides arr) s c)
|
in (MkArray rep
|
||||||
|
@{forceRepConstraint} s
|
||||||
|
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
|
||||||
|
(\is => fst $ f (arr !# is)),
|
||||||
|
MkArray rep
|
||||||
|
@{forceRepConstraint} s
|
||||||
|
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
|
||||||
|
(\is => fst $ snd $ f (arr !# is)),
|
||||||
|
MkArray rep
|
||||||
|
@{forceRepConstraint} s
|
||||||
|
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
|
||||||
|
(\is => snd $ snd $ f (arr !# is)))
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Functor (Array s) where
|
Functor (Array s) where
|
||||||
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr)
|
map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s
|
||||||
|
(mapPrim @{forceRepConstraint} @{forceRepConstraint} f
|
||||||
|
$ convertRep @{rc} @{forceRepConstraint} arr)
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Applicative (Array s) where
|
{s : _} -> Applicative (Array s) where
|
||||||
|
@ -602,15 +580,18 @@ export
|
||||||
|
|
||||||
export
|
export
|
||||||
Foldable (Array s) where
|
Foldable (Array s) where
|
||||||
foldl f z = foldl f z . getPrim
|
foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr
|
||||||
foldr f z = foldr f z . getPrim
|
foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr
|
||||||
null arr = size arr == Z
|
null (MkArray _ s _) = isZero (product s)
|
||||||
toList = toList . getPrim
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Traversable (Array s) where
|
Traversable (Array s) where
|
||||||
traverse f (MkArray ord sts s arr) =
|
traverse f (MkArray rep @{rc} s arr) =
|
||||||
map (MkArray ord sts s) (traverse f arr)
|
map (MkArray (forceRepNC rep) @{forceRepConstraint} s)
|
||||||
|
(PrimArray.traverse {rep=forceRepNC rep}
|
||||||
|
@{%search} @{forceRepConstraint} @{forceRepConstraint} f
|
||||||
|
(PrimArray.convertRep @{rc} @{forceRepConstraint} arr))
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
@ -619,13 +600,11 @@ Cast a b => Cast (Array s a) (Array s b) where
|
||||||
|
|
||||||
export
|
export
|
||||||
Eq a => Eq (Array s a) where
|
Eq a => Eq (Array s a) where
|
||||||
a == b = if getOrder a == getOrder b
|
a == b = and $ zipWith (delay .: (==)) (convertRep D a) (convertRep D b)
|
||||||
then unsafeEq (getPrim a) (getPrim b)
|
|
||||||
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Semigroup a => Semigroup (Array s a) where
|
Semigroup a => Semigroup (Array s a) where
|
||||||
(<+>) = zipWith (<+>)
|
(<+>) = zipWithArray' (<+>)
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Monoid a => Monoid (Array s a) where
|
{s : _} -> Monoid a => Monoid (Array s a) where
|
||||||
|
@ -635,36 +614,34 @@ export
|
||||||
-- were moved into its own interface, this constraint could be removed.
|
-- were moved into its own interface, this constraint could be removed.
|
||||||
export
|
export
|
||||||
{s : _} -> Num a => Num (Array s a) where
|
{s : _} -> Num a => Num (Array s a) where
|
||||||
(+) = zipWith (+)
|
(+) = zipWithArray' (+)
|
||||||
(*) = zipWith (*)
|
(*) = zipWithArray' (*)
|
||||||
|
|
||||||
fromInteger = repeat s . fromInteger
|
fromInteger = repeat s . fromInteger
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Neg a => Neg (Array s a) where
|
{s : _} -> Neg a => Neg (Array s a) where
|
||||||
negate = map negate
|
negate = mapArray' negate
|
||||||
(-) = zipWith (-)
|
(-) = zipWithArray' (-)
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Fractional a => Fractional (Array s a) where
|
{s : _} -> Fractional a => Fractional (Array s a) where
|
||||||
recip = map recip
|
recip = mapArray' recip
|
||||||
(/) = zipWith (/)
|
(/) = zipWithArray' (/)
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Num a => Mult a (Array {rk} s a) (Array s a) where
|
Num a => Mult a (Array {rk} s a) (Array s a) where
|
||||||
(*.) x = map (*x)
|
(*.) x = mapArray' (*x)
|
||||||
|
|
||||||
export
|
export
|
||||||
Num a => Mult (Array {rk} s a) a (Array s a) where
|
Num a => Mult (Array {rk} s a) a (Array s a) where
|
||||||
(*.) = flip (*.)
|
(*.) = flip (*.)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
Show a => Show (Array s a) where
|
Show a => Show (Array s a) where
|
||||||
showPrec d arr = let orderedElems = PrimArray.toList $ getPrim $
|
showPrec d arr = let orderedElems = toList $ elements arr
|
||||||
if getOrder arr == COrder then arr else reorder COrder arr
|
|
||||||
in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems
|
in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems
|
||||||
where
|
where
|
||||||
splitWindow : Nat -> List String -> List (List String)
|
splitWindow : Nat -> List String -> List (List String)
|
||||||
|
@ -691,7 +668,7 @@ Show a => Show (Array s a) where
|
||||||
||| Linearly interpolate between two arrays.
|
||| Linearly interpolate between two arrays.
|
||||||
export
|
export
|
||||||
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
|
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
|
||||||
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
|
lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t)
|
||||||
|
|
||||||
|
|
||||||
||| Calculate the square of an array's Eulidean norm.
|
||| Calculate the square of an array's Eulidean norm.
|
||||||
|
@ -715,4 +692,3 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
|
||||||
export
|
export
|
||||||
pnorm : (p : Double) -> Array s Double -> Double
|
pnorm : (p : Double) -> Array s Double -> Double
|
||||||
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
||||||
-}
|
|
||||||
|
|
|
@ -43,6 +43,11 @@ toNB : Coords {rk} s -> Vect rk Nat
|
||||||
toNB [] = []
|
toNB [] = []
|
||||||
toNB (i :: is) = finToNat i :: toNB is
|
toNB (i :: is) = finToNat i :: toNB is
|
||||||
|
|
||||||
|
export
|
||||||
|
validateCoords : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
|
||||||
|
validateCoords [] [] = Just []
|
||||||
|
validateCoords (d :: s) (i :: is) = (::) <$> natToFin i d <*> validateCoords s is
|
||||||
|
|
||||||
|
|
||||||
namespace Strict
|
namespace Strict
|
||||||
public export
|
public export
|
||||||
|
@ -155,6 +160,44 @@ namespace Strict
|
||||||
|
|
||||||
namespace NB
|
namespace NB
|
||||||
export
|
export
|
||||||
|
validateCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> Maybe (CoordsRange s)
|
||||||
|
validateCRange [] [] = Just []
|
||||||
|
validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |]
|
||||||
|
where
|
||||||
|
validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n)
|
||||||
|
validate' n All = Just All
|
||||||
|
validate' n (StartBound x) =
|
||||||
|
case isLTE x n of
|
||||||
|
Yes _ => Just (StartBound (natToFinLT x))
|
||||||
|
_ => Nothing
|
||||||
|
validate' n (EndBound x) =
|
||||||
|
case isLTE x n of
|
||||||
|
Yes _ => Just (EndBound (natToFinLT x))
|
||||||
|
_ => Nothing
|
||||||
|
validate' n (Bounds x y) =
|
||||||
|
case (isLTE x n, isLTE y n) of
|
||||||
|
(Yes _, Yes _) => Just (Bounds (natToFinLT x) (natToFinLT y))
|
||||||
|
_ => Nothing
|
||||||
|
validate' n (Indices xs) = Indices <$> traverse
|
||||||
|
(\x => case isLT x n of
|
||||||
|
Yes _ => Just (natToFinLT x)
|
||||||
|
No _ => Nothing) xs
|
||||||
|
validate' n (Filter f) = Just (Filter (f . finToNat))
|
||||||
|
|
||||||
|
export %unsafe
|
||||||
|
assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s
|
||||||
|
assertCRange [] [] = []
|
||||||
|
assertCRange (d :: s) (r :: rs) = assert' d r :: assertCRange s rs
|
||||||
|
where
|
||||||
|
assert' : (n : Nat) -> CRangeNB -> CRange n
|
||||||
|
assert' n All = All
|
||||||
|
assert' n (StartBound x) = StartBound (believe_me x)
|
||||||
|
assert' n (EndBound x) = EndBound (believe_me x)
|
||||||
|
assert' n (Bounds x y) = Bounds (believe_me x) (believe_me y)
|
||||||
|
assert' n (Indices xs) = Indices (believe_me <$> xs)
|
||||||
|
assert' n (Filter f) = Filter (f . finToNat)
|
||||||
|
|
||||||
|
public export
|
||||||
cRangeNBToList : Nat -> CRangeNB -> List Nat
|
cRangeNBToList : Nat -> CRangeNB -> List Nat
|
||||||
cRangeNBToList s All = range 0 s
|
cRangeNBToList s All = range 0 s
|
||||||
cRangeNBToList s (StartBound x) = range x s
|
cRangeNBToList s (StartBound x) = range x s
|
||||||
|
@ -163,38 +206,11 @@ namespace NB
|
||||||
cRangeNBToList s (Indices xs) = nub xs
|
cRangeNBToList s (Indices xs) = nub xs
|
||||||
cRangeNBToList s (Filter p) = filter p $ range 0 s
|
cRangeNBToList s (Filter p) = filter p $ range 0 s
|
||||||
|
|
||||||
export
|
|
||||||
validateCRangeNB : Nat -> CRangeNB -> Bool
|
|
||||||
validateCRangeNB s All = True
|
|
||||||
validateCRangeNB s (StartBound x) = x < s
|
|
||||||
validateCRangeNB s (EndBound x) = x <= s
|
|
||||||
validateCRangeNB s (Bounds x y) = x < s && y <= s
|
|
||||||
validateCRangeNB s (Indices xs) = all (<s) xs
|
|
||||||
validateCRangeNB s (Filter p) = True
|
|
||||||
|
|
||||||
||| Calculate the new shape given by a coordinate range.
|
||| Calculate the new shape given by a coordinate range.
|
||||||
export
|
public export
|
||||||
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
|
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
|
||||||
newShape = zipWith (length .: cRangeNBToList)
|
newShape = zipWith (length .: cRangeNBToList)
|
||||||
|
|
||||||
export
|
|
||||||
validateShape : Vect rk Nat -> Vect rk CRangeNB -> Bool
|
|
||||||
validateShape = all id .: zipWith validateCRangeNB
|
|
||||||
|
|
||||||
export
|
|
||||||
getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat
|
|
||||||
getNewPos = zipWith3 (\d,r,i => assert_total $
|
|
||||||
case findIndex (==i) (cRangeNBToList d r) of Just x => cast x)
|
|
||||||
|
|
||||||
export
|
|
||||||
getCoordsList : Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat, Vect rk Nat)
|
|
||||||
getCoordsList s rs = map (\is => (is, getNewPos s rs is)) $ go s rs
|
|
||||||
where
|
|
||||||
go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat)
|
|
||||||
go [] [] = [[]]
|
|
||||||
go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |]
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||||
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
module Data.NumIdr.Array.Rep
|
module Data.NumIdr.Array.Rep
|
||||||
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.Bits
|
||||||
|
import Data.Buffer
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
@ -23,6 +25,8 @@ data Order : Type where
|
||||||
||| fashion (the first axis is the least significant).
|
||| fashion (the first axis is the least significant).
|
||||||
FOrder : Order
|
FOrder : Order
|
||||||
|
|
||||||
|
%name Order o
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
Eq Order where
|
Eq Order where
|
||||||
|
@ -52,6 +56,9 @@ data Rep : Type where
|
||||||
Linked : Rep
|
Linked : Rep
|
||||||
Delayed : Rep
|
Delayed : Rep
|
||||||
|
|
||||||
|
%name Rep rep
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
B : Rep
|
B : Rep
|
||||||
B = Boxed COrder
|
B = Boxed COrder
|
||||||
|
@ -65,7 +72,231 @@ D : Rep
|
||||||
D = Delayed
|
D = Delayed
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
Eq Rep where
|
||||||
|
Bytes o == Bytes o' = o == o'
|
||||||
|
Boxed o == Boxed o' = o == o'
|
||||||
|
Linked == Linked = True
|
||||||
|
Delayed == Delayed = True
|
||||||
|
_ == _ = False
|
||||||
|
|
||||||
public export
|
public export
|
||||||
data LinearRep : Rep -> Type where
|
data LinearRep : Rep -> Type where
|
||||||
BytesIsL : LinearRep (Bytes o)
|
BytesIsL : LinearRep (Bytes o)
|
||||||
BoxedIsL : LinearRep (Boxed o)
|
BoxedIsL : LinearRep (Boxed o)
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
forceRepNC : Rep -> Rep
|
||||||
|
forceRepNC (Bytes o) = Boxed o
|
||||||
|
forceRepNC r = r
|
||||||
|
|
||||||
|
public export
|
||||||
|
mergeRep : Rep -> Rep -> Rep
|
||||||
|
mergeRep r r' = if r == Delayed || r' == Delayed then Delayed else r
|
||||||
|
|
||||||
|
public export
|
||||||
|
mergeRepNC : Rep -> Rep -> Rep
|
||||||
|
mergeRepNC r r' =
|
||||||
|
if r == Delayed || r' == Delayed then Delayed
|
||||||
|
else case r of
|
||||||
|
Bytes _ => case r' of
|
||||||
|
Bytes o => Boxed o
|
||||||
|
_ => r'
|
||||||
|
_ => r
|
||||||
|
|
||||||
|
public export
|
||||||
|
data NoConstraintRep : Rep -> Type where
|
||||||
|
BoxedNC : NoConstraintRep (Boxed o)
|
||||||
|
LinkedNC : NoConstraintRep Linked
|
||||||
|
DelayedNC : NoConstraintRep Delayed
|
||||||
|
|
||||||
|
public export
|
||||||
|
mergeRepNCCorrect : (r, r' : Rep) -> NoConstraintRep (mergeRepNC r r')
|
||||||
|
mergeRepNCCorrect Delayed _ = DelayedNC
|
||||||
|
mergeRepNCCorrect (Bytes y) (Bytes x) = BoxedNC
|
||||||
|
mergeRepNCCorrect (Boxed y) (Bytes x) = BoxedNC
|
||||||
|
mergeRepNCCorrect Linked (Bytes x) = LinkedNC
|
||||||
|
mergeRepNCCorrect (Bytes y) (Boxed x) = BoxedNC
|
||||||
|
mergeRepNCCorrect (Boxed y) (Boxed x) = BoxedNC
|
||||||
|
mergeRepNCCorrect Linked (Boxed x) = LinkedNC
|
||||||
|
mergeRepNCCorrect (Bytes x) Linked = LinkedNC
|
||||||
|
mergeRepNCCorrect (Boxed x) Linked = BoxedNC
|
||||||
|
mergeRepNCCorrect Linked Linked = LinkedNC
|
||||||
|
mergeRepNCCorrect (Bytes x) Delayed = DelayedNC
|
||||||
|
mergeRepNCCorrect (Boxed x) Delayed = DelayedNC
|
||||||
|
mergeRepNCCorrect Linked Delayed = DelayedNC
|
||||||
|
|
||||||
|
public export
|
||||||
|
forceRepNCCorrect : (r : Rep) -> NoConstraintRep (forceRepNC r)
|
||||||
|
forceRepNCCorrect (Bytes x) = BoxedNC
|
||||||
|
forceRepNCCorrect (Boxed x) = BoxedNC
|
||||||
|
forceRepNCCorrect Linked = LinkedNC
|
||||||
|
forceRepNCCorrect Delayed = DelayedNC
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Byte representations of elements
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
interface ByteRep a where
|
||||||
|
bytes : Nat
|
||||||
|
|
||||||
|
toBytes : a -> Vect bytes Bits8
|
||||||
|
fromBytes : Vect bytes Bits8 -> a
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Bits8 where
|
||||||
|
bytes = 1
|
||||||
|
|
||||||
|
toBytes x = [x]
|
||||||
|
fromBytes [x] = x
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Bits16 where
|
||||||
|
bytes = 2
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 8), cast x]
|
||||||
|
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Bits32 where
|
||||||
|
bytes = 4
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 24),
|
||||||
|
cast (x `shiftR` 16),
|
||||||
|
cast (x `shiftR` 8),
|
||||||
|
cast x]
|
||||||
|
fromBytes [b1,b2,b3,b4] =
|
||||||
|
cast b1 `shiftL` 24 .|.
|
||||||
|
cast b2 `shiftL` 16 .|.
|
||||||
|
cast b3 `shiftL` 8 .|.
|
||||||
|
cast b4
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Bits64 where
|
||||||
|
bytes = 8
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 56),
|
||||||
|
cast (x `shiftR` 48),
|
||||||
|
cast (x `shiftR` 40),
|
||||||
|
cast (x `shiftR` 32),
|
||||||
|
cast (x `shiftR` 24),
|
||||||
|
cast (x `shiftR` 16),
|
||||||
|
cast (x `shiftR` 8),
|
||||||
|
cast x]
|
||||||
|
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
|
||||||
|
cast b1 `shiftL` 56 .|.
|
||||||
|
cast b2 `shiftL` 48 .|.
|
||||||
|
cast b3 `shiftL` 40 .|.
|
||||||
|
cast b4 `shiftL` 32 .|.
|
||||||
|
cast b5 `shiftL` 24 .|.
|
||||||
|
cast b6 `shiftL` 16 .|.
|
||||||
|
cast b7 `shiftL` 8 .|.
|
||||||
|
cast b8
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Int where
|
||||||
|
bytes = 8
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 56),
|
||||||
|
cast (x `shiftR` 48),
|
||||||
|
cast (x `shiftR` 40),
|
||||||
|
cast (x `shiftR` 32),
|
||||||
|
cast (x `shiftR` 24),
|
||||||
|
cast (x `shiftR` 16),
|
||||||
|
cast (x `shiftR` 8),
|
||||||
|
cast x]
|
||||||
|
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
|
||||||
|
cast b1 `shiftL` 56 .|.
|
||||||
|
cast b2 `shiftL` 48 .|.
|
||||||
|
cast b3 `shiftL` 40 .|.
|
||||||
|
cast b4 `shiftL` 32 .|.
|
||||||
|
cast b5 `shiftL` 24 .|.
|
||||||
|
cast b6 `shiftL` 16 .|.
|
||||||
|
cast b7 `shiftL` 8 .|.
|
||||||
|
cast b8
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Int8 where
|
||||||
|
bytes = 1
|
||||||
|
|
||||||
|
toBytes x = [cast x]
|
||||||
|
fromBytes [x] = cast x
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Int16 where
|
||||||
|
bytes = 2
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 8), cast x]
|
||||||
|
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Int32 where
|
||||||
|
bytes = 4
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 24),
|
||||||
|
cast (x `shiftR` 16),
|
||||||
|
cast (x `shiftR` 8),
|
||||||
|
cast x]
|
||||||
|
fromBytes [b1,b2,b3,b4] =
|
||||||
|
cast b1 `shiftL` 24 .|.
|
||||||
|
cast b2 `shiftL` 16 .|.
|
||||||
|
cast b3 `shiftL` 8 .|.
|
||||||
|
cast b4
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Int64 where
|
||||||
|
bytes = 8
|
||||||
|
|
||||||
|
toBytes x = [cast (x `shiftR` 56),
|
||||||
|
cast (x `shiftR` 48),
|
||||||
|
cast (x `shiftR` 40),
|
||||||
|
cast (x `shiftR` 32),
|
||||||
|
cast (x `shiftR` 24),
|
||||||
|
cast (x `shiftR` 16),
|
||||||
|
cast (x `shiftR` 8),
|
||||||
|
cast x]
|
||||||
|
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
|
||||||
|
cast b1 `shiftL` 56 .|.
|
||||||
|
cast b2 `shiftL` 48 .|.
|
||||||
|
cast b3 `shiftL` 40 .|.
|
||||||
|
cast b4 `shiftL` 32 .|.
|
||||||
|
cast b5 `shiftL` 24 .|.
|
||||||
|
cast b6 `shiftL` 16 .|.
|
||||||
|
cast b7 `shiftL` 8 .|.
|
||||||
|
cast b8
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep Bool where
|
||||||
|
bytes = 1
|
||||||
|
|
||||||
|
toBytes b = [if b then 1 else 0]
|
||||||
|
fromBytes [x] = x /= 0
|
||||||
|
|
||||||
|
export
|
||||||
|
ByteRep a => ByteRep b => ByteRep (a, b) where
|
||||||
|
bytes = bytes {a} + bytes {a=b}
|
||||||
|
|
||||||
|
toBytes (x,y) = toBytes x ++ toBytes y
|
||||||
|
fromBytes = bimap fromBytes fromBytes . splitAt _
|
||||||
|
|
||||||
|
export
|
||||||
|
{n : _} -> ByteRep a => ByteRep (Vect n a) where
|
||||||
|
bytes = n * bytes {a}
|
||||||
|
|
||||||
|
toBytes xs = concat $ map toBytes xs
|
||||||
|
fromBytes {n = 0} bs = []
|
||||||
|
fromBytes {n = S n} bs =
|
||||||
|
let (bs1, bs2) = splitAt _ bs
|
||||||
|
in fromBytes bs1 :: fromBytes bs2
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
RepConstraint : Rep -> Type -> Type
|
||||||
|
RepConstraint (Bytes _) a = ByteRep a
|
||||||
|
RepConstraint (Boxed _) a = ()
|
||||||
|
RepConstraint Linked a = ()
|
||||||
|
RepConstraint Delayed a = ()
|
||||||
|
|
|
@ -2,23 +2,39 @@ module Data.NumIdr.PrimArray
|
||||||
|
|
||||||
import Data.Buffer
|
import Data.Buffer
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.Fin
|
||||||
import Data.NP
|
import Data.NP
|
||||||
import Data.NumIdr.Array.Rep
|
import Data.NumIdr.Array.Rep
|
||||||
import Data.NumIdr.Array.Coords
|
import Data.NumIdr.Array.Coords
|
||||||
import Data.NumIdr.PrimArray.Bytes
|
import public Data.NumIdr.PrimArray.Bytes
|
||||||
import Data.NumIdr.PrimArray.Boxed
|
import public Data.NumIdr.PrimArray.Boxed
|
||||||
import Data.NumIdr.PrimArray.Linked
|
import public Data.NumIdr.PrimArray.Linked
|
||||||
import Data.NumIdr.PrimArray.Delayed
|
import public Data.NumIdr.PrimArray.Delayed
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
public export
|
|
||||||
RepConstraint : Rep -> Type -> Type
|
noRCCorrect : (rep : Rep) -> NoConstraintRep rep -> RepConstraint rep a
|
||||||
RepConstraint (Bytes _) a = ByteRep a
|
noRCCorrect (Boxed _) BoxedNC = ()
|
||||||
RepConstraint (Boxed _) a = ()
|
noRCCorrect Linked LinkedNC = ()
|
||||||
RepConstraint Linked a = ()
|
noRCCorrect Delayed DelayedNC = ()
|
||||||
RepConstraint Delayed a = ()
|
|
||||||
|
export
|
||||||
|
mergeRepConstraint : {r, r' : Rep} -> RepConstraint r a -> RepConstraint r' a ->
|
||||||
|
RepConstraint (mergeRep r r') a
|
||||||
|
mergeRepConstraint {r,r'} a b with (r == Delayed || r' == Delayed)
|
||||||
|
_ | True = ()
|
||||||
|
_ | False = a
|
||||||
|
|
||||||
|
export
|
||||||
|
mergeNCRepConstraint : {r, r' : Rep} -> RepConstraint (mergeRepNC r r') a
|
||||||
|
mergeNCRepConstraint = noRCCorrect _ (mergeRepNCCorrect r r')
|
||||||
|
|
||||||
|
export
|
||||||
|
forceRepConstraint : {r : Rep} -> RepConstraint (forceRepNC r) a
|
||||||
|
forceRepConstraint = noRCCorrect _ (forceRepNCCorrect r)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
@ -29,6 +45,11 @@ PrimArray Linked s a = Vects s a
|
||||||
PrimArray Delayed s a = Coords s -> a
|
PrimArray Delayed s a = Coords s -> a
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
reshape : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s' : Vect rk Nat) -> PrimArray rep s a -> PrimArray rep s' a
|
||||||
|
reshape {rep = Bytes o} s' (MkPABytes _ arr) = MkPABytes (calcStrides o s') arr
|
||||||
|
reshape {rep = Boxed o} s' (MkPABoxed _ arr) = MkPABoxed (calcStrides o s') arr
|
||||||
|
|
||||||
export
|
export
|
||||||
constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a
|
constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a
|
||||||
constant {rep = Bytes o} = Bytes.constant
|
constant {rep = Bytes o} = Bytes.constant
|
||||||
|
@ -66,25 +87,68 @@ index {rep = Linked} is arr = Linked.index is arr
|
||||||
index {rep = Delayed} is arr = arr is
|
index {rep = Delayed} is arr = arr is
|
||||||
|
|
||||||
export
|
export
|
||||||
indexNB : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> Maybe a
|
indexUpdate : {rep,s : _} -> RepConstraint rep a => Coords s -> (a -> a) -> PrimArray rep s a -> PrimArray rep s a
|
||||||
indexNB {rep = Bytes o} is arr@(MkPABytes sts _) =
|
indexUpdate {rep = Bytes o} @{rc} is f arr@(MkPABytes sts _) =
|
||||||
if and (zipWith (delay .: (<)) is s)
|
let i = getLocation sts is
|
||||||
then Just $ index (getLocation' sts is) arr
|
in create @{rc} s (\n => let x = index n arr in if n == i then f x else x)
|
||||||
else Nothing
|
indexUpdate {rep = Boxed o} is f arr@(MkPABoxed sts _) =
|
||||||
indexNB {rep = Boxed o} is arr@(MkPABoxed sts _) =
|
let i = getLocation sts is
|
||||||
if and (zipWith (delay .: (<)) is s)
|
in create s (\n => let x = index n arr in if n == i then f x else x)
|
||||||
then Just $ index (getLocation' sts is) arr
|
indexUpdate {rep = Linked} is f arr = update is f arr
|
||||||
else Nothing
|
indexUpdate {rep = Delayed} is f arr =
|
||||||
indexNB {rep = Linked} is arr = (`Linked.index` arr) <$> checkRange s is
|
\is' =>
|
||||||
indexNB {rep = Delayed} is arr = arr <$> checkRange s is
|
let x = arr is'
|
||||||
|
in if eqNP is is' then f x else x
|
||||||
|
|
||||||
|
export
|
||||||
|
indexRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s) -> PrimArray rep s a
|
||||||
|
-> PrimArray rep (newShape rs) a
|
||||||
|
indexRange {rep = Bytes o} @{rc} rs arr@(MkPABytes sts _) =
|
||||||
|
let s' : Vect ? Nat
|
||||||
|
s' = newShape rs
|
||||||
|
sts' := calcStrides o s'
|
||||||
|
in unsafeFromIns @{rc} (newShape rs) $
|
||||||
|
map (\(is,is') => (getLocation' sts' is',
|
||||||
|
index (getLocation' sts is) arr))
|
||||||
|
$ getCoordsList rs
|
||||||
|
indexRange {rep = Boxed o} rs arr@(MkPABoxed sts _) =
|
||||||
|
let s' : Vect ? Nat
|
||||||
|
s' = newShape rs
|
||||||
|
sts' := calcStrides o s'
|
||||||
|
in unsafeFromIns (newShape rs) $
|
||||||
|
map (\(is,is') => (getLocation' sts' is',
|
||||||
|
index (getLocation' sts is) arr))
|
||||||
|
$ getCoordsList rs
|
||||||
|
indexRange {rep = Linked} rs arr = Linked.indexRange rs arr
|
||||||
|
indexRange {rep = Delayed} rs arr = Delayed.indexRange rs arr
|
||||||
|
|
||||||
|
export
|
||||||
|
indexSetRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s)
|
||||||
|
-> PrimArray rep (newShape rs) a -> PrimArray rep s a -> PrimArray rep s a
|
||||||
|
indexSetRange {rep = Bytes o} @{rc} rs rpl@(MkPABytes sts' _) arr@(MkPABytes sts _) =
|
||||||
|
unsafePerformIO $ do
|
||||||
|
let arr' = copy arr
|
||||||
|
unsafeDoIns @{rc} (map (\(is,is') =>
|
||||||
|
(getLocation' sts is, index (getLocation' sts' is') rpl))
|
||||||
|
$ getCoordsList rs) arr'
|
||||||
|
pure arr'
|
||||||
|
indexSetRange {rep = Boxed o} rs rpl@(MkPABoxed sts' _) arr@(MkPABoxed sts _) =
|
||||||
|
unsafePerformIO $ do
|
||||||
|
let arr' = copy arr
|
||||||
|
unsafeDoIns (map (\(is,is') =>
|
||||||
|
(getLocation' sts is, index (getLocation' sts' is') rpl))
|
||||||
|
$ getCoordsList rs) arr'
|
||||||
|
pure arr'
|
||||||
|
indexSetRange {rep = Linked} rs rpl arr = Linked.indexSetRange rs rpl arr
|
||||||
|
indexSetRange {rep = Delayed} rs rpl arr = Delayed.indexSetRange rs rpl arr
|
||||||
|
|
||||||
export
|
export
|
||||||
indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a
|
indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a
|
||||||
indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr
|
indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr
|
||||||
indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr
|
indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr
|
||||||
indexUnsafe {rep = Linked} is arr = assert_total $ case checkRange s is of
|
indexUnsafe {rep = Linked} is arr = assert_total $ case validateCoords s is of
|
||||||
Just is' => Linked.index is' arr
|
Just is' => Linked.index is' arr
|
||||||
indexUnsafe {rep = Delayed} is arr = assert_total $ case checkRange s is of
|
indexUnsafe {rep = Delayed} is arr = assert_total $ case validateCoords s is of
|
||||||
Just is' => arr is'
|
Just is' => arr is'
|
||||||
|
|
||||||
export
|
export
|
||||||
|
@ -100,3 +164,41 @@ convertRep {r1, r2} arr = fromFunction s (\is => PrimArray.index is arr)
|
||||||
export
|
export
|
||||||
fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a
|
fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a
|
||||||
fromVects s v = convertRep {r1=Linked} v
|
fromVects s v = convertRep {r1=Linked} v
|
||||||
|
|
||||||
|
export
|
||||||
|
fromList : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s : Vect rk Nat) -> List a -> PrimArray rep s a
|
||||||
|
fromList {rep = (Boxed x)} = Boxed.fromList
|
||||||
|
fromList {rep = (Bytes x)} = Bytes.fromList
|
||||||
|
|
||||||
|
export
|
||||||
|
mapPrim : {rep,s : _} -> RepConstraint rep a => RepConstraint rep b =>
|
||||||
|
(a -> b) -> PrimArray rep s a -> PrimArray rep s b
|
||||||
|
mapPrim {rep = (Bytes x)} @{_} @{rc} f arr = Bytes.create @{rc} _ (\i => f $ Bytes.index i arr)
|
||||||
|
mapPrim {rep = (Boxed x)} f arr = Boxed.create _ (\i => f $ Boxed.index i arr)
|
||||||
|
mapPrim {rep = Linked} f arr = Linked.mapVects f arr
|
||||||
|
mapPrim {rep = Delayed} f arr = f . arr
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
foldl : {rep,s : _} -> RepConstraint rep a => (b -> a -> b) -> b -> PrimArray rep s a -> b
|
||||||
|
foldl {rep = Bytes _} = Bytes.foldl
|
||||||
|
foldl {rep = Boxed _} = Boxed.foldl
|
||||||
|
foldl {rep = Linked} = Linked.foldl
|
||||||
|
foldl {rep = Delayed} = \f,z => Boxed.foldl f z . convertRep {r1=Delayed,r2=B}
|
||||||
|
|
||||||
|
export
|
||||||
|
foldr : {rep,s : _} -> RepConstraint rep a => (a -> b -> b) -> b -> PrimArray rep s a -> b
|
||||||
|
foldr {rep = Bytes _} = Bytes.foldr
|
||||||
|
foldr {rep = Boxed _} = Boxed.foldr
|
||||||
|
foldr {rep = Linked} = Linked.foldr
|
||||||
|
foldr {rep = Delayed} = \f,z => Boxed.foldr f z . convertRep {r1=Delayed,r2=B}
|
||||||
|
|
||||||
|
export
|
||||||
|
traverse : {rep,s : _} -> Applicative f => RepConstraint rep a => RepConstraint rep b =>
|
||||||
|
(a -> f b) -> PrimArray rep s a -> f (PrimArray rep s b)
|
||||||
|
traverse {rep = Bytes o} = Bytes.traverse
|
||||||
|
traverse {rep = Boxed o} = Boxed.traverse
|
||||||
|
traverse {rep = Linked} = Linked.traverse
|
||||||
|
traverse {rep = Delayed} = \f => map (convertRep @{()} @{()} {r1=B,r2=Delayed}) .
|
||||||
|
Boxed.traverse f .
|
||||||
|
convertRep @{()} @{()} {r1=Delayed,r2=B}
|
||||||
|
|
|
@ -2,6 +2,7 @@ module Data.NumIdr.PrimArray.Boxed
|
||||||
|
|
||||||
import Data.Buffer
|
import Data.Buffer
|
||||||
import System
|
import System
|
||||||
|
import Data.IORef
|
||||||
import Data.IOArray.Prims
|
import Data.IOArray.Prims
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NumIdr.Array.Rep
|
import Data.NumIdr.Array.Rep
|
||||||
|
@ -108,6 +109,34 @@ fromList s xs = arrayAction s $ go xs 0
|
||||||
go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf
|
go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
foldl : {s : _} -> (b -> a -> b) -> b -> PrimArrayBoxed o s a -> b
|
||||||
|
foldl f z (MkPABoxed sts arr) =
|
||||||
|
if product s == 0 then z
|
||||||
|
else unsafePerformIO $ do
|
||||||
|
ref <- newIORef z
|
||||||
|
for_ [0..pred $ product s] $ \n => do
|
||||||
|
x <- readIORef ref
|
||||||
|
y <- arrayDataGet n arr
|
||||||
|
writeIORef ref (f x y)
|
||||||
|
readIORef ref
|
||||||
|
|
||||||
|
export
|
||||||
|
foldr : {s : _} -> (a -> b -> b) -> b -> PrimArrayBoxed o s a -> b
|
||||||
|
foldr f z (MkPABoxed sts arr) =
|
||||||
|
if product s == 0 then z
|
||||||
|
else unsafePerformIO $ do
|
||||||
|
ref <- newIORef z
|
||||||
|
for_ [pred $ product s..0] $ \n => do
|
||||||
|
x <- arrayDataGet n arr
|
||||||
|
y <- readIORef ref
|
||||||
|
writeIORef ref (f x y)
|
||||||
|
readIORef ref
|
||||||
|
|
||||||
|
export
|
||||||
|
traverse : {o,s : _} -> Applicative f => (a -> f b) -> PrimArrayBoxed o s a -> f (PrimArrayBoxed o s b)
|
||||||
|
traverse f = map (Boxed.fromList _) . traverse f . foldr (::) []
|
||||||
|
|
||||||
{-
|
{-
|
||||||
export
|
export
|
||||||
updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
|
updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
|
||||||
|
@ -145,10 +174,6 @@ fromList xs = create (length xs)
|
||||||
fromJust : Maybe a -> a
|
fromJust : Maybe a -> a
|
||||||
fromJust (Just x) = x
|
fromJust (Just x) = x
|
||||||
|
|
||||||
||| Map a function over a primitive array.
|
|
||||||
export
|
|
||||||
map : (a -> b) -> PrimArrayBoxed a -> PrimArrayBoxed b
|
|
||||||
map f arr = create (length arr) (\n => f $ index n arr)
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
|
|
@ -2,8 +2,6 @@ module Data.NumIdr.PrimArray.Bytes
|
||||||
|
|
||||||
import System
|
import System
|
||||||
import Data.IORef
|
import Data.IORef
|
||||||
import Data.Bits
|
|
||||||
import Data.So
|
|
||||||
import Data.Buffer
|
import Data.Buffer
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NumIdr.Array.Coords
|
import Data.NumIdr.Array.Coords
|
||||||
|
@ -12,160 +10,6 @@ import Data.NumIdr.Array.Rep
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
public export
|
|
||||||
interface ByteRep a where
|
|
||||||
bytes : Nat
|
|
||||||
|
|
||||||
toBytes : a -> Vect bytes Bits8
|
|
||||||
fromBytes : Vect bytes Bits8 -> a
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Bits8 where
|
|
||||||
bytes = 1
|
|
||||||
|
|
||||||
toBytes x = [x]
|
|
||||||
fromBytes [x] = x
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Bits16 where
|
|
||||||
bytes = 2
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 8), cast x]
|
|
||||||
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Bits32 where
|
|
||||||
bytes = 4
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 24),
|
|
||||||
cast (x `shiftR` 16),
|
|
||||||
cast (x `shiftR` 8),
|
|
||||||
cast x]
|
|
||||||
fromBytes [b1,b2,b3,b4] =
|
|
||||||
cast b1 `shiftL` 24 .|.
|
|
||||||
cast b2 `shiftL` 16 .|.
|
|
||||||
cast b3 `shiftL` 8 .|.
|
|
||||||
cast b4
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Bits64 where
|
|
||||||
bytes = 8
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 56),
|
|
||||||
cast (x `shiftR` 48),
|
|
||||||
cast (x `shiftR` 40),
|
|
||||||
cast (x `shiftR` 32),
|
|
||||||
cast (x `shiftR` 24),
|
|
||||||
cast (x `shiftR` 16),
|
|
||||||
cast (x `shiftR` 8),
|
|
||||||
cast x]
|
|
||||||
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
|
|
||||||
cast b1 `shiftL` 56 .|.
|
|
||||||
cast b2 `shiftL` 48 .|.
|
|
||||||
cast b3 `shiftL` 40 .|.
|
|
||||||
cast b4 `shiftL` 32 .|.
|
|
||||||
cast b5 `shiftL` 24 .|.
|
|
||||||
cast b6 `shiftL` 16 .|.
|
|
||||||
cast b7 `shiftL` 8 .|.
|
|
||||||
cast b8
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Int where
|
|
||||||
bytes = 8
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 56),
|
|
||||||
cast (x `shiftR` 48),
|
|
||||||
cast (x `shiftR` 40),
|
|
||||||
cast (x `shiftR` 32),
|
|
||||||
cast (x `shiftR` 24),
|
|
||||||
cast (x `shiftR` 16),
|
|
||||||
cast (x `shiftR` 8),
|
|
||||||
cast x]
|
|
||||||
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
|
|
||||||
cast b1 `shiftL` 56 .|.
|
|
||||||
cast b2 `shiftL` 48 .|.
|
|
||||||
cast b3 `shiftL` 40 .|.
|
|
||||||
cast b4 `shiftL` 32 .|.
|
|
||||||
cast b5 `shiftL` 24 .|.
|
|
||||||
cast b6 `shiftL` 16 .|.
|
|
||||||
cast b7 `shiftL` 8 .|.
|
|
||||||
cast b8
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Int8 where
|
|
||||||
bytes = 1
|
|
||||||
|
|
||||||
toBytes x = [cast x]
|
|
||||||
fromBytes [x] = cast x
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Int16 where
|
|
||||||
bytes = 2
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 8), cast x]
|
|
||||||
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Int32 where
|
|
||||||
bytes = 4
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 24),
|
|
||||||
cast (x `shiftR` 16),
|
|
||||||
cast (x `shiftR` 8),
|
|
||||||
cast x]
|
|
||||||
fromBytes [b1,b2,b3,b4] =
|
|
||||||
cast b1 `shiftL` 24 .|.
|
|
||||||
cast b2 `shiftL` 16 .|.
|
|
||||||
cast b3 `shiftL` 8 .|.
|
|
||||||
cast b4
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Int64 where
|
|
||||||
bytes = 8
|
|
||||||
|
|
||||||
toBytes x = [cast (x `shiftR` 56),
|
|
||||||
cast (x `shiftR` 48),
|
|
||||||
cast (x `shiftR` 40),
|
|
||||||
cast (x `shiftR` 32),
|
|
||||||
cast (x `shiftR` 24),
|
|
||||||
cast (x `shiftR` 16),
|
|
||||||
cast (x `shiftR` 8),
|
|
||||||
cast x]
|
|
||||||
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
|
|
||||||
cast b1 `shiftL` 56 .|.
|
|
||||||
cast b2 `shiftL` 48 .|.
|
|
||||||
cast b3 `shiftL` 40 .|.
|
|
||||||
cast b4 `shiftL` 32 .|.
|
|
||||||
cast b5 `shiftL` 24 .|.
|
|
||||||
cast b6 `shiftL` 16 .|.
|
|
||||||
cast b7 `shiftL` 8 .|.
|
|
||||||
cast b8
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep Bool where
|
|
||||||
bytes = 1
|
|
||||||
|
|
||||||
toBytes b = [if b then 1 else 0]
|
|
||||||
fromBytes [x] = x /= 0
|
|
||||||
|
|
||||||
export
|
|
||||||
ByteRep a => ByteRep b => ByteRep (a, b) where
|
|
||||||
bytes = bytes {a} + bytes {a=b}
|
|
||||||
|
|
||||||
toBytes (x,y) = toBytes x ++ toBytes y
|
|
||||||
fromBytes = bimap fromBytes fromBytes . splitAt _
|
|
||||||
|
|
||||||
export
|
|
||||||
{n : _} -> ByteRep a => ByteRep (Vect n a) where
|
|
||||||
bytes = n * bytes {a}
|
|
||||||
|
|
||||||
toBytes xs = concat $ map toBytes xs
|
|
||||||
fromBytes {n = 0} bs = []
|
|
||||||
fromBytes {n = S n} bs =
|
|
||||||
let (bs1, bs2) = splitAt _ bs
|
|
||||||
in fromBytes bs1 :: fromBytes bs2
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a
|
readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a
|
||||||
readBuffer i buf = do
|
readBuffer i buf = do
|
||||||
|
@ -226,6 +70,15 @@ export
|
||||||
index : ByteRep a => Nat -> PrimArrayBytes o s -> a
|
index : ByteRep a => Nat -> PrimArrayBytes o s -> a
|
||||||
index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf
|
index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf
|
||||||
|
|
||||||
|
export
|
||||||
|
copy : {o,s : _} -> PrimArrayBytes o s -> PrimArrayBytes o s
|
||||||
|
copy (MkPABytes sts buf) =
|
||||||
|
MkPABytes sts (unsafePerformIO $ do
|
||||||
|
Just buf' <- newBuffer !(rawSize buf)
|
||||||
|
| Nothing => die "Cannot create array"
|
||||||
|
copyData buf 0 !(rawSize buf) buf' 0
|
||||||
|
pure buf)
|
||||||
|
|
||||||
export
|
export
|
||||||
reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s
|
reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s
|
||||||
reorder @{br} arr@(MkPABytes sts buf) =
|
reorder @{br} arr@(MkPABytes sts buf) =
|
||||||
|
@ -242,3 +95,31 @@ fromList @{br} s xs = bufferAction @{br} s $ go xs 0
|
||||||
go : List a -> Nat -> Buffer -> IO ()
|
go : List a -> Nat -> Buffer -> IO ()
|
||||||
go [] _ buf = pure ()
|
go [] _ buf = pure ()
|
||||||
go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf
|
go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf
|
||||||
|
|
||||||
|
export
|
||||||
|
foldl : {s : _} -> ByteRep a => (b -> a -> b) -> b -> PrimArrayBytes o s -> b
|
||||||
|
foldl f z (MkPABytes sts buf) =
|
||||||
|
if product s == 0 then z
|
||||||
|
else unsafePerformIO $ do
|
||||||
|
ref <- newIORef z
|
||||||
|
for_ [0..pred $ product s] $ \n => do
|
||||||
|
x <- readIORef ref
|
||||||
|
y <- readBuffer n buf
|
||||||
|
writeIORef ref (f x y)
|
||||||
|
readIORef ref
|
||||||
|
|
||||||
|
export
|
||||||
|
foldr : {s : _} -> ByteRep a => (a -> b -> b) -> b -> PrimArrayBytes o s -> b
|
||||||
|
foldr f z (MkPABytes sts buf) =
|
||||||
|
if product s == 0 then z
|
||||||
|
else unsafePerformIO $ do
|
||||||
|
ref <- newIORef z
|
||||||
|
for_ [pred $ product s..0] $ \n => do
|
||||||
|
x <- readBuffer n buf
|
||||||
|
y <- readIORef ref
|
||||||
|
writeIORef ref (f x y)
|
||||||
|
readIORef ref
|
||||||
|
|
||||||
|
export
|
||||||
|
traverse : {o,s : _} -> ByteRep a => ByteRep b => Applicative f => (a -> f b) -> PrimArrayBytes o s -> f (PrimArrayBytes o s)
|
||||||
|
traverse f = map (Bytes.fromList _) . traverse f . foldr (::) []
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
module Data.NumIdr.PrimArray.Delayed
|
module Data.NumIdr.PrimArray.Delayed
|
||||||
|
|
||||||
|
import Data.List
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NP
|
import Data.NP
|
||||||
import Data.NumIdr.Array.Rep
|
import Data.NumIdr.Array.Rep
|
||||||
|
@ -18,6 +19,20 @@ constant s x _ = x
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
checkRange : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
|
indexRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed s a -> PrimArrayDelayed (newShape rs) a
|
||||||
checkRange [] [] = Just []
|
indexRange [] v = v
|
||||||
checkRange (d :: s) (i :: is) = (::) <$> natToFin i d <*> checkRange s is
|
indexRange (r :: rs) v with (cRangeToList r)
|
||||||
|
_ | Left i = indexRange rs (\is' => v (believe_me i :: is'))
|
||||||
|
_ | Right is = \(i::is') => indexRange rs (\is'' => v (believe_me (Vect.index i (fromList is)) :: is'')) is'
|
||||||
|
|
||||||
|
export
|
||||||
|
indexSetRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed (newShape rs) a
|
||||||
|
-> PrimArrayDelayed s a -> PrimArrayDelayed s a
|
||||||
|
indexSetRange {s=[]} [] rv _ = rv
|
||||||
|
indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r)
|
||||||
|
_ | Left i = \(i'::is) => if i == finToNat i'
|
||||||
|
then indexSetRange rs rv (v . (i'::)) is
|
||||||
|
else v (i'::is)
|
||||||
|
_ | Right is = \(i'::is') => case findIndex (== finToNat i') is of
|
||||||
|
Just x => indexSetRange rs (rv . (x::)) (v . (i'::)) is'
|
||||||
|
Nothing => v (i'::is')
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
module Data.NumIdr.PrimArray.Linked
|
module Data.NumIdr.PrimArray.Linked
|
||||||
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
|
import Data.DPair
|
||||||
import Data.NP
|
import Data.NP
|
||||||
import Data.NumIdr.Array.Rep
|
import Data.NumIdr.Array.Rep
|
||||||
import Data.NumIdr.Array.Coords
|
import Data.NumIdr.Array.Coords
|
||||||
|
@ -18,6 +19,25 @@ index : Coords s -> Vects s a -> a
|
||||||
index [] x = x
|
index [] x = x
|
||||||
index (c :: cs) xs = index cs (index c xs)
|
index (c :: cs) xs = index cs (index c xs)
|
||||||
|
|
||||||
|
export
|
||||||
|
update : Coords s -> (a -> a) -> Vects s a -> Vects s a
|
||||||
|
update [] f v = f v
|
||||||
|
update (i :: is) f v = updateAt i (update is f) v
|
||||||
|
|
||||||
|
export
|
||||||
|
indexRange : {s : _} -> (rs : CoordsRange s) -> Vects s a -> Vects (newShape rs) a
|
||||||
|
indexRange [] v = v
|
||||||
|
indexRange (r :: rs) v with (cRangeToList r)
|
||||||
|
_ | Left i = indexRange rs (Vect.index (believe_me i) v)
|
||||||
|
_ | Right is = believe_me $ map (\i => indexRange rs (Vect.index (believe_me i) v)) is
|
||||||
|
|
||||||
|
export
|
||||||
|
indexSetRange : {s : _} -> (rs : CoordsRange s) -> Vects (newShape rs) a -> Vects s a -> Vects s a
|
||||||
|
indexSetRange {s=[]} [] rv _ = rv
|
||||||
|
indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r)
|
||||||
|
_ | Left i = updateAt (believe_me i) (indexSetRange rs rv) v
|
||||||
|
_ | Right is = foldl (\v,i => updateAt (believe_me i) (indexSetRange rs (Vect.index (believe_me i) rv)) v) v is
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
fromFunction : {s : _} -> (Coords s -> a) -> Vects s a
|
fromFunction : {s : _} -> (Coords s -> a) -> Vects s a
|
||||||
|
@ -32,3 +52,23 @@ fromFunctionNB {s = d :: s} f = tabulate' {n=d} (\i => fromFunctionNB (f . (i::)
|
||||||
tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a
|
tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a
|
||||||
tabulate' {n=Z} f = []
|
tabulate' {n=Z} f = []
|
||||||
tabulate' {n=S _} f = f Z :: tabulate' (f . S)
|
tabulate' {n=S _} f = f Z :: tabulate' (f . S)
|
||||||
|
|
||||||
|
export
|
||||||
|
mapVects : {s : _} -> (a -> b) -> Vects s a -> Vects s b
|
||||||
|
mapVects {s = []} f x = f x
|
||||||
|
mapVects {s = _ :: _} f v = Prelude.map (mapVects f) v
|
||||||
|
|
||||||
|
export
|
||||||
|
foldl : {s : _} -> (b -> a -> b) -> b -> Vects s a -> b
|
||||||
|
foldl {s=[]} f z v = f z v
|
||||||
|
foldl {s=_::_} f z v = Prelude.foldl (Linked.foldl f) z v
|
||||||
|
|
||||||
|
export
|
||||||
|
foldr : {s : _} -> (a -> b -> b) -> b -> Vects s a -> b
|
||||||
|
foldr {s=[]} f z v = f v z
|
||||||
|
foldr {s=_::_} f z v = Prelude.foldr (flip $ Linked.foldr f) z v
|
||||||
|
|
||||||
|
export
|
||||||
|
traverse : {s : _} -> Applicative f => (a -> f b) -> Vects s a -> f (Vects s b)
|
||||||
|
traverse {s=[]} f v = f v
|
||||||
|
traverse {s=_::_} f v = Prelude.traverse (Linked.traverse f) v
|
||||||
|
|
Loading…
Reference in a new issue