Revamp all of the standard array API

This commit is contained in:
Kiana Sheibani 2023-09-25 22:55:33 -04:00
parent ccb689af42
commit a0068469c5
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
10 changed files with 655 additions and 364 deletions

View file

@ -26,6 +26,11 @@ public export
elems {all=[_]} [x] = show x
elems {all=_::_} (x :: xs) = show x ++ ", " ++ elems xs
-- use this when idris can't figure out the constraint above
public export
eqNP : forall f. (forall t. Eq (f t)) => (x, y : NP f ts) -> Bool
eqNP [] [] = True
eqNP (x :: xs) (y :: ys) = x == y && eqNP xs ys
public export

View file

@ -3,4 +3,4 @@ module Data.NumIdr.Array
import public Data.NP
import public Data.NumIdr.Array.Array
import public Data.NumIdr.Array.Coords
import public Data.NumIdr.Array.Order
import public Data.NumIdr.Array.Rep

View file

@ -112,7 +112,7 @@ viewShape arr = rewrite shapeEq arr in
export
repeat : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> a -> Array s a
repeat s x = MkArray rep s (constant s x)
repeat s x = MkArray rep s (PrimArray.constant s x)
||| Create an array filled with zeros.
|||
@ -138,9 +138,9 @@ ones {rep} s = repeat {rep} s 1
||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array
export
fromVect : {default B rep : Rep} -> RepConstraint rep a =>
fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
(s : Vect rk Nat) -> Vect (product s) a -> Array s a
fromVect s v = ?fv
fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $ toList v)
||| Create an array by taking values from a stream.
@ -148,7 +148,7 @@ fromVect s v = ?fv
||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array
export
fromStream : {default B rep : Rep} -> RepConstraint rep a =>
fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
(s : Vect rk Nat) -> Stream a -> Array s a
fromStream {rep} s str = fromVect {rep} s (take _ str)
@ -211,12 +211,12 @@ index is (MkArray _ _ arr) = PrimArray.index is arr
export %inline
(!!) : Array s a -> Coords s -> a
arr !! is = index is arr
{-
||| Update the entry at the given coordinates using the function.
export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
indexUpdate is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation sts is) f arr)
indexUpdate is f (MkArray rep @{rc} s arr) =
MkArray rep @{rc} s (indexUpdate @{rc} is f arr)
||| Set the entry at the given coordinates to the given value.
export
@ -227,17 +227,8 @@ indexSet is = indexUpdate is . const
||| Index the array using the given range of coordinates, returning a new array.
export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange {s} rs arr with (viewShape arr)
_ | Shape s =
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList rs)
where s' : Vect ? Nat
s' = newShape rs
indexRange rs (MkArray rep @{rc} s arr) =
MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr)
||| Index the array using the given range of coordinates, returning a new array.
|||
@ -250,13 +241,8 @@ arr !!.. rs = indexRange rs arr
export
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
Array s a -> Array s a
indexSetRange rs rpl (MkArray ord sts s arr) =
MkArray ord sts s (unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns (map (\(is,is') =>
(getLocation' sts is, index (getLocation' (strides rpl) is') (getPrim rpl)))
$ getCoordsList rs) arr'
pure arr')
indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) =
MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRep rpl) arr)
||| Update the sub-array at the given range of coordinates by applying
@ -266,13 +252,14 @@ indexUpdateRange : (rs : CoordsRange s) ->
(Array (newShape rs) a -> Array (newShape rs) a) ->
Array s a -> Array s a
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
-}
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
indexNB is (MkArray rep @{rc} s arr) =
map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is)
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
@ -281,15 +268,13 @@ indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr
{-
||| Update the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds.
export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
indexUpdateNB is f (MkArray ord sts s arr) =
if all id $ zipWith (<) is s
then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr)
else Nothing
indexUpdateNB is f (MkArray rep @{rc} s arr) =
map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is)
||| Set the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds.
@ -302,20 +287,8 @@ indexSetNB is = indexUpdateNB is . const
||| If any of the given indices are out of bounds, then `Nothing` is returned.
export
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
indexRangeNB {s} rs arr with (viewShape arr)
_ | Shape s =
if validateShape s rs
then
let ord = getOrder arr
sts = calcStrides ord s'
in Just $ MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
else Nothing
where s' : Vect ? Nat
s' = newShape s rs
indexRangeNB rs (MkArray rep @{rc} s arr) =
map (\rs => believe_me $ Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs)
||| Index the array using the given range of coordinates, returning a new array.
||| If any of the given indices are out of bounds, then `Nothing` is returned.
@ -324,7 +297,7 @@ indexRangeNB {s} rs arr with (viewShape arr)
export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
arr !?.. rs = indexRangeNB rs arr
-}
||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs.
@ -342,30 +315,21 @@ export %inline
(!#) : Array {rk} s a -> Vect rk Nat -> a
arr !# is = indexUnsafe is arr
{-
||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
export
export %unsafe
indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
indexRangeUnsafe {s} rs arr with (viewShape arr)
_ | Shape s =
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
where s' : Vect ? Nat
s' = newShape s rs
indexRangeUnsafe rs (MkArray rep @{rc} s arr) =
believe_me $ Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr)
||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
|||
||| This is the operator form of `indexRangeUnsafe`.
export %inline
export %inline %unsafe
(!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
arr !#.. is = indexRangeUnsafe is arr
@ -375,35 +339,42 @@ arr !#.. is = indexRangeUnsafe is arr
-- Operations on arrays
--------------------------------------------------------------------------------
||| Reshape the array into the given shape and reinterpret it according to
||| the given order.
|||
||| @ s' The shape to convert the array to
||| @ ord The order to reinterpret the array by
export
reshape' : (s' : Vect rk' Nat) -> (ord : Order) -> Array {rk} s a ->
(0 ok : product s = product s') => Array s' a
reshape' s' ord' arr = MkArray ord' (calcStrides ord' s') s' (getPrim arr)
mapArray' : (a -> a) -> Array s a -> Array s a
mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
export
mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
export
zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
zipWithArray' {s} f a b with (viewShape a)
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b))
@{mergeRepConstraint (getRepC a) (getRepC b)} s
$ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
(\is => f (a !# is) (b !# is))
export
zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
zipWithArray {s} f a b @{rc} with (viewShape a)
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s
$ PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is))
||| Reshape the array into the given shape.
|||
||| @ s' The shape to convert the array to
export
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) =>
(0 ok : product s = product s') => Array s' a
reshape s' arr = reshape' s' (getOrder arr) arr
reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr)
||| Change the internal order of the array's elements.
export
reorder : Order -> Array s a -> Array s a
reorder {s} ord' arr with (viewShape arr)
_ | Shape s = let sts = calcStrides ord' s
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a
convertRep rep (MkArray _ s arr) = MkArray rep s (PrimArray.convertRep arr)
||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. New coordinates are filled with a default value.
@ -412,7 +383,7 @@ reorder {s} ord' arr with (viewShape arr)
||| @ def The default value to fill the array with
export
resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB)
resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB)
||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. This function requires a proof that the new shape is strictly
@ -427,20 +398,11 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
resizeLTE s' arr = resize s' (believe_me ()) arr
||| List all of the values in an array in row-major order.
export
elements : Array {rk} s a -> Vect (product s) a
elements (MkArray _ sts sh p) =
let elems = map (flip index p . getLocation' sts) (getAllCoords' sh)
-- TODO: prove that the number of elements is `product s`
in assert_total $ case toVect (product sh) elems of Just v => v
||| List all of the values in an array along with their coordinates.
export
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
enumerateNB (MkArray _ sts sh p) =
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
enumerateNB (MkArray _ s arr) =
map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s)
||| List all of the values in an array along with their coordinates.
export
@ -448,6 +410,12 @@ enumerate : Array s a -> List (Coords s, a)
enumerate {s} arr with (viewShape arr)
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
||| List all of the values in an array in row-major order.
export
elements : Array {rk} s a -> Vect (product s) a
elements (MkArray _ s arr) =
believe_me $ Vect.fromList $
map (flip PrimArray.indexUnsafe arr) (getAllCoords' s)
||| Join two arrays along a particular axis, e.g. combining two matrices
||| vertically or horizontally. All other axes of the arrays must have the
@ -457,16 +425,12 @@ enumerate {s} arr with (viewShape arr)
export
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
Array (updateAt axis (+d) s) a
concat axis a b = let sA = shape a
sB = shape b
dA = index axis sA
dB = index axis sB
s = replaceAt axis (dA + dB) sA
sts = calcStrides COrder s
ins = map (mapFst $ getLocation' sts . toNB) (enumerate a)
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNB) (enumerate b)
-- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
concat {s,d} axis a b with (viewShape a, viewShape b)
_ | (Shape s, Shape (replaceAt axis d s)) =
believe_me $ Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)}
@{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s)
(\is => let limit = index axis s
in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is)
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
|||
@ -500,12 +464,11 @@ splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
||| Construct the transpose of an array by reversing the order of its axes.
export
transpose : Array s a -> Array (reverse s) a
transpose {s} arr with (viewShape arr)
_ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is)
_ | Shape s = fromFunctionNB _ (\is => arr !# reverse is)
||| Construct the transpose of an array by reversing the order of its axes.
|||
@ -551,41 +514,56 @@ permuteInAxis {s} axis p arr with (viewShape arr)
-- Implementations
--------------------------------------------------------------------------------
-- Most of these implementations apply the operation pointwise. If there are
-- multiple arrays involved with different orders, then all of the arrays are
-- reordered to match one.
export
Zippable (Array s) where
zipWith {s} f a b with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $
if getOrder a == getOrder b
then unsafeZipWith f (getPrim a) (getPrim b)
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
_ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b))
@{mergeNCRepConstraint} s
$ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
(\is => f (a !# is) (b !# is))
zipWith3 {s} f a b c with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
(getPrim $ reorder (getOrder a) c)
_ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c))
@{mergeNCRepConstraint} s
$ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
(\is => f (a !# is) (b !# is) (c !# is))
unzipWith {s} f arr with (viewShape arr)
_ | Shape s = case unzipWith f (getPrim arr) of
(a, b) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) s b)
_ | Shape s =
let rep : Rep
rep = forceRepNC $ getRep arr
in (MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => snd $ f (arr !# is)))
unzipWith3 {s} f arr with (viewShape arr)
_ | Shape s = case unzipWith3 f (getPrim arr) of
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) s b,
MkArray (getOrder arr) (strides arr) s c)
_ | Shape s =
let rep : Rep
rep = forceRepNC $ getRep arr
in (MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ snd $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => snd $ snd $ f (arr !# is)))
export
Functor (Array s) where
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr)
map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s
(mapPrim @{forceRepConstraint} @{forceRepConstraint} f
$ convertRep @{rc} @{forceRepConstraint} arr)
export
{s : _} -> Applicative (Array s) where
@ -602,15 +580,18 @@ export
export
Foldable (Array s) where
foldl f z = foldl f z . getPrim
foldr f z = foldr f z . getPrim
null arr = size arr == Z
toList = toList . getPrim
foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr
foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr
null (MkArray _ s _) = isZero (product s)
export
Traversable (Array s) where
traverse f (MkArray ord sts s arr) =
map (MkArray ord sts s) (traverse f arr)
traverse f (MkArray rep @{rc} s arr) =
map (MkArray (forceRepNC rep) @{forceRepConstraint} s)
(PrimArray.traverse {rep=forceRepNC rep}
@{%search} @{forceRepConstraint} @{forceRepConstraint} f
(PrimArray.convertRep @{rc} @{forceRepConstraint} arr))
export
@ -619,13 +600,11 @@ Cast a b => Cast (Array s a) (Array s b) where
export
Eq a => Eq (Array s a) where
a == b = if getOrder a == getOrder b
then unsafeEq (getPrim a) (getPrim b)
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
a == b = and $ zipWith (delay .: (==)) (convertRep D a) (convertRep D b)
export
Semigroup a => Semigroup (Array s a) where
(<+>) = zipWith (<+>)
(<+>) = zipWithArray' (<+>)
export
{s : _} -> Monoid a => Monoid (Array s a) where
@ -635,36 +614,34 @@ export
-- were moved into its own interface, this constraint could be removed.
export
{s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+)
(*) = zipWith (*)
(+) = zipWithArray' (+)
(*) = zipWithArray' (*)
fromInteger = repeat s . fromInteger
export
{s : _} -> Neg a => Neg (Array s a) where
negate = map negate
(-) = zipWith (-)
negate = mapArray' negate
(-) = zipWithArray' (-)
export
{s : _} -> Fractional a => Fractional (Array s a) where
recip = map recip
(/) = zipWith (/)
recip = mapArray' recip
(/) = zipWithArray' (/)
export
Num a => Mult a (Array {rk} s a) (Array s a) where
(*.) x = map (*x)
(*.) x = mapArray' (*x)
export
Num a => Mult (Array {rk} s a) a (Array s a) where
(*.) = flip (*.)
export
Show a => Show (Array s a) where
showPrec d arr = let orderedElems = PrimArray.toList $ getPrim $
if getOrder arr == COrder then arr else reorder COrder arr
showPrec d arr = let orderedElems = toList $ elements arr
in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems
where
splitWindow : Nat -> List String -> List (List String)
@ -691,7 +668,7 @@ Show a => Show (Array s a) where
||| Linearly interpolate between two arrays.
export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t)
||| Calculate the square of an array's Eulidean norm.
@ -715,4 +692,3 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export
pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p)
-}

