Fix indexMaybe and add more utility functions
This commit is contained in:
parent
de66efe75b
commit
b0d48eaf00
|
@ -83,11 +83,6 @@ rank : Array s a -> Nat
|
||||||
rank = length . shape
|
rank = length . shape
|
||||||
|
|
||||||
|
|
||||||
export
|
|
||||||
shapeEq : (arr : Array s a) -> s = shape arr
|
|
||||||
shapeEq (MkArray _ _ _ _) = Refl
|
|
||||||
|
|
||||||
|
|
||||||
-- Get a list of all coordinates
|
-- Get a list of all coordinates
|
||||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||||
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
getAllCoords' = traverse (\case Z => []; S n => [0..n])
|
||||||
|
@ -103,11 +98,16 @@ getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
shapeEq : (arr : Array s a) -> s = shape arr
|
||||||
|
shapeEq (MkArray _ _ _ _) = Refl
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
data ShapeView : Array s a -> Type where
|
data ShapeView : Array s a -> Type where
|
||||||
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
|
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
|
||||||
|
|
||||||
public export
|
export
|
||||||
viewShape : (arr : Array s a) -> ShapeView arr
|
viewShape : (arr : Array s a) -> ShapeView arr
|
||||||
viewShape arr = rewrite shapeEq arr in
|
viewShape arr = rewrite shapeEq arr in
|
||||||
Shape (shape arr)
|
Shape (shape arr)
|
||||||
|
@ -233,15 +233,36 @@ export
|
||||||
(!!) : Array s a -> Coords s -> a
|
(!!) : Array s a -> Coords s -> a
|
||||||
(!!) = flip index
|
(!!) = flip index
|
||||||
|
|
||||||
|
-- TODO: Create set/update at index functions
|
||||||
|
|
||||||
|
export
|
||||||
|
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
|
||||||
|
indexUpdate is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation sts is) f arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
indexSet : Coords s -> a -> Array s a -> Array s a
|
||||||
|
indexSet is = indexUpdate is . const
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
indexMaybe : Vect rk Nat -> Array {rk} s a -> Maybe a
|
indexMaybe : Vect rk Nat -> Array {rk} s a -> Maybe a
|
||||||
indexMaybe is arr = safeIndex (getLocation' (strides arr) is) (getPrim arr)
|
indexMaybe is arr with (viewShape arr)
|
||||||
|
_ | Shape s = if concat @{All} $ zipWith (<) s (shape arr)
|
||||||
|
then Just $ index (getLocation' (strides arr) is) (getPrim arr)
|
||||||
|
else Nothing
|
||||||
|
|
||||||
export
|
export
|
||||||
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
|
||||||
(!?) = flip indexMaybe
|
(!?) = flip indexMaybe
|
||||||
|
|
||||||
|
export
|
||||||
|
indexUpdateMaybe : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Array s a
|
||||||
|
indexUpdateMaybe is f (MkArray ord sts s arr) = MkArray ord sts s (updateAt (getLocation' sts is) f arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
indexSetMaybe : Vect rk Nat -> a -> Array {rk} s a -> Array s a
|
||||||
|
indexSetMaybe is = indexUpdateMaybe is . const
|
||||||
|
|
||||||
|
|
||||||
||| Index the array using the given `CoordsRange` object.
|
||| Index the array using the given `CoordsRange` object.
|
||||||
export
|
export
|
||||||
|
@ -297,6 +318,10 @@ reorder ord' arr with (viewShape arr)
|
||||||
getAllCoords' s)
|
getAllCoords' s)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
resize : (s' : Vect rk Nat) -> a -> Array {rk} s a -> Array s' a
|
||||||
|
resize s' x arr = fromFunction' s' (getOrder arr) (fromMaybe x . (arr !?) . toNats)
|
||||||
|
|
||||||
export
|
export
|
||||||
enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
|
enumerate' : Array {rk} s a -> List (Vect rk Nat, a)
|
||||||
enumerate' (MkArray _ sts sh p) =
|
enumerate' (MkArray _ sts sh p) =
|
||||||
|
|
|
@ -105,7 +105,5 @@ hconcat = concat 1
|
||||||
|
|
||||||
export
|
export
|
||||||
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
||||||
kronecker a b = rewrite dimEq a in rewrite dimEq b in
|
kronecker a b with (viewShape a, viewShape b)
|
||||||
fromFunction [dim a, dim b]
|
_ | (Shape [m], Shape [n]) = fromFunction [m,n] (\[i,j] => index i a * index j b)
|
||||||
(\[i,j] => Vector.index (rewrite dimEq a in i) a *
|
|
||||||
Vector.index (rewrite dimEq b in j) b)
|
|
||||||
|
|
|
@ -59,7 +59,6 @@ create size f = unsafePerformIO $ do
|
||||||
= do arrayDataSet loc (f loc) arr
|
= do arrayDataSet loc (f loc) arr
|
||||||
addToArray (S loc) n arr
|
addToArray (S loc) n arr
|
||||||
|
|
||||||
|
|
||||||
||| Index into a primitive array. This function is unsafe, as it
|
||| Index into a primitive array. This function is unsafe, as it
|
||||||
||| performs no boundary check on the index given.
|
||| performs no boundary check on the index given.
|
||||||
export
|
export
|
||||||
|
@ -73,6 +72,18 @@ safeIndex n arr = if n < length arr
|
||||||
then Just $ index n arr
|
then Just $ index n arr
|
||||||
else Nothing
|
else Nothing
|
||||||
|
|
||||||
|
export
|
||||||
|
copy : PrimArray a -> PrimArray a
|
||||||
|
copy arr = create (length arr) (\n => index n arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
updateAt : Nat -> (a -> a) -> PrimArray a -> PrimArray a
|
||||||
|
updateAt n f arr = if n >= length arr then arr else
|
||||||
|
unsafePerformIO $ do
|
||||||
|
let cpy = copy arr
|
||||||
|
x <- arrayDataGet n cpy.content
|
||||||
|
arrayDataSet n (f x) cpy.content
|
||||||
|
pure cpy
|
||||||
|
|
||||||
||| Convert a primitive array to a list.
|
||| Convert a primitive array to a list.
|
||||||
export
|
export
|
||||||
|
|
Loading…
Reference in a new issue