diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index faad3e9..cbbbfc6 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -83,11 +83,6 @@ rank : Array s a -> Nat rank = length . shape -export -shapeEq : (arr : Array s a) -> s = shape arr -shapeEq (MkArray _ _ _ _) = Refl - - -- Get a list of all coordinates getAllCoords' : Vect rk Nat -> List (Vect rk Nat) getAllCoords' = traverse (\case Z => []; S n => [0..n]) @@ -103,11 +98,16 @@ getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] -------------------------------------------------------------------------------- +export +shapeEq : (arr : Array s a) -> s = shape arr +shapeEq (MkArray _ _ _ _) = Refl + + public export data ShapeView : Array s a -> Type where Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr -public export +export viewShape : (arr : Array s a) -> ShapeView arr viewShape arr = rewrite shapeEq arr in Shape (shape arr) @@ -233,15 +233,36 @@ export (!!) : Array s a -> Coords s -> a (!!) = flip index +-- TODO: Create set/update at index functions + +export +indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a +indexUpdate is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation sts is) f arr) + +export +indexSet : Coords s -> a -> Array s a -> Array s a +indexSet is = indexUpdate is . const + export indexMaybe : Vect rk Nat -> Array {rk} s a -> Maybe a -indexMaybe is arr = safeIndex (getLocation' (strides arr) is) (getPrim arr) +indexMaybe is arr with (viewShape arr) + _ | Shape s = if concat @{All} $ zipWith (<) s (shape arr) + then Just $ index (getLocation' (strides arr) is) (getPrim arr) + else Nothing export (!?) : Array {rk} s a -> Vect rk Nat -> Maybe a (!?) = flip indexMaybe +export +indexUpdateMaybe : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a +indexUpdateMaybe is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation' sts is) f arr) + +export +indexSetMaybe : Vect rk Nat -> a -> Array {rk} s a -> Array s a +indexSetMaybe is = indexUpdateMaybe is . const + ||| Index the array using the given `CoordsRange` object. export @@ -297,6 +318,10 @@ reorder ord' arr with (viewShape arr) getAllCoords' s) +export +resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a +resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNats) + export enumerate' : Array {rk} s a -> List (Vect rk Nat, a) enumerate' (MkArray _ sts sh p) = diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index c1559da..17feb04 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -105,7 +105,5 @@ hconcat = concat 1 export kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a -kronecker a b = rewrite dimEq a in rewrite dimEq b in - fromFunction [dim a, dim b] - (\[i,j] => Vector.index (rewrite dimEq a in i) a * - Vector.index (rewrite dimEq b in j) b) +kronecker a b with (viewShape a, viewShape b) + _ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b) diff --git a/src/Data/NumIdr/PrimArray.idr b/src/Data/NumIdr/PrimArray.idr index c7247ee..f484364 100644 --- a/src/Data/NumIdr/PrimArray.idr +++ b/src/Data/NumIdr/PrimArray.idr @@ -59,7 +59,6 @@ create size f = unsafePerformIO $ do = do arrayDataSet loc (f loc) arr addToArray (S loc) n arr - ||| Index into a primitive array. This function is unsafe, as it ||| performs no boundary check on the index given. export @@ -73,6 +72,18 @@ safeIndex n arr = if n < length arr then Just $ index n arr else Nothing +export +copy : PrimArray a -> PrimArray a +copy arr = create (length arr) (\n => index n arr) + +export +updateAt : Nat -> (a -> a) -> PrimArray a -> PrimArray a +updateAt n f arr = if n >= length arr then arr else + unsafePerformIO $ do + let cpy = copy arr + x <- arrayDataGet n cpy.content + arrayDataSet n (f x) cpy.content + pure cpy ||| Convert a primitive array to a list. export