View file

@ -43,6 +43,11 @@ toNB : Coords {rk} s -> Vect rk Nat
toNB [] = []
toNB (i :: is) = finToNat i :: toNB is
export
validateCoords : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
validateCoords [] [] = Just []
validateCoords (d :: s) (i :: is) = (::) <$> natToFin i d <*> validateCoords s is
namespace Strict
public export
@ -155,6 +160,44 @@ namespace Strict
namespace NB
export
validateCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> Maybe (CoordsRange s)
validateCRange [] [] = Just []
validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |]
where
validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n)
validate' n All = Just All
validate' n (StartBound x) =
case isLTE x n of
Yes _ => Just (StartBound (natToFinLT x))
_ => Nothing
validate' n (EndBound x) =
case isLTE x n of
Yes _ => Just (EndBound (natToFinLT x))
_ => Nothing
validate' n (Bounds x y) =
case (isLTE x n, isLTE y n) of
(Yes _, Yes _) => Just (Bounds (natToFinLT x) (natToFinLT y))
_ => Nothing
validate' n (Indices xs) = Indices <$> traverse
(\x => case isLT x n of
Yes _ => Just (natToFinLT x)
No _ => Nothing) xs
validate' n (Filter f) = Just (Filter (f . finToNat))
export %unsafe
assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s
assertCRange [] [] = []
assertCRange (d :: s) (r :: rs) = assert' d r :: assertCRange s rs
where
assert' : (n : Nat) -> CRangeNB -> CRange n
assert' n All = All
assert' n (StartBound x) = StartBound (believe_me x)
assert' n (EndBound x) = EndBound (believe_me x)
assert' n (Bounds x y) = Bounds (believe_me x) (believe_me y)
assert' n (Indices xs) = Indices (believe_me <$> xs)
assert' n (Filter f) = Filter (f . finToNat)
public export
cRangeNBToList : Nat -> CRangeNB -> List Nat
cRangeNBToList s All = range 0 s
cRangeNBToList s (StartBound x) = range x s
@ -163,38 +206,11 @@ namespace NB
cRangeNBToList s (Indices xs) = nub xs
cRangeNBToList s (Filter p) = filter p $ range 0 s
export
validateCRangeNB : Nat -> CRangeNB -> Bool
validateCRangeNB s All = True
validateCRangeNB s (StartBound x) = x < s
validateCRangeNB s (EndBound x) = x <= s
validateCRangeNB s (Bounds x y) = x < s && y <= s
validateCRangeNB s (Indices xs) = all (<s) xs
validateCRangeNB s (Filter p) = True
||| Calculate the new shape given by a coordinate range.
export
public export
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
newShape = zipWith (length .: cRangeNBToList)
export
validateShape : Vect rk Nat -> Vect rk CRangeNB -> Bool
validateShape = all id .: zipWith validateCRangeNB
export
getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat
getNewPos = zipWith3 (\d,r,i => assert_total $
case findIndex (==i) (cRangeNBToList d r) of Just x => cast x)
export
getCoordsList : Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat, Vect rk Nat)
getCoordsList s rs = map (\is => (is, getNewPos s rs is)) $ go s rs
where
go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat)
go [] [] = [[]]
go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |]
export
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
getAllCoords' = traverse (\case Z => []; S n => [0..n])

