Revamp all of the standard array API

This commit is contained in:
Kiana Sheibani 2023-09-25 22:55:33 -04:00
parent ccb689af42
commit a0068469c5
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
10 changed files with 655 additions and 364 deletions

View file

@ -112,7 +112,7 @@ viewShape arr = rewrite shapeEq arr in
export
repeat : {default B rep : Rep} -> RepConstraint rep a =>
(s : Vect rk Nat) -> a -> Array s a
repeat s x = MkArray rep s (constant s x)
repeat s x = MkArray rep s (PrimArray.constant s x)
||| Create an array filled with zeros.
|||
@ -138,9 +138,9 @@ ones {rep} s = repeat {rep} s 1
||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array
export
fromVect : {default B rep : Rep} -> RepConstraint rep a =>
fromVect : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
(s : Vect rk Nat) -> Vect (product s) a -> Array s a
fromVect s v = ?fv
fromVect {rep} s v = MkArray rep s (PrimArray.fromList {rep} s $ toList v)
||| Create an array by taking values from a stream.
@ -148,7 +148,7 @@ fromVect s v = ?fv
||| @ s The shape of the constructed array
||| @ rep The internal representation of the constructed array
export
fromStream : {default B rep : Rep} -> RepConstraint rep a =>
fromStream : {default B rep : Rep} -> LinearRep rep => RepConstraint rep a =>
(s : Vect rk Nat) -> Stream a -> Array s a
fromStream {rep} s str = fromVect {rep} s (take _ str)
@ -211,12 +211,12 @@ index is (MkArray _ _ arr) = PrimArray.index is arr
export %inline
(!!) : Array s a -> Coords s -> a
arr !! is = index is arr
{-
||| Update the entry at the given coordinates using the function.
export
indexUpdate : Coords s -> (a -> a) -> Array s a -> Array s a
indexUpdate is f (MkArray ord sts s arr) =
MkArray ord sts s (updateAt (getLocation sts is) f arr)
indexUpdate is f (MkArray rep @{rc} s arr) =
MkArray rep @{rc} s (indexUpdate @{rc} is f arr)
||| Set the entry at the given coordinates to the given value.
export
@ -227,17 +227,8 @@ indexSet is = indexUpdate is . const
||| Index the array using the given range of coordinates, returning a new array.
export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange {s} rs arr with (viewShape arr)
_ | Shape s =
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList rs)
where s' : Vect ? Nat
s' = newShape rs
indexRange rs (MkArray rep @{rc} s arr) =
MkArray rep @{rc} _ (PrimArray.indexRange @{rc} rs arr)
||| Index the array using the given range of coordinates, returning a new array.
|||
@ -250,13 +241,8 @@ arr !!.. rs = indexRange rs arr
export
indexSetRange : (rs : CoordsRange s) -> Array (newShape rs) a ->
Array s a -> Array s a
indexSetRange rs rpl (MkArray ord sts s arr) =
MkArray ord sts s (unsafePerformIO $ do
let arr' = copy arr
unsafeDoIns (map (\(is,is') =>
(getLocation' sts is, index (getLocation' (strides rpl) is') (getPrim rpl)))
$ getCoordsList rs) arr'
pure arr')
indexSetRange rs (MkArray _ _ rpl) (MkArray rep s arr) =
MkArray rep s (PrimArray.indexSetRange {rep} rs (convertRep rpl) arr)
||| Update the sub-array at the given range of coordinates by applying
@ -266,13 +252,14 @@ indexUpdateRange : (rs : CoordsRange s) ->
(Array (newShape rs) a -> Array (newShape rs) a) ->
Array s a -> Array s a
indexUpdateRange rs f arr = indexSetRange rs (f $ arr !!.. rs) arr
-}
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
export
indexNB : Vect rk Nat -> Array {rk} s a -> Maybe a
indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
indexNB is (MkArray rep @{rc} s arr) =
map (\is => index is (MkArray rep @{rc} s arr)) (validateCoords s is)
||| Index the array using the given coordinates, returning `Nothing` if the
||| coordinates are out of bounds.
@ -281,15 +268,13 @@ indexNB is (MkArray _ _ arr) = PrimArray.indexNB is arr
export %inline
(!?) : Array {rk} s a -> Vect rk Nat -> Maybe a
arr !? is = indexNB is arr
{-
||| Update the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds.
export
indexUpdateNB : Vect rk Nat -> (a -> a) -> Array {rk} s a -> Maybe (Array s a)
indexUpdateNB is f (MkArray ord sts s arr) =
if all id $ zipWith (<) is s
then Just $ MkArray ord sts s (updateAt (getLocation' sts is) f arr)
else Nothing
indexUpdateNB is f (MkArray rep @{rc} s arr) =
map (\is => indexUpdate is f (MkArray rep @{rc} s arr)) (validateCoords s is)
||| Set the entry at the given coordinates using the function. `Nothing` is
||| returned if the coordinates are out of bounds.
@ -302,20 +287,8 @@ indexSetNB is = indexUpdateNB is . const
||| If any of the given indices are out of bounds, then `Nothing` is returned.
export
indexRangeNB : (rs : Vect rk CRangeNB) -> Array s a -> Maybe (Array (newShape s rs) a)
indexRangeNB {s} rs arr with (viewShape arr)
_ | Shape s =
if validateShape s rs
then
let ord = getOrder arr
sts = calcStrides ord s'
in Just $ MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
else Nothing
where s' : Vect ? Nat
s' = newShape s rs
indexRangeNB rs (MkArray rep @{rc} s arr) =
map (\rs => believe_me $ Array.indexRange rs (MkArray rep @{rc} s arr)) (validateCRange s rs)
||| Index the array using the given range of coordinates, returning a new array.
||| If any of the given indices are out of bounds, then `Nothing` is returned.
@ -324,7 +297,7 @@ indexRangeNB {s} rs arr with (viewShape arr)
export %inline
(!?..) : Array s a -> (rs : Vect rk CRangeNB) -> Maybe (Array (newShape s rs) a)
arr !?.. rs = indexRangeNB rs arr
-}
||| Index the array using the given coordinates.
||| WARNING: This function does not perform any bounds check on its inputs.
@ -342,30 +315,21 @@ export %inline
(!#) : Array {rk} s a -> Vect rk Nat -> a
arr !# is = indexUnsafe is arr
{-
||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
export
export %unsafe
indexRangeUnsafe : (rs : Vect rk CRangeNB) -> Array s a -> Array (newShape s rs) a
indexRangeUnsafe {s} rs arr with (viewShape arr)
_ | Shape s =
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList s rs)
where s' : Vect ? Nat
s' = newShape s rs
indexRangeUnsafe rs (MkArray rep @{rc} s arr) =
believe_me $ Array.indexRange (assertCRange s rs) (MkArray rep @{rc} s arr)
||| Index the array using the given range of coordinates, returning a new array.
||| WARNING: This function does not perform any bounds check on its inputs.
||| Misuse of this function can easily break memory safety.
|||
||| This is the operator form of `indexRangeUnsafe`.
export %inline
export %inline %unsafe
(!#..) : Array s a -> (rs : Vect rk CRangeNB) -> Array (newShape s rs) a
arr !#.. is = indexRangeUnsafe is arr
@ -375,35 +339,42 @@ arr !#.. is = indexRangeUnsafe is arr
-- Operations on arrays
--------------------------------------------------------------------------------
||| Reshape the array into the given shape and reinterpret it according to
||| the given order.
|||
||| @ s' The shape to convert the array to
||| @ ord The order to reinterpret the array by
export
reshape' : (s' : Vect rk' Nat) -> (ord : Order) -> Array {rk} s a ->
(0 ok : product s = product s') => Array s' a
reshape' s' ord' arr = MkArray ord' (calcStrides ord' s') s' (getPrim arr)
mapArray' : (a -> a) -> Array s a -> Array s a
mapArray' f (MkArray rep _ arr) = MkArray rep _ (mapPrim f arr)
export
mapArray : (a -> b) -> (arr : Array s a) -> RepConstraint (getRep arr) b => Array s b
mapArray f (MkArray rep _ arr) @{rc} = MkArray rep @{rc} _ (mapPrim f arr)
export
zipWithArray' : (a -> a -> a) -> Array s a -> Array s a -> Array s a
zipWithArray' {s} f a b with (viewShape a)
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b))
@{mergeRepConstraint (getRepC a) (getRepC b)} s
$ PrimArray.fromFunctionNB @{mergeRepConstraint (getRepC a) (getRepC b)} _
(\is => f (a !# is) (b !# is))
export
zipWithArray : (a -> b -> c) -> (arr : Array s a) -> (arr' : Array s b) ->
RepConstraint (mergeRep (getRep arr) (getRep arr')) c => Array s c
zipWithArray {s} f a b @{rc} with (viewShape a)
_ | Shape s = MkArray (mergeRep (getRep a) (getRep b)) @{rc} s
$ PrimArray.fromFunctionNB _ (\is => f (a !# is) (b !# is))
||| Reshape the array into the given shape.
|||
||| @ s' The shape to convert the array to
export
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
reshape : (s' : Vect rk' Nat) -> (arr : Array {rk} s a) -> LinearRep (getRep arr) =>
(0 ok : product s = product s') => Array s' a
reshape s' arr = reshape' s' (getOrder arr) arr
reshape s' (MkArray rep _ arr) = MkArray rep s' (PrimArray.reshape s' arr)
||| Change the internal order of the array's elements.
export
reorder : Order -> Array s a -> Array s a
reorder {s} ord' arr with (viewShape arr)
_ | Shape s = let sts = calcStrides ord' s
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
convertRep : (rep : Rep) -> RepConstraint rep a => Array s a -> Array s a
convertRep rep (MkArray _ s arr) = MkArray rep s (PrimArray.convertRep arr)
||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. New coordinates are filled with a default value.
@ -412,7 +383,7 @@ reorder {s} ord' arr with (viewShape arr)
||| @ def The default value to fill the array with
export
resize : (s' : Vect rk Nat) -> (def : a) -> Array {rk} s a -> Array s' a
resize s' def arr = fromFunction' s' (getOrder arr) (fromMaybe def . (arr !?) . toNB)
resize s' def arr = fromFunction {rep=getRep arr} @{getRepC arr} s' (fromMaybe def . (arr !?) . toNB)
||| Resize the array to a new shape, preserving the coordinates of the original
||| elements. This function requires a proof that the new shape is strictly
@ -427,20 +398,11 @@ resizeLTE : (s' : Vect rk Nat) -> (0 ok : NP Prelude.id (zipWith LTE s' s)) =>
resizeLTE s' arr = resize s' (believe_me ()) arr
||| List all of the values in an array in row-major order.
export
elements : Array {rk} s a -> Vect (product s) a
elements (MkArray _ sts sh p) =
let elems = map (flip index p . getLocation' sts) (getAllCoords' sh)
-- TODO: prove that the number of elements is `product s`
in assert_total $ case toVect (product sh) elems of Just v => v
||| List all of the values in an array along with their coordinates.
export
enumerateNB : Array {rk} s a -> List (Vect rk Nat, a)
enumerateNB (MkArray _ sts sh p) =
map (\is => (is, index (getLocation' sts is) p)) (getAllCoords' sh)
enumerateNB (MkArray _ s arr) =
map (\is => (is, PrimArray.indexUnsafe is arr)) (getAllCoords' s)
||| List all of the values in an array along with their coordinates.
export
@ -448,6 +410,12 @@ enumerate : Array s a -> List (Coords s, a)
enumerate {s} arr with (viewShape arr)
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
||| List all of the values in an array in row-major order.
export
elements : Array {rk} s a -> Vect (product s) a
elements (MkArray _ s arr) =
believe_me $ Vect.fromList $
map (flip PrimArray.indexUnsafe arr) (getAllCoords' s)
||| Join two arrays along a particular axis, e.g. combining two matrices
||| vertically or horizontally. All other axes of the arrays must have the
@ -457,16 +425,12 @@ enumerate {s} arr with (viewShape arr)
export
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
Array (updateAt axis (+d) s) a
concat axis a b = let sA = shape a
sB = shape b
dA = index axis sA
dB = index axis sB
s = replaceAt axis (dA + dB) sA
sts = calcStrides COrder s
ins = map (mapFst $ getLocation' sts . toNB) (enumerate a)
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNB) (enumerate b)
-- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
concat {s,d} axis a b with (viewShape a, viewShape b)
_ | (Shape s, Shape (replaceAt axis d s)) =
believe_me $ Array.fromFunctionNB {rep=mergeRep (getRep a) (getRep b)}
@{mergeRepConstraint (getRepC a) (getRepC b)} (updateAt axis (+ index axis (shape b)) s)
(\is => let limit = index axis s
in if index axis is < limit then a !# is else b !# updateAt axis (`minus` limit) is)
||| Stack multiple arrays along a new axis, e.g. stacking vectors to form a matrix.
|||
@ -500,12 +464,11 @@ splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} ->
Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a)
splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is')))
||| Construct the transpose of an array by reversing the order of its axes.
export
transpose : Array s a -> Array (reverse s) a
transpose {s} arr with (viewShape arr)
_ | Shape s = fromFunctionNB (reverse s) (\is => arr !# reverse is)
_ | Shape s = fromFunctionNB _ (\is => arr !# reverse is)
||| Construct the transpose of an array by reversing the order of its axes.
|||
@ -551,41 +514,56 @@ permuteInAxis {s} axis p arr with (viewShape arr)
-- Implementations
--------------------------------------------------------------------------------
-- Most of these implementations apply the operation pointwise. If there are
-- multiple arrays involved with different orders, then all of the arrays are
-- reordered to match one.
export
Zippable (Array s) where
zipWith {s} f a b with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $
if getOrder a == getOrder b
then unsafeZipWith f (getPrim a) (getPrim b)
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
_ | Shape s = MkArray (mergeRepNC (getRep a) (getRep b))
@{mergeNCRepConstraint} s
$ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
(\is => f (a !# is) (b !# is))
zipWith3 {s} f a b c with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
(getPrim $ reorder (getOrder a) c)
_ | Shape s = MkArray (mergeRepNC (mergeRep (getRep a) (getRep b)) (getRep c))
@{mergeNCRepConstraint} s
$ PrimArray.fromFunctionNB @{mergeNCRepConstraint} _
(\is => f (a !# is) (b !# is) (c !# is))
unzipWith {s} f arr with (viewShape arr)
_ | Shape s = case unzipWith f (getPrim arr) of
(a, b) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) s b)
_ | Shape s =
let rep : Rep
rep = forceRepNC $ getRep arr
in (MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => snd $ f (arr !# is)))
unzipWith3 {s} f arr with (viewShape arr)
_ | Shape s = case unzipWith3 f (getPrim arr) of
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) s b,
MkArray (getOrder arr) (strides arr) s c)
_ | Shape s =
let rep : Rep
rep = forceRepNC $ getRep arr
in (MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => fst $ snd $ f (arr !# is)),
MkArray rep
@{forceRepConstraint} s
$ PrimArray.fromFunctionNB @{forceRepConstraint} _
(\is => snd $ snd $ f (arr !# is)))
export
Functor (Array s) where
map f (MkArray ord sts s arr) = MkArray ord sts s (map f arr)
map f (MkArray rep @{rc} s arr) = MkArray (forceRepNC rep) @{forceRepConstraint} s
(mapPrim @{forceRepConstraint} @{forceRepConstraint} f
$ convertRep @{rc} @{forceRepConstraint} arr)
export
{s : _} -> Applicative (Array s) where
@ -602,15 +580,18 @@ export
export
Foldable (Array s) where
foldl f z = foldl f z . getPrim
foldr f z = foldr f z . getPrim
null arr = size arr == Z
toList = toList . getPrim
foldl f z (MkArray _ _ arr) = PrimArray.foldl f z arr
foldr f z (MkArray _ _ arr) = PrimArray.foldr f z arr
null (MkArray _ s _) = isZero (product s)
export
Traversable (Array s) where
traverse f (MkArray ord sts s arr) =
map (MkArray ord sts s) (traverse f arr)
traverse f (MkArray rep @{rc} s arr) =
map (MkArray (forceRepNC rep) @{forceRepConstraint} s)
(PrimArray.traverse {rep=forceRepNC rep}
@{%search} @{forceRepConstraint} @{forceRepConstraint} f
(PrimArray.convertRep @{rc} @{forceRepConstraint} arr))
export
@ -619,13 +600,11 @@ Cast a b => Cast (Array s a) (Array s b) where
export
Eq a => Eq (Array s a) where
a == b = if getOrder a == getOrder b
then unsafeEq (getPrim a) (getPrim b)
else unsafeEq (getPrim a) (getPrim $ reorder (getOrder a) b)
a == b = and $ zipWith (delay .: (==)) (convertRep D a) (convertRep D b)
export
Semigroup a => Semigroup (Array s a) where
(<+>) = zipWith (<+>)
(<+>) = zipWithArray' (<+>)
export
{s : _} -> Monoid a => Monoid (Array s a) where
@ -635,36 +614,34 @@ export
-- were moved into its own interface, this constraint could be removed.
export
{s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+)
(*) = zipWith (*)
(+) = zipWithArray' (+)
(*) = zipWithArray' (*)
fromInteger = repeat s . fromInteger
export
{s : _} -> Neg a => Neg (Array s a) where
negate = map negate
(-) = zipWith (-)
negate = mapArray' negate
(-) = zipWithArray' (-)
export
{s : _} -> Fractional a => Fractional (Array s a) where
recip = map recip
(/) = zipWith (/)
recip = mapArray' recip
(/) = zipWithArray' (/)
export
Num a => Mult a (Array {rk} s a) (Array s a) where
(*.) x = map (*x)
(*.) x = mapArray' (*x)
export
Num a => Mult (Array {rk} s a) a (Array s a) where
(*.) = flip (*.)
export
Show a => Show (Array s a) where
showPrec d arr = let orderedElems = PrimArray.toList $ getPrim $
if getOrder arr == COrder then arr else reorder COrder arr
showPrec d arr = let orderedElems = toList $ elements arr
in showCon d "array " $ concat $ insertPunct (shape arr) $ map show orderedElems
where
splitWindow : Nat -> List String -> List (List String)
@ -691,7 +668,7 @@ Show a => Show (Array s a) where
||| Linearly interpolate between two arrays.
export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
lerp t a b = zipWithArray' (+) (a *. (1 - t)) (b *. t)
||| Calculate the square of an array's Eulidean norm.
@ -715,4 +692,3 @@ normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export
pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p)
-}