Shape view

This commit is contained in:
Kiana Sheibani 2022-06-15 11:45:06 -04:00
parent 4b293d7e2a
commit cb0157cd8c
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 56 additions and 46 deletions

View file

@ -87,9 +87,6 @@ export
shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl
export
withShape : {0 s' : Vect rk Nat} -> Array s' a -> ((s : Vect rk Nat) -> Array s a -> b) -> b
withShape arr f = f (shape arr) (rewrite sym (shapeEq arr) in arr)
-- Get a list of all coordinates
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
@ -101,6 +98,22 @@ getAllCoords (Z :: s) = []
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
--------------------------------------------------------------------------------
-- Shape view
--------------------------------------------------------------------------------
public export
data ShapeView : Array s a -> Type where
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
public export
viewShape : (arr : Array s a) -> ShapeView arr
viewShape arr = rewrite shapeEq arr in
Shape (shape arr)
{arr = rewrite sym (shapeEq arr) in arr}
--------------------------------------------------------------------------------
-- Array constructors
--------------------------------------------------------------------------------
@ -233,13 +246,17 @@ export
||| Index the array using the given `CoordsRange` object.
export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange rs arr = let ord = getOrder arr
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
sts = calcStrides ord s
-- TODO: Make this actually typecheck without resorting to this mess
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
indexRange rs arr with (viewShape arr)
_ | Shape s =
let ord = getOrder arr
sts = calcStrides ord s'
in MkArray ord sts s'
(unsafeFromIns (product s') $
map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList rs)
where s' : Vect ? Nat
s' = newShape rs
export
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
@ -273,9 +290,8 @@ reshape s' arr = reshape' s' (getOrder arr) arr
||| Change the internal order of the array's elements.
export
reorder : Order -> Array s a -> Array s a
reorder ord' arr = let s = shape arr
sts = calcStrides ord' s
in rewrite shapeEq arr
reorder ord' arr with (viewShape arr)
_ | Shape s = let sts = calcStrides ord' s
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
@ -289,8 +305,8 @@ enumerate' (MkArray _ sts sh p) =
export
enumerate : Array {rk} s a -> List (Coords s, a)
enumerate arr = map (\is => (is, index is arr))
(rewrite shapeEq arr in getAllCoords $ shape arr)
enumerate arr with (viewShape arr)
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
export
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
@ -329,29 +345,29 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in
export
Zippable (Array s) where
zipWith f a b = rewrite shapeEq a
in MkArray (getOrder a) (strides a) _ $
zipWith f a b with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $
if getOrder a == getOrder b
then unsafeZipWith f (getPrim a) (getPrim b)
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
zipWith3 f a b c = rewrite shapeEq a
in MkArray (getOrder a) (strides a) _ $
zipWith3 f a b c with (viewShape a)
_ | Shape s = MkArray (getOrder a) (strides a) s $
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
(getPrim $ reorder (getOrder a) c)
unzipWith f arr = rewrite shapeEq arr
in case unzipWith f (getPrim arr) of
(a, b) => (MkArray (getOrder arr) (strides arr) _ a,
MkArray (getOrder arr) (strides arr) _ b)
unzipWith f arr with (viewShape arr)
_ | Shape s = case unzipWith f (getPrim arr) of
(a, b) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) s b)
unzipWith3 f arr = rewrite shapeEq arr
in case unzipWith3 f (getPrim arr) of
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a,
MkArray (getOrder arr) (strides arr) _ b,
MkArray (getOrder arr) (strides arr) _ c)
unzipWith3 f arr with (viewShape arr)
_ | Shape s = case unzipWith3 f (getPrim arr) of
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) s b,
MkArray (getOrder arr) (strides arr) s c)
export
Functor (Array s) where

View file

@ -16,7 +16,8 @@ HMatrix m n = Matrix (S m) (S n)
export
fromMatrix : Num a => Matrix m n a -> HMatrix m n a
fromMatrix mat = withDims mat $ \m,n,mat => ?h2 -- rewrite plusCommutative 1 m in ?h -- (mat `vconcat` repeat [1,n] 0) `hconcat` repeat [m,1] 0
fromMatrix mat with (viewShape mat)
_ | Shape [m,n] = ?h2
export
toMatrix : HMatrix m n a -> Matrix m n a

View file

@ -17,10 +17,6 @@ Matrix' : Nat -> Type -> Type
Matrix' n = Matrix n n
export
withDims : {0 m',n' : Nat} -> {0 b : Nat -> Nat -> Type} -> Matrix m' n' a -> ((m,n : Nat) -> Matrix m n a -> b m n) -> b m' n'
withDims mat f = rewrite shapeEq mat in f (head $ shape mat) (index 1 $ shape mat) (rewrite sym (shapeEq mat) in mat)
--------------------------------------------------------------------------------
-- Matrix constructors
--------------------------------------------------------------------------------

View file

@ -19,9 +19,6 @@ export
dimEq : (v : Vector n a) -> n = dim v
dimEq v = cong head $ shapeEq v
export
withDim : {0 n' : Nat} -> Vector n' a -> ((n : Nat) -> Vector n a -> b) -> b
withDim v f = f (dim v) (rewrite sym (dimEq v) in v)
--------------------------------------------------------------------------------
-- Vector constructors