Shape view

This commit is contained in:
Kiana Sheibani 2022-06-15 11:45:06 -04:00
parent 4b293d7e2a
commit cb0157cd8c
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 56 additions and 46 deletions

View file

@ -87,9 +87,6 @@ export
shapeEq : (arr : Array s a) -> s = shape arr shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl shapeEq (MkArray _ _ _ _) = Refl
export
withShape : {0 s' : Vect rk Nat} -> Array s' a -> ((s : Vect rk Nat) -> Array s a -> b) -> b
withShape arr f = f (shape arr) (rewrite sym (shapeEq arr) in arr)
-- Get a list of all coordinates -- Get a list of all coordinates
getAllCoords' : Vect rk Nat -> List (Vect rk Nat) getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
@ -101,6 +98,22 @@ getAllCoords (Z :: s) = []
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
--------------------------------------------------------------------------------
-- Shape view
--------------------------------------------------------------------------------
public export
data ShapeView : Array s a -> Type where
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
public export
viewShape : (arr : Array s a) -> ShapeView arr
viewShape arr = rewrite shapeEq arr in
Shape (shape arr)
{arr = rewrite sym (shapeEq arr) in arr}
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Array constructors -- Array constructors
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -233,13 +246,17 @@ export
||| Index the array using the given `CoordsRange` object. ||| Index the array using the given `CoordsRange` object.
export export
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
indexRange rs arr = let ord = getOrder arr indexRange rs arr with (viewShape arr)
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs _ | Shape s =
sts = calcStrides ord s let ord = getOrder arr
-- TODO: Make this actually typecheck without resorting to this mess sts = calcStrides ord s'
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ in MkArray ord sts s'
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ (unsafeFromIns (product s') $
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) map (\(is,is') => (getLocation' sts is',
index (getLocation' (strides arr) is) (getPrim arr)))
$ getCoordsList rs)
where s' : Vect ? Nat
s' = newShape rs
export export
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
@ -273,12 +290,11 @@ reshape s' arr = reshape' s' (getOrder arr) arr
||| Change the internal order of the array's elements. ||| Change the internal order of the array's elements.
export export
reorder : Order -> Array s a -> Array s a reorder : Order -> Array s a -> Array s a
reorder ord' arr = let s = shape arr reorder ord' arr with (viewShape arr)
sts = calcStrides ord' s _ | Shape s = let sts = calcStrides ord' s
in rewrite shapeEq arr in MkArray ord' sts _ (unsafeFromIns (product s) $
in MkArray ord' sts _ (unsafeFromIns (product s) $ map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ getAllCoords' s)
getAllCoords' s)
export export
@ -289,8 +305,8 @@ enumerate' (MkArray _ sts sh p) =
export export
enumerate : Array {rk} s a -> List (Coords s, a) enumerate : Array {rk} s a -> List (Coords s, a)
enumerate arr = map (\is => (is, index is arr)) enumerate arr with (viewShape arr)
(rewrite shapeEq arr in getAllCoords $ shape arr) _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
export export
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
@ -329,29 +345,29 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in
export export
Zippable (Array s) where Zippable (Array s) where
zipWith f a b = rewrite shapeEq a zipWith f a b with (viewShape a)
in MkArray (getOrder a) (strides a) _ $ _ | Shape s = MkArray (getOrder a) (strides a) s $
if getOrder a == getOrder b if getOrder a == getOrder b
then unsafeZipWith f (getPrim a) (getPrim b) then unsafeZipWith f (getPrim a) (getPrim b)
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b) else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
zipWith3 f a b c = rewrite shapeEq a zipWith3 f a b c with (viewShape a)
in MkArray (getOrder a) (strides a) _ $ _ | Shape s = MkArray (getOrder a) (strides a) s $
if (getOrder a == getOrder b) && (getOrder b == getOrder c) if (getOrder a == getOrder b) && (getOrder b == getOrder c)
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c) then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b) else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
(getPrim $ reorder (getOrder a) c) (getPrim $ reorder (getOrder a) c)
unzipWith f arr = rewrite shapeEq arr unzipWith f arr with (viewShape arr)
in case unzipWith f (getPrim arr) of _ | Shape s = case unzipWith f (getPrim arr) of
(a, b) => (MkArray (getOrder arr) (strides arr) _ a, (a, b) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) _ b) MkArray (getOrder arr) (strides arr) s b)
unzipWith3 f arr = rewrite shapeEq arr unzipWith3 f arr with (viewShape arr)
in case unzipWith3 f (getPrim arr) of _ | Shape s = case unzipWith3 f (getPrim arr) of
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a, (a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
MkArray (getOrder arr) (strides arr) _ b, MkArray (getOrder arr) (strides arr) s b,
MkArray (getOrder arr) (strides arr) _ c) MkArray (getOrder arr) (strides arr) s c)
export export
Functor (Array s) where Functor (Array s) where

View file

@ -16,7 +16,8 @@ HMatrix m n = Matrix (S m) (S n)
export export
fromMatrix : Num a => Matrix m n a -> HMatrix m n a fromMatrix : Num a => Matrix m n a -> HMatrix m n a
fromMatrix mat = withDims mat $ \m,n,mat => ?h2 -- rewrite plusCommutative 1 m in ?h -- (mat `vconcat` repeat [1,n] 0) `hconcat` repeat [m,1] 0 fromMatrix mat with (viewShape mat)
_ | Shape [m,n] = ?h2
export export
toMatrix : HMatrix m n a -> Matrix m n a toMatrix : HMatrix m n a -> Matrix m n a

View file

@ -17,10 +17,6 @@ Matrix' : Nat -> Type -> Type
Matrix' n = Matrix n n Matrix' n = Matrix n n
export
withDims : {0 m',n' : Nat} -> {0 b : Nat -> Nat -> Type} -> Matrix m' n' a -> ((m,n : Nat) -> Matrix m n a -> b m n) -> b m' n'
withDims mat f = rewrite shapeEq mat in f (head $ shape mat) (index 1 $ shape mat) (rewrite sym (shapeEq mat) in mat)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Matrix constructors -- Matrix constructors
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View file

@ -19,9 +19,6 @@ export
dimEq : (v : Vector n a) -> n = dim v dimEq : (v : Vector n a) -> n = dim v
dimEq v = cong head $ shapeEq v dimEq v = cong head $ shapeEq v
export
withDim : {0 n' : Nat} -> Vector n' a -> ((n : Nat) -> Vector n a -> b) -> b
withDim v f = f (dim v) (rewrite sym (dimEq v) in v)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Vector constructors -- Vector constructors