From a0068469c57efcb482c193022042113f570572b3 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Mon, 25 Sep 2023 22:55:33 -0400 Subject: [PATCH] Revamp all of the standard array API --- src/Data/NP.idr | 5 + src/Data/NumIdr/Array.idr | 2 +- src/Data/NumIdr/Array/Array.idr | 274 ++++++++++++-------------- src/Data/NumIdr/Array/Coords.idr | 72 ++++--- src/Data/NumIdr/Array/Rep.idr | 231 ++++++++++++++++++++++ src/Data/NumIdr/PrimArray.idr | 148 +++++++++++--- src/Data/NumIdr/PrimArray/Boxed.idr | 33 +++- src/Data/NumIdr/PrimArray/Bytes.idr | 193 ++++-------------- src/Data/NumIdr/PrimArray/Delayed.idr | 21 +- src/Data/NumIdr/PrimArray/Linked.idr | 40 ++++ 10 files changed, 655 insertions(+), 364 deletions(-) diff --git a/src/Data/NP.idr b/src/Data/NP.idr index 678a29a..6ee76e4 100644 --- a/src/Data/NP.idr +++ b/src/Data/NP.idr @@ -26,6 +26,11 @@ public export elems {all=[_]} [x] = show x elems {all=_::_} (x :: xs) = show x ++ ", " ++ elems xs +-- use this when idris can't figure out the constraint above +public export +eqNP : forall f. (forall t. Eq (f t)) => (x, y : NP f ts) -> Bool +eqNP [] [] = True +eqNP (x :: xs) (y :: ys) = x == y && eqNP xs ys public export diff --git a/src/Data/NumIdr/Array.idr b/src/Data/NumIdr/Array.idr index fd81478..9efe967 100644 --- a/src/Data/NumIdr/Array.idr +++ b/src/Data/NumIdr/Array.idr @@ -3,4 +3,4 @@ module Data.NumIdr.Array import public Data.NP import public Data.NumIdr.Array.Array import public Data.NumIdr.Array.Coords -import public Data.NumIdr.Array.Order +import public Data.NumIdr.Array.Rep diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 7d9c404..4768843 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -112,7 +112,7 @@ viewShape arr = rewrite shapeEq arr in export repeat : {default B rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> Array s a -repeat s x = MkArray rep s (constant s x) +repeat s x = MkArray rep s (PrimArray.constant s x) ||| Create an array filled with zeros. ||| @@ -138,9 +138,9 @@ ones {rep} s = repeat {rep} s 1 ||| @ s The shape of the constructed array ||| @ rep The internal representation of the constructed array export -fromVect : {default B rep : Rep} -> RepConstraint rep a => +fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a => (s : Vect rk Nat) -> Vect (product s) a -> Array s a -fromVect s v = ?fv +fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $ toList v) ||| Create an array by taking values from a stream. @@ -148,7 +148,7 @@ fromVect s v = ?fv ||| @ s The shape of the constructed array ||| @ rep The internal representation of the constructed array export -fromStream : {default B rep : Rep} -> RepConstraint rep a => +fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a => (s : Vect rk Nat) -> Stream a -> Array s a fromStream {rep} s str = fromVect {rep} s (take _ str) @@ -211,12 +211,12 @@ index is (MkArray _ _ arr) = PrimArray.index is arr export %inline (!!) : Array s a -> Coords s -> a arr !! is = index is arr -{- + ||| Update the entry at the given coordinates using the function. export indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a -indexUpdate is f (MkArray ord sts s arr) = - MkArray ord sts s (updateAt (getLocation sts is) f arr) +indexUpdate is f (MkArray rep @{rc} s arr) = + MkArray rep @{rc} s (indexUpdate @{rc} is f arr) ||| Set the entry at the given coordinates to the given value. export @@ -227,17 +227,8 @@ indexSet is = indexUpdate is . const ||| Index the array using the given range of coordinates, returning a new array. export indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a -indexRange {s} rs arr with (viewShape arr) - _ | Shape s = - let ord = getOrder arr - sts = calcStrides ord s' - in MkArray ord sts s' - (unsafeFromIns (product s') $ - map (\(is,is') => (getLocation' sts is', - index (getLocation' (strides arr) is) (getPrim arr))) - $ getCoordsList rs) - where s' : Vect ? Nat - s' = newShape rs +indexRange rs (MkArray rep @{rc} s arr) = + MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr) ||| Index the array using the given range of coordinates, returning a new array. ||| @@ -250,13 +241,8 @@ arr !!.. rs = indexRange rs arr export indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a -> Array s a -> Array s a -indexSetRange rs rpl (MkArray ord sts s arr) = - MkArray ord sts s (unsafePerformIO $ do - let arr' = copy arr - unsafeDoIns (map (\(is,is') => - (getLocation' sts is, index (getLocation' (strides rpl) is') (getPrim rpl))) - $ getCoordsList rs) arr' - pure arr') +indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) = + MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRep rpl) arr) ||| Update the sub-array at the given range of coordinates by applying @@ -266,13 +252,14 @@ indexUpdateRange : (rs : CoordsRange s) -> (Array (newShape rs) a -> Array (newShape rs) a) -> Array s a -> Array s a indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr --} + ||| Index the array using the given coordinates, returning `Nothing` if the ||| coordinates are out of bounds. export indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a -indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr +indexNB is (MkArray rep @{rc} s arr) = + map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is) ||| Index the array using the given coordinates, returning `Nothing` if the ||| coordinates are out of bounds. @@ -281,15 +268,13 @@ indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr export %inline (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a arr !? is = indexNB is arr -{- + ||| Update the entry at the given coordinates using the function. `Nothing` is ||| returned if the coordinates are out of bounds. export indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a) -indexUpdateNB is f (MkArray ord sts s arr) = - if all id $ zipWith (<) is s - then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr) - else Nothing +indexUpdateNB is f (MkArray rep @{rc} s arr) = + map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is) ||| Set the entry at the given coordinates using the function. `Nothing` is ||| returned if the coordinates are out of bounds. @@ -302,20 +287,8 @@ indexSetNB is = indexUpdateNB is . const ||| If any of the given indices are out of bounds, then `Nothing` is returned. export indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a) -indexRangeNB {s} rs arr with (viewShape arr) - _ | Shape s = - if validateShape s rs - then - let ord = getOrder arr - sts = calcStrides ord s' - in Just $ MkArray ord sts s' - (unsafeFromIns (product s') $ - map (\(is,is') => (getLocation' sts is', - index (getLocation' (strides arr) is) (getPrim arr))) - $ getCoordsList s rs) - else Nothing - where s' : Vect ? Nat - s' = newShape s rs +indexRangeNB rs (MkArray rep @{rc} s arr) = + map (\rs => believe_me $ Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs) ||| Index the array using the given range of coordinates, returning a new array. ||| If any of the given indices are out of bounds, then `Nothing` is returned. @@ -324,7 +297,7 @@ indexRangeNB {s} rs arr with (viewShape arr) export %inline (!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a) arr !?.. rs = indexRangeNB rs arr --} + ||| Index the array using the given coordinates. ||| WARNING: This function does not perform any bounds check on its inputs. @@ -342,30 +315,21 @@ export %inline (!#) : Array {rk} s a -> Vect rk Nat -> a arr !# is = indexUnsafe is arr -{- + ||| Index the array using the given range of coordinates, returning a new array. ||| WARNING: This function does not perform any bounds check on its inputs. ||| Misuse of this function can easily break memory safety. -export +export %unsafe indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a -indexRangeUnsafe {s} rs arr with (viewShape arr) - _ | Shape s = - let ord = getOrder arr - sts = calcStrides ord s' - in MkArray ord sts s' - (unsafeFromIns (product s') $ - map (\(is,is') => (getLocation' sts is', - index (getLocation' (strides arr) is) (getPrim arr))) - $ getCoordsList s rs) - where s' : Vect ? Nat - s' = newShape s rs +indexRangeUnsafe rs (MkArray rep @{rc} s arr) = + believe_me $ Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr) ||| Index the array using the given range of coordinates, returning a new array. ||| WARNING: This function does not perform any bounds check on its inputs. ||| Misuse of this function can easily break memory safety. ||| ||| This is the operator form of `indexRangeUnsafe`. -export %inline +export %inline %unsafe (!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a arr !#.. is = indexRangeUnsafe is arr @@ -375,35 +339,42 @@ arr !#.. is = indexRangeUnsafe is arr -- Operations on arrays -------------------------------------------------------------------------------- - -||| Reshape the array into the given shape and reinterpret it according to -||| the given order. -||| -||| @ s' The shape to convert the array to -||| @ ord The order to reinterpret the array by export -reshape' : (s' : Vect rk' Nat) -> (ord : Order) -> Array {rk} s a -> - (0 ok : product s = product s') => Array s' a -reshape' s' ord' arr = MkArray ord' (calcStrides ord' s') s' (getPrim arr) +mapArray' : (a -> a) -> Array s a -> Array s a +mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr) + +export +mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b +mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr) + +export +zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a +zipWithArray' {s} f a b with (viewShape a) + _ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) + @{mergeRepConstraint (getRepC a) (getRepC b)} s + $ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _ + (\is => f (a !# is) (b !# is)) + +export +zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) -> + RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c +zipWithArray {s} f a b @{rc} with (viewShape a) + _ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s + $ PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is)) + ||| Reshape the array into the given shape. ||| ||| @ s' The shape to convert the array to export -reshape : (s' : Vect rk' Nat) -> Array {rk} s a -> +reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) => (0 ok : product s = product s') => Array s' a -reshape s' arr = reshape' s' (getOrder arr) arr - +reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr) ||| Change the internal order of the array's elements. export -reorder : Order -> Array s a -> Array s a -reorder {s} ord' arr with (viewShape arr) - _ | Shape s = let sts = calcStrides ord' s - in MkArray ord' sts _ (unsafeFromIns (product s) $ - map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ - getAllCoords' s) - +convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a +convertRep rep (MkArray _ s arr) = MkArray rep s (PrimArray.convertRep arr) ||| Resize the array to a new shape, preserving the coordinates of the original ||| elements. New coordinates are filled with a default value. @@ -412,7 +383,7 @@ reorder {s} ord' arr with (viewShape arr) ||| @ def The default value to fill the array with export resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a -resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB) +resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB) ||| Resize the array to a new shape, preserving the coordinates of the original ||| elements. This function requires a proof that the new shape is strictly @@ -427,20 +398,11 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) => resizeLTE s' arr = resize s' (believe_me ()) arr -||| List all of the values in an array in row-major order. -export -elements : Array {rk} s a -> Vect (product s) a -elements (MkArray _ sts sh p) = - let elems = map (flip index p . getLocation' sts) (getAllCoords' sh) - -- TODO: prove that the number of elements is `product s` - in assert_total $ case toVect (product sh) elems of Just v => v - - ||| List all of the values in an array along with their coordinates. export enumerateNB : Array {rk} s a -> List (Vect rk Nat, a) -enumerateNB (MkArray _ sts sh p) = - map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh) +enumerateNB (MkArray _ s arr) = + map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s) ||| List all of the values in an array along with their coordinates. export @@ -448,6 +410,12 @@ enumerate : Array s a -> List (Coords s, a) enumerate {s} arr with (viewShape arr) _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s) +||| List all of the values in an array in row-major order. +export +elements : Array {rk} s a -> Vect (product s) a +elements (MkArray _ s arr) = + believe_me $ Vect.fromList $ + map (flip PrimArray.indexUnsafe arr) (getAllCoords' s) ||| Join two arrays along a particular axis, e.g. combining two matrices ||| vertically or horizontally. All other axes of the arrays must have the @@ -457,16 +425,12 @@ enumerate {s} arr with (viewShape arr) export concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> Array (updateAt axis (+d) s) a -concat axis a b = let sA = shape a - sB = shape b - dA = index axis sA - dB = index axis sB - s = replaceAt axis (dA + dB) sA - sts = calcStrides COrder s - ins = map (mapFst $ getLocation' sts . toNB) (enumerate a) - ++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNB) (enumerate b) - -- TODO: prove that the type-level shape and `s` are equivalent - in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) +concat {s,d} axis a b with (viewShape a, viewShape b) + _ | (Shape s, Shape (replaceAt axis d s)) = + believe_me $ Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)} + @{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s) + (\is => let limit = index axis s + in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is) ||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix. ||| @@ -500,12 +464,11 @@ splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} -> Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a) splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is'))) - ||| Construct the transpose of an array by reversing the order of its axes. export transpose : Array s a -> Array (reverse s) a transpose {s} arr with (viewShape arr) - _ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is) + _ | Shape s = fromFunctionNB _ (\is => arr !# reverse is) ||| Construct the transpose of an array by reversing the order of its axes. ||| @@ -551,41 +514,56 @@ permuteInAxis {s} axis p arr with (viewShape arr) -- Implementations -------------------------------------------------------------------------------- - --- Most of these implementations apply the operation pointwise. If there are --- multiple arrays involved with different orders, then all of the arrays are --- reordered to match one. - - export Zippable (Array s) where zipWith {s} f a b with (viewShape a) - _ | Shape s = MkArray (getOrder a) (strides a) s $ - if getOrder a == getOrder b - then unsafeZipWith f (getPrim a) (getPrim b) - else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b) + _ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b)) + @{mergeNCRepConstraint} s + $ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _ + (\is => f (a !# is) (b !# is)) zipWith3 {s} f a b c with (viewShape a) - _ | Shape s = MkArray (getOrder a) (strides a) s $ - if (getOrder a == getOrder b) && (getOrder b == getOrder c) - then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c) - else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b) - (getPrim $ reorder (getOrder a) c) + _ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c)) + @{mergeNCRepConstraint} s + $ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _ + (\is => f (a !# is) (b !# is) (c !# is)) unzipWith {s} f arr with (viewShape arr) - _ | Shape s = case unzipWith f (getPrim arr) of - (a, b) => (MkArray (getOrder arr) (strides arr) s a, - MkArray (getOrder arr) (strides arr) s b) + _ | Shape s = + let rep : Rep + rep = forceRepNC $ getRep arr + in (MkArray rep + @{forceRepConstraint} s + $ PrimArray.fromFunctionNB @{forceRepConstraint} _ + (\is => fst $ f (arr !# is)), + MkArray rep + @{forceRepConstraint} s + $ PrimArray.fromFunctionNB @{forceRepConstraint} _ + (\is => snd $ f (arr !# is))) unzipWith3 {s} f arr with (viewShape arr) - _ | Shape s = case unzipWith3 f (getPrim arr) of - (a, b, c) => (MkArray (getOrder arr) (strides arr) s a, - MkArray (getOrder arr) (strides arr) s b, - MkArray (getOrder arr) (strides arr) s c) + _ | Shape s = + let rep : Rep + rep = forceRepNC $ getRep arr + in (MkArray rep + @{forceRepConstraint} s + $ PrimArray.fromFunctionNB @{forceRepConstraint} _ + (\is => fst $ f (arr !# is)), + MkArray rep + @{forceRepConstraint} s + $ PrimArray.fromFunctionNB @{forceRepConstraint} _ + (\is => fst $ snd $ f (arr !# is)), + MkArray rep + @{forceRepConstraint} s + $ PrimArray.fromFunctionNB @{forceRepConstraint} _ + (\is => snd $ snd $ f (arr !# is))) + export Functor (Array s) where - map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr) + map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s + (mapPrim @{forceRepConstraint} @{forceRepConstraint} f + $ convertRep @{rc} @{forceRepConstraint} arr) export {s : _} -> Applicative (Array s) where @@ -602,15 +580,18 @@ export export Foldable (Array s) where - foldl f z = foldl f z . getPrim - foldr f z = foldr f z . getPrim - null arr = size arr == Z - toList = toList . getPrim + foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr + foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr + null (MkArray _ s _) = isZero (product s) + export Traversable (Array s) where - traverse f (MkArray ord sts s arr) = - map (MkArray ord sts s) (traverse f arr) + traverse f (MkArray rep @{rc} s arr) = + map (MkArray (forceRepNC rep) @{forceRepConstraint} s) + (PrimArray.traverse {rep=forceRepNC rep} + @{%search} @{forceRepConstraint} @{forceRepConstraint} f + (PrimArray.convertRep @{rc} @{forceRepConstraint} arr)) export @@ -619,13 +600,11 @@ Cast a b => Cast (Array s a) (Array s b) where export Eq a => Eq (Array s a) where - a == b = if getOrder a == getOrder b - then unsafeEq (getPrim a) (getPrim b) - else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b) + a == b = and $ zipWith (delay .: (==)) (convertRep D a) (convertRep D b) export Semigroup a => Semigroup (Array s a) where - (<+>) = zipWith (<+>) + (<+>) = zipWithArray' (<+>) export {s : _} -> Monoid a => Monoid (Array s a) where @@ -635,36 +614,34 @@ export -- were moved into its own interface, this constraint could be removed. export {s : _} -> Num a => Num (Array s a) where - (+) = zipWith (+) - (*) = zipWith (*) + (+) = zipWithArray' (+) + (*) = zipWithArray' (*) fromInteger = repeat s . fromInteger export {s : _} -> Neg a => Neg (Array s a) where - negate = map negate - (-) = zipWith (-) + negate = mapArray' negate + (-) = zipWithArray' (-) export {s : _} -> Fractional a => Fractional (Array s a) where - recip = map recip - (/) = zipWith (/) + recip = mapArray' recip + (/) = zipWithArray' (/) export Num a => Mult a (Array {rk} s a) (Array s a) where - (*.) x = map (*x) + (*.) x = mapArray' (*x) export Num a => Mult (Array {rk} s a) a (Array s a) where (*.) = flip (*.) - export Show a => Show (Array s a) where - showPrec d arr = let orderedElems = PrimArray.toList $ getPrim $ - if getOrder arr == COrder then arr else reorder COrder arr + showPrec d arr = let orderedElems = toList $ elements arr in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems where splitWindow : Nat -> List String -> List (List String) @@ -691,7 +668,7 @@ Show a => Show (Array s a) where ||| Linearly interpolate between two arrays. export lerp : Neg a => a -> Array s a -> Array s a -> Array s a -lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t) +lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t) ||| Calculate the square of an array's Eulidean norm. @@ -715,4 +692,3 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr export pnorm : (p : Double) -> Array s Double -> Double pnorm p = (`pow` recip p) . sum . map (`pow` p) --} diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr index 21ac2cd..03b021e 100644 --- a/src/Data/NumIdr/Array/Coords.idr +++ b/src/Data/NumIdr/Array/Coords.idr @@ -43,6 +43,11 @@ toNB : Coords {rk} s -> Vect rk Nat toNB [] = [] toNB (i :: is) = finToNat i :: toNB is +export +validateCoords : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s) +validateCoords [] [] = Just [] +validateCoords (d :: s) (i :: is) = (::) <$> natToFin i d <*> validateCoords s is + namespace Strict public export @@ -155,6 +160,44 @@ namespace Strict namespace NB export + validateCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> Maybe (CoordsRange s) + validateCRange [] [] = Just [] + validateCRange (d :: s) (r :: rs) = [| validate' d r :: validateCRange s rs |] + where + validate' : (n : Nat) -> CRangeNB -> Maybe (CRange n) + validate' n All = Just All + validate' n (StartBound x) = + case isLTE x n of + Yes _ => Just (StartBound (natToFinLT x)) + _ => Nothing + validate' n (EndBound x) = + case isLTE x n of + Yes _ => Just (EndBound (natToFinLT x)) + _ => Nothing + validate' n (Bounds x y) = + case (isLTE x n, isLTE y n) of + (Yes _, Yes _) => Just (Bounds (natToFinLT x) (natToFinLT y)) + _ => Nothing + validate' n (Indices xs) = Indices <$> traverse + (\x => case isLT x n of + Yes _ => Just (natToFinLT x) + No _ => Nothing) xs + validate' n (Filter f) = Just (Filter (f . finToNat)) + + export %unsafe + assertCRange : (s : Vect rk Nat) -> Vect rk CRangeNB -> CoordsRange s + assertCRange [] [] = [] + assertCRange (d :: s) (r :: rs) = assert' d r :: assertCRange s rs + where + assert' : (n : Nat) -> CRangeNB -> CRange n + assert' n All = All + assert' n (StartBound x) = StartBound (believe_me x) + assert' n (EndBound x) = EndBound (believe_me x) + assert' n (Bounds x y) = Bounds (believe_me x) (believe_me y) + assert' n (Indices xs) = Indices (believe_me <$> xs) + assert' n (Filter f) = Filter (f . finToNat) + + public export cRangeNBToList : Nat -> CRangeNB -> List Nat cRangeNBToList s All = range 0 s cRangeNBToList s (StartBound x) = range x s @@ -163,38 +206,11 @@ namespace NB cRangeNBToList s (Indices xs) = nub xs cRangeNBToList s (Filter p) = filter p $ range 0 s - export - validateCRangeNB : Nat -> CRangeNB -> Bool - validateCRangeNB s All = True - validateCRangeNB s (StartBound x) = x < s - validateCRangeNB s (EndBound x) = x <= s - validateCRangeNB s (Bounds x y) = x < s && y <= s - validateCRangeNB s (Indices xs) = all ( Vect rk CRangeNB -> Vect rk Nat newShape = zipWith (length .: cRangeNBToList) - export - validateShape : Vect rk Nat -> Vect rk CRangeNB -> Bool - validateShape = all id .: zipWith validateCRangeNB - - export - getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat - getNewPos = zipWith3 (\d,r,i => assert_total $ - case findIndex (==i) (cRangeNBToList d r) of Just x => cast x) - - export - getCoordsList : Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat, Vect rk Nat) - getCoordsList s rs = map (\is => (is, getNewPos s rs is)) $ go s rs - where - go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat) - go [] [] = [[]] - go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |] - - export getAllCoords' : Vect rk Nat -> List (Vect rk Nat) getAllCoords' = traverse (\case Z => []; S n => [0..n]) diff --git a/src/Data/NumIdr/Array/Rep.idr b/src/Data/NumIdr/Array/Rep.idr index 31b5ab5..3410d13 100644 --- a/src/Data/NumIdr/Array/Rep.idr +++ b/src/Data/NumIdr/Array/Rep.idr @@ -1,6 +1,8 @@ module Data.NumIdr.Array.Rep import Data.Vect +import Data.Bits +import Data.Buffer %default total @@ -23,6 +25,8 @@ data Order : Type where ||| fashion (the first axis is the least significant). FOrder : Order +%name Order o + public export Eq Order where @@ -52,6 +56,9 @@ data Rep : Type where Linked : Rep Delayed : Rep +%name Rep rep + + public export B : Rep B = Boxed COrder @@ -65,7 +72,231 @@ D : Rep D = Delayed +public export +Eq Rep where + Bytes o == Bytes o' = o == o' + Boxed o == Boxed o' = o == o' + Linked == Linked = True + Delayed == Delayed = True + _ == _ = False + public export data LinearRep : Rep -> Type where BytesIsL : LinearRep (Bytes o) BoxedIsL : LinearRep (Boxed o) + + +public export +forceRepNC : Rep -> Rep +forceRepNC (Bytes o) = Boxed o +forceRepNC r = r + +public export +mergeRep : Rep -> Rep -> Rep +mergeRep r r' = if r == Delayed || r' == Delayed then Delayed else r + +public export +mergeRepNC : Rep -> Rep -> Rep +mergeRepNC r r' = + if r == Delayed || r' == Delayed then Delayed + else case r of + Bytes _ => case r' of + Bytes o => Boxed o + _ => r' + _ => r + +public export +data NoConstraintRep : Rep -> Type where + BoxedNC : NoConstraintRep (Boxed o) + LinkedNC : NoConstraintRep Linked + DelayedNC : NoConstraintRep Delayed + +public export +mergeRepNCCorrect : (r, r' : Rep) -> NoConstraintRep (mergeRepNC r r') +mergeRepNCCorrect Delayed _ = DelayedNC +mergeRepNCCorrect (Bytes y) (Bytes x) = BoxedNC +mergeRepNCCorrect (Boxed y) (Bytes x) = BoxedNC +mergeRepNCCorrect Linked (Bytes x) = LinkedNC +mergeRepNCCorrect (Bytes y) (Boxed x) = BoxedNC +mergeRepNCCorrect (Boxed y) (Boxed x) = BoxedNC +mergeRepNCCorrect Linked (Boxed x) = LinkedNC +mergeRepNCCorrect (Bytes x) Linked = LinkedNC +mergeRepNCCorrect (Boxed x) Linked = BoxedNC +mergeRepNCCorrect Linked Linked = LinkedNC +mergeRepNCCorrect (Bytes x) Delayed = DelayedNC +mergeRepNCCorrect (Boxed x) Delayed = DelayedNC +mergeRepNCCorrect Linked Delayed = DelayedNC + +public export +forceRepNCCorrect : (r : Rep) -> NoConstraintRep (forceRepNC r) +forceRepNCCorrect (Bytes x) = BoxedNC +forceRepNCCorrect (Boxed x) = BoxedNC +forceRepNCCorrect Linked = LinkedNC +forceRepNCCorrect Delayed = DelayedNC + + +-------------------------------------------------------------------------------- +-- Byte representations of elements +-------------------------------------------------------------------------------- + + +public export +interface ByteRep a where + bytes : Nat + + toBytes : a -> Vect bytes Bits8 + fromBytes : Vect bytes Bits8 -> a + +export +ByteRep Bits8 where + bytes = 1 + + toBytes x = [x] + fromBytes [x] = x + +export +ByteRep Bits16 where + bytes = 2 + + toBytes x = [cast (x `shiftR` 8), cast x] + fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2 + +export +ByteRep Bits32 where + bytes = 4 + + toBytes x = [cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4] = + cast b1 `shiftL` 24 .|. + cast b2 `shiftL` 16 .|. + cast b3 `shiftL` 8 .|. + cast b4 + +export +ByteRep Bits64 where + bytes = 8 + + toBytes x = [cast (x `shiftR` 56), + cast (x `shiftR` 48), + cast (x `shiftR` 40), + cast (x `shiftR` 32), + cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = + cast b1 `shiftL` 56 .|. + cast b2 `shiftL` 48 .|. + cast b3 `shiftL` 40 .|. + cast b4 `shiftL` 32 .|. + cast b5 `shiftL` 24 .|. + cast b6 `shiftL` 16 .|. + cast b7 `shiftL` 8 .|. + cast b8 + +export +ByteRep Int where + bytes = 8 + + toBytes x = [cast (x `shiftR` 56), + cast (x `shiftR` 48), + cast (x `shiftR` 40), + cast (x `shiftR` 32), + cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = + cast b1 `shiftL` 56 .|. + cast b2 `shiftL` 48 .|. + cast b3 `shiftL` 40 .|. + cast b4 `shiftL` 32 .|. + cast b5 `shiftL` 24 .|. + cast b6 `shiftL` 16 .|. + cast b7 `shiftL` 8 .|. + cast b8 + +export +ByteRep Int8 where + bytes = 1 + + toBytes x = [cast x] + fromBytes [x] = cast x + +export +ByteRep Int16 where + bytes = 2 + + toBytes x = [cast (x `shiftR` 8), cast x] + fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2 + +export +ByteRep Int32 where + bytes = 4 + + toBytes x = [cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4] = + cast b1 `shiftL` 24 .|. + cast b2 `shiftL` 16 .|. + cast b3 `shiftL` 8 .|. + cast b4 + +export +ByteRep Int64 where + bytes = 8 + + toBytes x = [cast (x `shiftR` 56), + cast (x `shiftR` 48), + cast (x `shiftR` 40), + cast (x `shiftR` 32), + cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = + cast b1 `shiftL` 56 .|. + cast b2 `shiftL` 48 .|. + cast b3 `shiftL` 40 .|. + cast b4 `shiftL` 32 .|. + cast b5 `shiftL` 24 .|. + cast b6 `shiftL` 16 .|. + cast b7 `shiftL` 8 .|. + cast b8 + +export +ByteRep Bool where + bytes = 1 + + toBytes b = [if b then 1 else 0] + fromBytes [x] = x /= 0 + +export +ByteRep a => ByteRep b => ByteRep (a, b) where + bytes = bytes {a} + bytes {a=b} + + toBytes (x,y) = toBytes x ++ toBytes y + fromBytes = bimap fromBytes fromBytes . splitAt _ + +export +{n : _} -> ByteRep a => ByteRep (Vect n a) where + bytes = n * bytes {a} + + toBytes xs = concat $ map toBytes xs + fromBytes {n = 0} bs = [] + fromBytes {n = S n} bs = + let (bs1, bs2) = splitAt _ bs + in fromBytes bs1 :: fromBytes bs2 + + +public export +RepConstraint : Rep -> Type -> Type +RepConstraint (Bytes _) a = ByteRep a +RepConstraint (Boxed _) a = () +RepConstraint Linked a = () +RepConstraint Delayed a = () diff --git a/src/Data/NumIdr/PrimArray.idr b/src/Data/NumIdr/PrimArray.idr index a4e958f..d7d1009 100644 --- a/src/Data/NumIdr/PrimArray.idr +++ b/src/Data/NumIdr/PrimArray.idr @@ -2,23 +2,39 @@ module Data.NumIdr.PrimArray import Data.Buffer import Data.Vect +import Data.Fin import Data.NP import Data.NumIdr.Array.Rep import Data.NumIdr.Array.Coords -import Data.NumIdr.PrimArray.Bytes -import Data.NumIdr.PrimArray.Boxed -import Data.NumIdr.PrimArray.Linked -import Data.NumIdr.PrimArray.Delayed +import public Data.NumIdr.PrimArray.Bytes +import public Data.NumIdr.PrimArray.Boxed +import public Data.NumIdr.PrimArray.Linked +import public Data.NumIdr.PrimArray.Delayed %default total -public export -RepConstraint : Rep -> Type -> Type -RepConstraint (Bytes _) a = ByteRep a -RepConstraint (Boxed _) a = () -RepConstraint Linked a = () -RepConstraint Delayed a = () + +noRCCorrect : (rep : Rep) -> NoConstraintRep rep -> RepConstraint rep a +noRCCorrect (Boxed _) BoxedNC = () +noRCCorrect Linked LinkedNC = () +noRCCorrect Delayed DelayedNC = () + +export +mergeRepConstraint : {r, r' : Rep} -> RepConstraint r a -> RepConstraint r' a -> + RepConstraint (mergeRep r r') a +mergeRepConstraint {r,r'} a b with (r == Delayed || r' == Delayed) + _ | True = () + _ | False = a + +export +mergeNCRepConstraint : {r, r' : Rep} -> RepConstraint (mergeRepNC r r') a +mergeNCRepConstraint = noRCCorrect _ (mergeRepNCCorrect r r') + +export +forceRepConstraint : {r : Rep} -> RepConstraint (forceRepNC r) a +forceRepConstraint = noRCCorrect _ (forceRepNCCorrect r) + export @@ -29,6 +45,11 @@ PrimArray Linked s a = Vects s a PrimArray Delayed s a = Coords s -> a +export +reshape : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s' : Vect rk Nat) -> PrimArray rep s a -> PrimArray rep s' a +reshape {rep = Bytes o} s' (MkPABytes _ arr) = MkPABytes (calcStrides o s') arr +reshape {rep = Boxed o} s' (MkPABoxed _ arr) = MkPABoxed (calcStrides o s') arr + export constant : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> a -> PrimArray rep s a constant {rep = Bytes o} = Bytes.constant @@ -66,25 +87,68 @@ index {rep = Linked} is arr = Linked.index is arr index {rep = Delayed} is arr = arr is export -indexNB : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> Maybe a -indexNB {rep = Bytes o} is arr@(MkPABytes sts _) = - if and (zipWith (delay .: (<)) is s) - then Just $ index (getLocation' sts is) arr - else Nothing -indexNB {rep = Boxed o} is arr@(MkPABoxed sts _) = - if and (zipWith (delay .: (<)) is s) - then Just $ index (getLocation' sts is) arr - else Nothing -indexNB {rep = Linked} is arr = (`Linked.index` arr) <$> checkRange s is -indexNB {rep = Delayed} is arr = arr <$> checkRange s is +indexUpdate : {rep,s : _} -> RepConstraint rep a => Coords s -> (a -> a) -> PrimArray rep s a -> PrimArray rep s a +indexUpdate {rep = Bytes o} @{rc} is f arr@(MkPABytes sts _) = + let i = getLocation sts is + in create @{rc} s (\n => let x = index n arr in if n == i then f x else x) +indexUpdate {rep = Boxed o} is f arr@(MkPABoxed sts _) = + let i = getLocation sts is + in create s (\n => let x = index n arr in if n == i then f x else x) +indexUpdate {rep = Linked} is f arr = update is f arr +indexUpdate {rep = Delayed} is f arr = + \is' => + let x = arr is' + in if eqNP is is' then f x else x + +export +indexRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s) -> PrimArray rep s a + -> PrimArray rep (newShape rs) a +indexRange {rep = Bytes o} @{rc} rs arr@(MkPABytes sts _) = + let s' : Vect ? Nat + s' = newShape rs + sts' := calcStrides o s' + in unsafeFromIns @{rc} (newShape rs) $ + map (\(is,is') => (getLocation' sts' is', + index (getLocation' sts is) arr)) + $ getCoordsList rs +indexRange {rep = Boxed o} rs arr@(MkPABoxed sts _) = + let s' : Vect ? Nat + s' = newShape rs + sts' := calcStrides o s' + in unsafeFromIns (newShape rs) $ + map (\(is,is') => (getLocation' sts' is', + index (getLocation' sts is) arr)) + $ getCoordsList rs +indexRange {rep = Linked} rs arr = Linked.indexRange rs arr +indexRange {rep = Delayed} rs arr = Delayed.indexRange rs arr + +export +indexSetRange : {rep,s : _} -> RepConstraint rep a => (rs : CoordsRange s) + -> PrimArray rep (newShape rs) a -> PrimArray rep s a -> PrimArray rep s a +indexSetRange {rep = Bytes o} @{rc} rs rpl@(MkPABytes sts' _) arr@(MkPABytes sts _) = + unsafePerformIO $ do + let arr' = copy arr + unsafeDoIns @{rc} (map (\(is,is') => + (getLocation' sts is, index (getLocation' sts' is') rpl)) + $ getCoordsList rs) arr' + pure arr' +indexSetRange {rep = Boxed o} rs rpl@(MkPABoxed sts' _) arr@(MkPABoxed sts _) = + unsafePerformIO $ do + let arr' = copy arr + unsafeDoIns (map (\(is,is') => + (getLocation' sts is, index (getLocation' sts' is') rpl)) + $ getCoordsList rs) arr' + pure arr' +indexSetRange {rep = Linked} rs rpl arr = Linked.indexSetRange rs rpl arr +indexSetRange {rep = Delayed} rs rpl arr = Delayed.indexSetRange rs rpl arr export indexUnsafe : {rep,s : _} -> RepConstraint rep a => Vect rk Nat -> PrimArray {rk} rep s a -> a indexUnsafe {rep = Bytes o} is arr@(MkPABytes sts _) = index (getLocation' sts is) arr indexUnsafe {rep = Boxed o} is arr@(MkPABoxed sts _) = index (getLocation' sts is) arr -indexUnsafe {rep = Linked} is arr = assert_total $ case checkRange s is of +indexUnsafe {rep = Linked} is arr = assert_total $ case validateCoords s is of Just is' => Linked.index is' arr -indexUnsafe {rep = Delayed} is arr = assert_total $ case checkRange s is of +indexUnsafe {rep = Delayed} is arr = assert_total $ case validateCoords s is of Just is' => arr is' export @@ -100,3 +164,41 @@ convertRep {r1, r2} arr = fromFunction s (\is => PrimArray.index is arr) export fromVects : {rep : Rep} -> RepConstraint rep a => (s : Vect rk Nat) -> Vects s a -> PrimArray rep s a fromVects s v = convertRep {r1=Linked} v + +export +fromList : {rep : Rep} -> LinearRep rep => RepConstraint rep a => (s : Vect rk Nat) -> List a -> PrimArray rep s a +fromList {rep = (Boxed x)} = Boxed.fromList +fromList {rep = (Bytes x)} = Bytes.fromList + +export +mapPrim : {rep,s : _} -> RepConstraint rep a => RepConstraint rep b => + (a -> b) -> PrimArray rep s a -> PrimArray rep s b +mapPrim {rep = (Bytes x)} @{_} @{rc} f arr = Bytes.create @{rc} _ (\i => f $ Bytes.index i arr) +mapPrim {rep = (Boxed x)} f arr = Boxed.create _ (\i => f $ Boxed.index i arr) +mapPrim {rep = Linked} f arr = Linked.mapVects f arr +mapPrim {rep = Delayed} f arr = f . arr + + +export +foldl : {rep,s : _} -> RepConstraint rep a => (b -> a -> b) -> b -> PrimArray rep s a -> b +foldl {rep = Bytes _} = Bytes.foldl +foldl {rep = Boxed _} = Boxed.foldl +foldl {rep = Linked} = Linked.foldl +foldl {rep = Delayed} = \f,z => Boxed.foldl f z . convertRep {r1=Delayed,r2=B} + +export +foldr : {rep,s : _} -> RepConstraint rep a => (a -> b -> b) -> b -> PrimArray rep s a -> b +foldr {rep = Bytes _} = Bytes.foldr +foldr {rep = Boxed _} = Boxed.foldr +foldr {rep = Linked} = Linked.foldr +foldr {rep = Delayed} = \f,z => Boxed.foldr f z . convertRep {r1=Delayed,r2=B} + +export +traverse : {rep,s : _} -> Applicative f => RepConstraint rep a => RepConstraint rep b => + (a -> f b) -> PrimArray rep s a -> f (PrimArray rep s b) +traverse {rep = Bytes o} = Bytes.traverse +traverse {rep = Boxed o} = Boxed.traverse +traverse {rep = Linked} = Linked.traverse +traverse {rep = Delayed} = \f => map (convertRep @{()} @{()} {r1=B,r2=Delayed}) . + Boxed.traverse f . + convertRep @{()} @{()} {r1=Delayed,r2=B} diff --git a/src/Data/NumIdr/PrimArray/Boxed.idr b/src/Data/NumIdr/PrimArray/Boxed.idr index 9b038bd..54716d7 100644 --- a/src/Data/NumIdr/PrimArray/Boxed.idr +++ b/src/Data/NumIdr/PrimArray/Boxed.idr @@ -2,6 +2,7 @@ module Data.NumIdr.PrimArray.Boxed import Data.Buffer import System +import Data.IORef import Data.IOArray.Prims import Data.Vect import Data.NumIdr.Array.Rep @@ -108,6 +109,34 @@ fromList s xs = arrayAction s $ go xs 0 go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf +export +foldl : {s : _} -> (b -> a -> b) -> b -> PrimArrayBoxed o s a -> b +foldl f z (MkPABoxed sts arr) = + if product s == 0 then z + else unsafePerformIO $ do + ref <- newIORef z + for_ [0..pred $ product s] $ \n => do + x <- readIORef ref + y <- arrayDataGet n arr + writeIORef ref (f x y) + readIORef ref + +export +foldr : {s : _} -> (a -> b -> b) -> b -> PrimArrayBoxed o s a -> b +foldr f z (MkPABoxed sts arr) = + if product s == 0 then z + else unsafePerformIO $ do + ref <- newIORef z + for_ [pred $ product s..0] $ \n => do + x <- arrayDataGet n arr + y <- readIORef ref + writeIORef ref (f x y) + readIORef ref + +export +traverse : {o,s : _} -> Applicative f => (a -> f b) -> PrimArrayBoxed o s a -> f (PrimArrayBoxed o s b) +traverse f = map (Boxed.fromList _) . traverse f . foldr (::) [] + {- export updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a @@ -145,10 +174,6 @@ fromList xs = create (length xs) fromJust : Maybe a -> a fromJust (Just x) = x -||| Map a function over a primitive array. -export -map : (a -> b) -> PrimArrayBoxed a -> PrimArrayBoxed b -map f arr = create (length arr) (\n => f $ index n arr) export diff --git a/src/Data/NumIdr/PrimArray/Bytes.idr b/src/Data/NumIdr/PrimArray/Bytes.idr index 4c0ce72..bd944f0 100644 --- a/src/Data/NumIdr/PrimArray/Bytes.idr +++ b/src/Data/NumIdr/PrimArray/Bytes.idr @@ -2,8 +2,6 @@ module Data.NumIdr.PrimArray.Bytes import System import Data.IORef -import Data.Bits -import Data.So import Data.Buffer import Data.Vect import Data.NumIdr.Array.Coords @@ -12,160 +10,6 @@ import Data.NumIdr.Array.Rep %default total -public export -interface ByteRep a where - bytes : Nat - - toBytes : a -> Vect bytes Bits8 - fromBytes : Vect bytes Bits8 -> a - -export -ByteRep Bits8 where - bytes = 1 - - toBytes x = [x] - fromBytes [x] = x - -export -ByteRep Bits16 where - bytes = 2 - - toBytes x = [cast (x `shiftR` 8), cast x] - fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2 - -export -ByteRep Bits32 where - bytes = 4 - - toBytes x = [cast (x `shiftR` 24), - cast (x `shiftR` 16), - cast (x `shiftR` 8), - cast x] - fromBytes [b1,b2,b3,b4] = - cast b1 `shiftL` 24 .|. - cast b2 `shiftL` 16 .|. - cast b3 `shiftL` 8 .|. - cast b4 - -export -ByteRep Bits64 where - bytes = 8 - - toBytes x = [cast (x `shiftR` 56), - cast (x `shiftR` 48), - cast (x `shiftR` 40), - cast (x `shiftR` 32), - cast (x `shiftR` 24), - cast (x `shiftR` 16), - cast (x `shiftR` 8), - cast x] - fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = - cast b1 `shiftL` 56 .|. - cast b2 `shiftL` 48 .|. - cast b3 `shiftL` 40 .|. - cast b4 `shiftL` 32 .|. - cast b5 `shiftL` 24 .|. - cast b6 `shiftL` 16 .|. - cast b7 `shiftL` 8 .|. - cast b8 - -export -ByteRep Int where - bytes = 8 - - toBytes x = [cast (x `shiftR` 56), - cast (x `shiftR` 48), - cast (x `shiftR` 40), - cast (x `shiftR` 32), - cast (x `shiftR` 24), - cast (x `shiftR` 16), - cast (x `shiftR` 8), - cast x] - fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = - cast b1 `shiftL` 56 .|. - cast b2 `shiftL` 48 .|. - cast b3 `shiftL` 40 .|. - cast b4 `shiftL` 32 .|. - cast b5 `shiftL` 24 .|. - cast b6 `shiftL` 16 .|. - cast b7 `shiftL` 8 .|. - cast b8 - -export -ByteRep Int8 where - bytes = 1 - - toBytes x = [cast x] - fromBytes [x] = cast x - -export -ByteRep Int16 where - bytes = 2 - - toBytes x = [cast (x `shiftR` 8), cast x] - fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2 - -export -ByteRep Int32 where - bytes = 4 - - toBytes x = [cast (x `shiftR` 24), - cast (x `shiftR` 16), - cast (x `shiftR` 8), - cast x] - fromBytes [b1,b2,b3,b4] = - cast b1 `shiftL` 24 .|. - cast b2 `shiftL` 16 .|. - cast b3 `shiftL` 8 .|. - cast b4 - -export -ByteRep Int64 where - bytes = 8 - - toBytes x = [cast (x `shiftR` 56), - cast (x `shiftR` 48), - cast (x `shiftR` 40), - cast (x `shiftR` 32), - cast (x `shiftR` 24), - cast (x `shiftR` 16), - cast (x `shiftR` 8), - cast x] - fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = - cast b1 `shiftL` 56 .|. - cast b2 `shiftL` 48 .|. - cast b3 `shiftL` 40 .|. - cast b4 `shiftL` 32 .|. - cast b5 `shiftL` 24 .|. - cast b6 `shiftL` 16 .|. - cast b7 `shiftL` 8 .|. - cast b8 - -export -ByteRep Bool where - bytes = 1 - - toBytes b = [if b then 1 else 0] - fromBytes [x] = x /= 0 - -export -ByteRep a => ByteRep b => ByteRep (a, b) where - bytes = bytes {a} + bytes {a=b} - - toBytes (x,y) = toBytes x ++ toBytes y - fromBytes = bimap fromBytes fromBytes . splitAt _ - -export -{n : _} -> ByteRep a => ByteRep (Vect n a) where - bytes = n * bytes {a} - - toBytes xs = concat $ map toBytes xs - fromBytes {n = 0} bs = [] - fromBytes {n = S n} bs = - let (bs1, bs2) = splitAt _ bs - in fromBytes bs1 :: fromBytes bs2 - - export readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a readBuffer i buf = do @@ -226,6 +70,15 @@ export index : ByteRep a => Nat -> PrimArrayBytes o s -> a index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf +export +copy : {o,s : _} -> PrimArrayBytes o s -> PrimArrayBytes o s +copy (MkPABytes sts buf) = + MkPABytes sts (unsafePerformIO $ do + Just buf' <- newBuffer !(rawSize buf) + | Nothing => die "Cannot create array" + copyData buf 0 !(rawSize buf) buf' 0 + pure buf) + export reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s reorder @{br} arr@(MkPABytes sts buf) = @@ -242,3 +95,31 @@ fromList @{br} s xs = bufferAction @{br} s $ go xs 0 go : List a -> Nat -> Buffer -> IO () go [] _ buf = pure () go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf + +export +foldl : {s : _} -> ByteRep a => (b -> a -> b) -> b -> PrimArrayBytes o s -> b +foldl f z (MkPABytes sts buf) = + if product s == 0 then z + else unsafePerformIO $ do + ref <- newIORef z + for_ [0..pred $ product s] $ \n => do + x <- readIORef ref + y <- readBuffer n buf + writeIORef ref (f x y) + readIORef ref + +export +foldr : {s : _} -> ByteRep a => (a -> b -> b) -> b -> PrimArrayBytes o s -> b +foldr f z (MkPABytes sts buf) = + if product s == 0 then z + else unsafePerformIO $ do + ref <- newIORef z + for_ [pred $ product s..0] $ \n => do + x <- readBuffer n buf + y <- readIORef ref + writeIORef ref (f x y) + readIORef ref + +export +traverse : {o,s : _} -> ByteRep a => ByteRep b => Applicative f => (a -> f b) -> PrimArrayBytes o s -> f (PrimArrayBytes o s) +traverse f = map (Bytes.fromList _) . traverse f . foldr (::) [] diff --git a/src/Data/NumIdr/PrimArray/Delayed.idr b/src/Data/NumIdr/PrimArray/Delayed.idr index c7a5825..96030e8 100644 --- a/src/Data/NumIdr/PrimArray/Delayed.idr +++ b/src/Data/NumIdr/PrimArray/Delayed.idr @@ -1,5 +1,6 @@ module Data.NumIdr.PrimArray.Delayed +import Data.List import Data.Vect import Data.NP import Data.NumIdr.Array.Rep @@ -18,6 +19,20 @@ constant s x _ = x export -checkRange : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s) -checkRange [] [] = Just [] -checkRange (d :: s) (i :: is) = (::) <$> natToFin i d <*> checkRange s is +indexRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed s a -> PrimArrayDelayed (newShape rs) a +indexRange [] v = v +indexRange (r :: rs) v with (cRangeToList r) + _ | Left i = indexRange rs (\is' => v (believe_me i :: is')) + _ | Right is = \(i::is') => indexRange rs (\is'' => v (believe_me (Vect.index i (fromList is)) :: is'')) is' + +export +indexSetRange : {s : _} -> (rs : CoordsRange s) -> PrimArrayDelayed (newShape rs) a + -> PrimArrayDelayed s a -> PrimArrayDelayed s a +indexSetRange {s=[]} [] rv _ = rv +indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r) + _ | Left i = \(i'::is) => if i == finToNat i' + then indexSetRange rs rv (v . (i'::)) is + else v (i'::is) + _ | Right is = \(i'::is') => case findIndex (== finToNat i') is of + Just x => indexSetRange rs (rv . (x::)) (v . (i'::)) is' + Nothing => v (i'::is') diff --git a/src/Data/NumIdr/PrimArray/Linked.idr b/src/Data/NumIdr/PrimArray/Linked.idr index 54af7c0..2e9817e 100644 --- a/src/Data/NumIdr/PrimArray/Linked.idr +++ b/src/Data/NumIdr/PrimArray/Linked.idr @@ -1,6 +1,7 @@ module Data.NumIdr.PrimArray.Linked import Data.Vect +import Data.DPair import Data.NP import Data.NumIdr.Array.Rep import Data.NumIdr.Array.Coords @@ -18,6 +19,25 @@ index : Coords s -> Vects s a -> a index [] x = x index (c :: cs) xs = index cs (index c xs) +export +update : Coords s -> (a -> a) -> Vects s a -> Vects s a +update [] f v = f v +update (i :: is) f v = updateAt i (update is f) v + +export +indexRange : {s : _} -> (rs : CoordsRange s) -> Vects s a -> Vects (newShape rs) a +indexRange [] v = v +indexRange (r :: rs) v with (cRangeToList r) + _ | Left i = indexRange rs (Vect.index (believe_me i) v) + _ | Right is = believe_me $ map (\i => indexRange rs (Vect.index (believe_me i) v)) is + +export +indexSetRange : {s : _} -> (rs : CoordsRange s) -> Vects (newShape rs) a -> Vects s a -> Vects s a +indexSetRange {s=[]} [] rv _ = rv +indexSetRange {s=_::_} (r :: rs) rv v with (cRangeToList r) + _ | Left i = updateAt (believe_me i) (indexSetRange rs rv) v + _ | Right is = foldl (\v,i => updateAt (believe_me i) (indexSetRange rs (Vect.index (believe_me i) rv)) v) v is + export fromFunction : {s : _} -> (Coords s -> a) -> Vects s a @@ -32,3 +52,23 @@ fromFunctionNB {s = d :: s} f = tabulate' {n=d} (\i => fromFunctionNB (f . (i::) tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a tabulate' {n=Z} f = [] tabulate' {n=S _} f = f Z :: tabulate' (f . S) + +export +mapVects : {s : _} -> (a -> b) -> Vects s a -> Vects s b +mapVects {s = []} f x = f x +mapVects {s = _ :: _} f v = Prelude.map (mapVects f) v + +export +foldl : {s : _} -> (b -> a -> b) -> b -> Vects s a -> b +foldl {s=[]} f z v = f z v +foldl {s=_::_} f z v = Prelude.foldl (Linked.foldl f) z v + +export +foldr : {s : _} -> (a -> b -> b) -> b -> Vects s a -> b +foldr {s=[]} f z v = f v z +foldr {s=_::_} f z v = Prelude.foldr (flip $ Linked.foldr f) z v + +export +traverse : {s : _} -> Applicative f => (a -> f b) -> Vects s a -> f (Vects s b) +traverse {s=[]} f v = f v +traverse {s=_::_} f v = Prelude.traverse (Linked.traverse f) v