View file

@ -1,6 +1,8 @@
module Data.NumIdr.Array.Rep
import Data.Vect
import Data.Bits
import Data.Buffer
%default total
@ -23,6 +25,8 @@ data Order : Type where
||| fashion (the first axis is the least significant).
FOrder : Order
%name Order o
public export
Eq Order where
@ -52,6 +56,9 @@ data Rep : Type where
Linked : Rep
Delayed : Rep
%name Rep rep
public export
B : Rep
B = Boxed COrder
@ -65,7 +72,231 @@ D : Rep
D = Delayed
public export
Eq Rep where
Bytes o == Bytes o' = o == o'
Boxed o == Boxed o' = o == o'
Linked == Linked = True
Delayed == Delayed = True
_ == _ = False
public export
data LinearRep : Rep -> Type where
BytesIsL : LinearRep (Bytes o)
BoxedIsL : LinearRep (Boxed o)
public export
forceRepNC : Rep -> Rep
forceRepNC (Bytes o) = Boxed o
forceRepNC r = r
public export
mergeRep : Rep -> Rep -> Rep
mergeRep r r' = if r == Delayed || r' == Delayed then Delayed else r
public export
mergeRepNC : Rep -> Rep -> Rep
mergeRepNC r r' =
if r == Delayed || r' == Delayed then Delayed
else case r of
Bytes _ => case r' of
Bytes o => Boxed o
_ => r'
_ => r
public export
data NoConstraintRep : Rep -> Type where
BoxedNC : NoConstraintRep (Boxed o)
LinkedNC : NoConstraintRep Linked
DelayedNC : NoConstraintRep Delayed
public export
mergeRepNCCorrect : (r, r' : Rep) -> NoConstraintRep (mergeRepNC r r')
mergeRepNCCorrect Delayed _ = DelayedNC
mergeRepNCCorrect (Bytes y) (Bytes x) = BoxedNC
mergeRepNCCorrect (Boxed y) (Bytes x) = BoxedNC
mergeRepNCCorrect Linked (Bytes x) = LinkedNC
mergeRepNCCorrect (Bytes y) (Boxed x) = BoxedNC
mergeRepNCCorrect (Boxed y) (Boxed x) = BoxedNC
mergeRepNCCorrect Linked (Boxed x) = LinkedNC
mergeRepNCCorrect (Bytes x) Linked = LinkedNC
mergeRepNCCorrect (Boxed x) Linked = BoxedNC
mergeRepNCCorrect Linked Linked = LinkedNC
mergeRepNCCorrect (Bytes x) Delayed = DelayedNC
mergeRepNCCorrect (Boxed x) Delayed = DelayedNC
mergeRepNCCorrect Linked Delayed = DelayedNC
public export
forceRepNCCorrect : (r : Rep) -> NoConstraintRep (forceRepNC r)
forceRepNCCorrect (Bytes x) = BoxedNC
forceRepNCCorrect (Boxed x) = BoxedNC
forceRepNCCorrect Linked = LinkedNC
forceRepNCCorrect Delayed = DelayedNC
--------------------------------------------------------------------------------
-- Byte representations of elements
--------------------------------------------------------------------------------
public export
interface ByteRep a where
bytes : Nat
toBytes : a -> Vect bytes Bits8
fromBytes : Vect bytes Bits8 -> a
export
ByteRep Bits8 where
bytes = 1
toBytes x = [x]
fromBytes [x] = x
export
ByteRep Bits16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Bits32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Bits64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int8 where
bytes = 1
toBytes x = [cast x]
fromBytes [x] = cast x
export
ByteRep Int16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Int32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Int64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Bool where
bytes = 1
toBytes b = [if b then 1 else 0]
fromBytes [x] = x /= 0
export
ByteRep a => ByteRep b => ByteRep (a, b) where
bytes = bytes {a} + bytes {a=b}
toBytes (x,y) = toBytes x ++ toBytes y
fromBytes = bimap fromBytes fromBytes . splitAt _
export
{n : _} -> ByteRep a => ByteRep (Vect n a) where
bytes = n * bytes {a}
toBytes xs = concat $ map toBytes xs
fromBytes {n = 0} bs = []
fromBytes {n = S n} bs =
let (bs1, bs2) = splitAt _ bs
in fromBytes bs1 :: fromBytes bs2
public export
RepConstraint : Rep -> Type -> Type
RepConstraint (Bytes _) a = ByteRep a
RepConstraint (Boxed _) a = ()
RepConstraint Linked a = ()
RepConstraint Delayed a = ()

View file

@ -2,23 +2,39 @@ module Data.NumIdr.PrimArray
import Data.Buffer
import Data.Vect
import Data.Fin
import Data.NP
import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords
import Data.NumIdr.PrimArray.Bytes
import Data.NumIdr.PrimArray.Boxed
import Data.NumIdr.PrimArray.Linked
import Data.NumIdr.PrimArray.Delayed
import public Data.NumIdr.PrimArray.Bytes
import public Data.NumIdr.PrimArray.Boxed
import public Data.NumIdr.PrimArray.Linked
import public Data.NumIdr.PrimArray.Delayed
%default total
public export
RepConstraint : Rep -> Type -> Type
RepConstraint (Bytes _) a = ByteRep a
RepConstraint (Boxed _) a = ()
RepConstraint Linked a = ()
RepConstraint Delayed a = ()
noRCCorrect : (rep : Rep) -> NoConstraintRep rep -> RepConstraint rep a
noRCCorrect (Boxed _) BoxedNC = ()
noRCCorrect Linked LinkedNC = ()
noRCCorrect Delayed DelayedNC = ()
export
mergeRepConstraint : {r, r' : Rep} -> RepConstraint r a -> RepConstraint r' a ->
RepConstraint (mergeRep r r') a
mergeRepConstraint {r,r'} a b with (r == Delayed || r' == Delayed)
_ | True = ()
_ | False = a
export
mergeNCRepConstraint : {r, r' : Rep} -> RepConstraint (mergeRepNC r r') a
mergeNCRepConstraint = noRCCorrect _ (mergeRepNCCorrect r r')
export
forceRepConstraint : {r : Rep} -> RepConstraint (forceRepNC r) a
forceRepConstraint = noRCCorrect _ (forceRepNCCorrect r)
export
@ -29,6 +45,11 @@ PrimArray Linked s a = Vects s a
PrimArray Delayed s a = Coords s -> a
export
reshape : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s' : Vect rk Nat) -> PrimArray rep s a -> PrimArray rep s' a
reshape {rep = Bytes o} s' (MkPABytes _ arr) = MkPABytes (calcStrides o s') arr
reshape {rep = Boxed o} s' (MkPABoxed _ arr) = MkPABoxed (calcStrides o s') arr
export
constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a
constant {rep = Bytes o} = Bytes.constant
@ -66,25 +87,68 @@ index {rep = Linked} is arr = Linked.index is arr
index {rep = Delayed} is arr = arr is
export
indexNB : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> Maybe a
indexNB {rep = Bytes o} is arr@(MkPABytes sts _) =
if and (zipWith (delay .: (<)) is s)
then Just $ index (getLocation' sts is) arr
else Nothing
indexNB {rep = Boxed o} is arr@(MkPABoxed sts _) =
if and (zipWith (delay .: (<)) is s)
then Just $ index (getLocation' sts is) arr
else Nothing
indexNB {rep = Linked} is arr = (`Linked.index` arr) <$> checkRange s is
indexNB {rep = Delayed} is arr = arr <$> checkRange s is
indexUpdate : {rep,s : _} -> RepConstraint rep a => Coords s -> (a -> a) -> PrimArray rep s a -> PrimArray rep s a
indexUpdate {rep = Bytes o} @{rc} is f arr@(MkPABytes sts _) =
let i = getLocation sts is
in create @{rc} s (\n => let x = index n arr in if n == i then f x else x)
indexUpdate {rep = Boxed o} is f arr@(MkPABoxed sts _) =
let i = getLocation sts is
in create s (\n => let x = index n arr in if n == i then f x else x)
indexUpdate {rep = Linked} is f arr = update is f arr
indexUpdate {rep = Delayed} is f arr =
\is' =>
let x = arr is'
in if eqNP is is' then f x else x
export
indexRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s) -> PrimArray rep s a
-> PrimArray rep (newShape rs) a
indexRange {rep = Bytes o} @{rc} rs arr@(MkPABytes sts _) =
let s' : Vect ? Nat
s' = newShape rs
sts' := calcStrides o s'
in unsafeFromIns @{rc} (newShape rs) $
map (\(is,is') => (getLocation' sts' is',
index (getLocation' sts is) arr))
$ getCoordsList rs
indexRange {rep = Boxed o} rs arr@(MkPABoxed sts _) =
let s' : Vect ? Nat
s' = newShape rs
sts' := calcStrides o s'
in unsafeFromIns (newShape rs) $
map (\(is,is') => (getLocation' sts' is',
index (getLocation' sts is) arr))
$ getCoordsList rs
indexRange {rep = Linked} rs arr = Linked.indexRange rs arr
indexRange {rep = Delayed} rs arr = Delayed.indexRange rs arr
export
indexSetRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s)
-> PrimArray rep (newShape rs) a -> PrimArray rep s a -> PrimArray rep s a
indexSetRange {rep = Bytes o} @{rc} rs rpl@(MkPABytes sts' _) arr@(MkPABytes sts _) =
unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns @{rc} (map (\(is,is') =>
(getLocation' sts is, index (getLocation' sts' is') rpl))
$ getCoordsList rs) arr'
pure arr'
indexSetRange {rep = Boxed o} rs rpl@(MkPABoxed sts' _) arr@(MkPABoxed sts _) =
unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns (map (\(is,is') =>
(getLocation' sts is, index (getLocation' sts' is') rpl))
$ getCoordsList rs) arr'
pure arr'
indexSetRange {rep = Linked} rs rpl arr = Linked.indexSetRange rs rpl arr
indexSetRange {rep = Delayed} rs rpl arr = Delayed.indexSetRange rs rpl arr
export
indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a
indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr
indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr
indexUnsafe {rep = Linked} is arr = assert_total $ case checkRange s is of
indexUnsafe {rep = Linked} is arr = assert_total $ case validateCoords s is of
Just is' => Linked.index is' arr
indexUnsafe {rep = Delayed} is arr = assert_total $ case checkRange s is of
indexUnsafe {rep = Delayed} is arr = assert_total $ case validateCoords s is of
Just is' => arr is'
export
@ -100,3 +164,41 @@ convertRep {r1, r2} arr = fromFunction s (\is => PrimArray.index is arr)
export
fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a
fromVects s v = convertRep {r1=Linked} v
export
fromList : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s : Vect rk Nat) -> List a -> PrimArray rep s a
fromList {rep = (Boxed x)} = Boxed.fromList
fromList {rep = (Bytes x)} = Bytes.fromList
export
mapPrim : {rep,s : _} -> RepConstraint rep a => RepConstraint rep b =>
(a -> b) -> PrimArray rep s a -> PrimArray rep s b
mapPrim {rep = (Bytes x)} @{_} @{rc} f arr = Bytes.create @{rc} _ (\i => f $ Bytes.index i arr)
mapPrim {rep = (Boxed x)} f arr = Boxed.create _ (\i => f $ Boxed.index i arr)
mapPrim {rep = Linked} f arr = Linked.mapVects f arr
mapPrim {rep = Delayed} f arr = f . arr
export
foldl : {rep,s : _} -> RepConstraint rep a => (b -> a -> b) -> b -> PrimArray rep s a -> b
foldl {rep = Bytes _} = Bytes.foldl
foldl {rep = Boxed _} = Boxed.foldl
foldl {rep = Linked} = Linked.foldl
foldl {rep = Delayed} = \f,z => Boxed.foldl f z . convertRep {r1=Delayed,r2=B}
export
foldr : {rep,s : _} -> RepConstraint rep a => (a -> b -> b) -> b -> PrimArray rep s a -> b
foldr {rep = Bytes _} = Bytes.foldr
foldr {rep = Boxed _} = Boxed.foldr
foldr {rep = Linked} = Linked.foldr
foldr {rep = Delayed} = \f,z => Boxed.foldr f z . convertRep {r1=Delayed,r2=B}
export
traverse : {rep,s : _} -> Applicative f => RepConstraint rep a => RepConstraint rep b =>
(a -> f b) -> PrimArray rep s a -> f (PrimArray rep s b)
traverse {rep = Bytes o} = Bytes.traverse
traverse {rep = Boxed o} = Boxed.traverse
traverse {rep = Linked} = Linked.traverse
traverse {rep = Delayed} = \f => map (convertRep @{()} @{()} {r1=B,r2=Delayed}) .
Boxed.traverse f .
convertRep @{()} @{()} {r1=Delayed,r2=B}

View file

@ -2,6 +2,7 @@ module Data.NumIdr.PrimArray.Boxed
import Data.Buffer
import System
import Data.IORef
import Data.IOArray.Prims
import Data.Vect
import Data.NumIdr.Array.Rep
@ -108,6 +109,34 @@ fromList s xs = arrayAction s $ go xs 0
go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf
export
foldl : {s : _} -> (b -> a -> b) -> b -> PrimArrayBoxed o s a -> b
foldl f z (MkPABoxed sts arr) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred $ product s] $ \n => do
x <- readIORef ref
y <- arrayDataGet n arr
writeIORef ref (f x y)
readIORef ref
export
foldr : {s : _} -> (a -> b -> b) -> b -> PrimArrayBoxed o s a -> b
foldr f z (MkPABoxed sts arr) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [pred $ product s..0] $ \n => do
x <- arrayDataGet n arr
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : {o,s : _} -> Applicative f => (a -> f b) -> PrimArrayBoxed o s a -> f (PrimArrayBoxed o s b)
traverse f = map (Boxed.fromList _) . traverse f . foldr (::) []
{-
export
updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
@ -145,10 +174,6 @@ fromList xs = create (length xs)
fromJust : Maybe a -> a
fromJust (Just x) = x
||| Map a function over a primitive array.
export
map : (a -> b) -> PrimArrayBoxed a -> PrimArrayBoxed b
map f arr = create (length arr) (\n => f $ index n arr)
export

View file

@ -2,8 +2,6 @@ module Data.NumIdr.PrimArray.Bytes
import System
import Data.IORef
import Data.Bits
import Data.So
import Data.Buffer
import Data.Vect
import Data.NumIdr.Array.Coords
@ -12,160 +10,6 @@ import Data.NumIdr.Array.Rep
%default total
public export
interface ByteRep a where
bytes : Nat
toBytes : a -> Vect bytes Bits8
fromBytes : Vect bytes Bits8 -> a
export
ByteRep Bits8 where
bytes = 1
toBytes x = [x]
fromBytes [x] = x
export
ByteRep Bits16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Bits32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Bits64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int8 where
bytes = 1
toBytes x = [cast x]
fromBytes [x] = cast x
export
ByteRep Int16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Int32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Int64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Bool where
bytes = 1
toBytes b = [if b then 1 else 0]
fromBytes [x] = x /= 0
export
ByteRep a => ByteRep b => ByteRep (a, b) where
bytes = bytes {a} + bytes {a=b}
toBytes (x,y) = toBytes x ++ toBytes y
fromBytes = bimap fromBytes fromBytes . splitAt _
export
{n : _} -> ByteRep a => ByteRep (Vect n a) where
bytes = n * bytes {a}
toBytes xs = concat $ map toBytes xs
fromBytes {n = 0} bs = []
fromBytes {n = S n} bs =
let (bs1, bs2) = splitAt _ bs
in fromBytes bs1 :: fromBytes bs2
export
readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a
readBuffer i buf = do
@ -226,6 +70,15 @@ export
index : ByteRep a => Nat -> PrimArrayBytes o s -> a
index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf
export
copy : {o,s : _} -> PrimArrayBytes o s -> PrimArrayBytes o s
copy (MkPABytes sts buf) =
MkPABytes sts (unsafePerformIO $ do
Just buf' <- newBuffer !(rawSize buf)
| Nothing => die "Cannot create array"
copyData buf 0 !(rawSize buf) buf' 0
pure buf)
export
reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s
reorder @{br} arr@(MkPABytes sts buf) =
@ -242,3 +95,31 @@ fromList @{br} s xs = bufferAction @{br} s $ go xs 0
go : List a -> Nat -> Buffer -> IO ()
go [] _ buf = pure ()
go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf
export
foldl : {s : _} -> ByteRep a => (b -> a -> b) -> b -> PrimArrayBytes o s -> b
foldl f z (MkPABytes sts buf) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred $ product s] $ \n => do
x <- readIORef ref
y <- readBuffer n buf
writeIORef ref (f x y)
readIORef ref
export
foldr : {s : _} -> ByteRep a => (a -> b -> b) -> b -> PrimArrayBytes o s -> b
foldr f z (MkPABytes sts buf) =
if product s == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [pred $ product s..0] $ \n => do
x <- readBuffer n buf
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : {o,s : _} -> ByteRep a => ByteRep b => Applicative f => (a -> f b) -> PrimArrayBytes o s -> f (PrimArrayBytes o s)
traverse f = map (Bytes.fromList _) . traverse f . foldr (::) []

View file

@ -1,5 +1,6 @@
module Data.NumIdr.PrimArray.Delayed
import Data.List
import Data.Vect
import Data.NP
import Data.NumIdr.Array.Rep
@ -18,6 +19,20 @@ constant s x _ = x
export
checkRange : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
checkRange [] [] = Just []
checkRange (d :: s) (i :: is) = (::) <$> natToFin i d <*> checkRange s is
indexRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed s a -> PrimArrayDelayed (newShape rs) a
indexRange [] v = v
indexRange (r :: rs) v with (cRangeToList r)
_ | Left i = indexRange rs (\is' => v (believe_me i :: is'))
_ | Right is = \(i::is') => indexRange rs (\is'' => v (believe_me (Vect.index i (fromList is)) :: is'')) is'
export
indexSetRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed (newShape rs) a
-> PrimArrayDelayed s a -> PrimArrayDelayed s a
indexSetRange {s=[]} [] rv _ = rv
indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r)
_ | Left i = \(i'::is) => if i == finToNat i'
then indexSetRange rs rv (v . (i'::)) is
else v (i'::is)
_ | Right is = \(i'::is') => case findIndex (== finToNat i') is of
Just x => indexSetRange rs (rv . (x::)) (v . (i'::)) is'
Nothing => v (i'::is')

View file

@ -1,6 +1,7 @@
module Data.NumIdr.PrimArray.Linked
import Data.Vect
import Data.DPair
import Data.NP
import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords
@ -18,6 +19,25 @@ index : Coords s -> Vects s a -> a
index [] x = x
index (c :: cs) xs = index cs (index c xs)
export
update : Coords s -> (a -> a) -> Vects s a -> Vects s a
update [] f v = f v
update (i :: is) f v = updateAt i (update is f) v
export
indexRange : {s : _} -> (rs : CoordsRange s) -> Vects s a -> Vects (newShape rs) a
indexRange [] v = v
indexRange (r :: rs) v with (cRangeToList r)
_ | Left i = indexRange rs (Vect.index (believe_me i) v)
_ | Right is = believe_me $ map (\i => indexRange rs (Vect.index (believe_me i) v)) is
export
indexSetRange : {s : _} -> (rs : CoordsRange s) -> Vects (newShape rs) a -> Vects s a -> Vects s a
indexSetRange {s=[]} [] rv _ = rv
indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r)
_ | Left i = updateAt (believe_me i) (indexSetRange rs rv) v
_ | Right is = foldl (\v,i => updateAt (believe_me i) (indexSetRange rs (Vect.index (believe_me i) rv)) v) v is
export
fromFunction : {s : _} -> (Coords s -> a) -> Vects s a
@ -32,3 +52,23 @@ fromFunctionNB {s = d :: s} f = tabulate' {n=d} (\i => fromFunctionNB (f . (i::)
tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a
tabulate' {n=Z} f = []
tabulate' {n=S _} f = f Z :: tabulate' (f . S)
export
mapVects : {s : _} -> (a -> b) -> Vects s a -> Vects s b
mapVects {s = []} f x = f x
mapVects {s = _ :: _} f v = Prelude.map (mapVects f) v
export
foldl : {s : _} -> (b -> a -> b) -> b -> Vects s a -> b
foldl {s=[]} f z v = f z v
foldl {s=_::_} f z v = Prelude.foldl (Linked.foldl f) z v
export
foldr : {s : _} -> (a -> b -> b) -> b -> Vects s a -> b
foldr {s=[]} f z v = f v z
foldr {s=_::_} f z v = Prelude.foldr (flip $ Linked.foldr f) z v
export
traverse : {s : _} -> Applicative f => (a -> f b) -> Vects s a -> f (Vects s b)
traverse {s=[]} f v = f v
traverse {s=_::_} f v = Prelude.traverse (Linked.traverse f) v