Shape view
This commit is contained in:
parent
4b293d7e2a
commit
cb0157cd8c
|
@ -87,9 +87,6 @@ export
|
||||||
shapeEq : (arr : Array s a) -> s = shape arr
|
shapeEq : (arr : Array s a) -> s = shape arr
|
||||||
shapeEq (MkArray _ _ _ _) = Refl
|
shapeEq (MkArray _ _ _ _) = Refl
|
||||||
|
|
||||||
export
|
|
||||||
withShape : {0 s' : Vect rk Nat} -> Array s' a -> ((s : Vect rk Nat) -> Array s a -> b) -> b
|
|
||||||
withShape arr f = f (shape arr) (rewrite sym (shapeEq arr) in arr)
|
|
||||||
|
|
||||||
-- Get a list of all coordinates
|
-- Get a list of all coordinates
|
||||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||||
|
@ -101,6 +98,22 @@ getAllCoords (Z :: s) = []
|
||||||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Shape view
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
data ShapeView : Array s a -> Type where
|
||||||
|
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
|
||||||
|
|
||||||
|
public export
|
||||||
|
viewShape : (arr : Array s a) -> ShapeView arr
|
||||||
|
viewShape arr = rewrite shapeEq arr in
|
||||||
|
Shape (shape arr)
|
||||||
|
{arr = rewrite sym (shapeEq arr) in arr}
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Array constructors
|
-- Array constructors
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -233,13 +246,17 @@ export
|
||||||
||| Index the array using the given `CoordsRange` object.
|
||| Index the array using the given `CoordsRange` object.
|
||||||
export
|
export
|
||||||
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||||
indexRange rs arr = let ord = getOrder arr
|
indexRange rs arr with (viewShape arr)
|
||||||
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
_ | Shape s =
|
||||||
sts = calcStrides ord s
|
let ord = getOrder arr
|
||||||
-- TODO: Make this actually typecheck without resorting to this mess
|
sts = calcStrides ord s'
|
||||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $
|
in MkArray ord sts s'
|
||||||
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
|
(unsafeFromIns (product s') $
|
||||||
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
map (\(is,is') => (getLocation' sts is',
|
||||||
|
index (getLocation' (strides arr) is) (getPrim arr)))
|
||||||
|
$ getCoordsList rs)
|
||||||
|
where s' : Vect ? Nat
|
||||||
|
s' = newShape rs
|
||||||
|
|
||||||
export
|
export
|
||||||
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
|
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
|
||||||
|
@ -273,12 +290,11 @@ reshape s' arr = reshape' s' (getOrder arr) arr
|
||||||
||| Change the internal order of the array's elements.
|
||| Change the internal order of the array's elements.
|
||||||
export
|
export
|
||||||
reorder : Order -> Array s a -> Array s a
|
reorder : Order -> Array s a -> Array s a
|
||||||
reorder ord' arr = let s = shape arr
|
reorder ord' arr with (viewShape arr)
|
||||||
sts = calcStrides ord' s
|
_ | Shape s = let sts = calcStrides ord' s
|
||||||
in rewrite shapeEq arr
|
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
||||||
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||||
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
getAllCoords' s)
|
||||||
getAllCoords' s)
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
|
@ -289,8 +305,8 @@ enumerate' (MkArray _ sts sh p) =
|
||||||
|
|
||||||
export
|
export
|
||||||
enumerate : Array {rk} s a -> List (Coords s, a)
|
enumerate : Array {rk} s a -> List (Coords s, a)
|
||||||
enumerate arr = map (\is => (is, index is arr))
|
enumerate arr with (viewShape arr)
|
||||||
(rewrite shapeEq arr in getAllCoords $ shape arr)
|
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
|
||||||
|
|
||||||
export
|
export
|
||||||
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||||
|
@ -329,29 +345,29 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in
|
||||||
|
|
||||||
export
|
export
|
||||||
Zippable (Array s) where
|
Zippable (Array s) where
|
||||||
zipWith f a b = rewrite shapeEq a
|
zipWith f a b with (viewShape a)
|
||||||
in MkArray (getOrder a) (strides a) _ $
|
_ | Shape s = MkArray (getOrder a) (strides a) s $
|
||||||
if getOrder a == getOrder b
|
if getOrder a == getOrder b
|
||||||
then unsafeZipWith f (getPrim a) (getPrim b)
|
then unsafeZipWith f (getPrim a) (getPrim b)
|
||||||
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||||
|
|
||||||
zipWith3 f a b c = rewrite shapeEq a
|
zipWith3 f a b c with (viewShape a)
|
||||||
in MkArray (getOrder a) (strides a) _ $
|
_ | Shape s = MkArray (getOrder a) (strides a) s $
|
||||||
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
|
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
|
||||||
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
|
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
|
||||||
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||||
(getPrim $ reorder (getOrder a) c)
|
(getPrim $ reorder (getOrder a) c)
|
||||||
|
|
||||||
unzipWith f arr = rewrite shapeEq arr
|
unzipWith f arr with (viewShape arr)
|
||||||
in case unzipWith f (getPrim arr) of
|
_ | Shape s = case unzipWith f (getPrim arr) of
|
||||||
(a, b) => (MkArray (getOrder arr) (strides arr) _ a,
|
(a, b) => (MkArray (getOrder arr) (strides arr) s a,
|
||||||
MkArray (getOrder arr) (strides arr) _ b)
|
MkArray (getOrder arr) (strides arr) s b)
|
||||||
|
|
||||||
unzipWith3 f arr = rewrite shapeEq arr
|
unzipWith3 f arr with (viewShape arr)
|
||||||
in case unzipWith3 f (getPrim arr) of
|
_ | Shape s = case unzipWith3 f (getPrim arr) of
|
||||||
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a,
|
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
|
||||||
MkArray (getOrder arr) (strides arr) _ b,
|
MkArray (getOrder arr) (strides arr) s b,
|
||||||
MkArray (getOrder arr) (strides arr) _ c)
|
MkArray (getOrder arr) (strides arr) s c)
|
||||||
|
|
||||||
export
|
export
|
||||||
Functor (Array s) where
|
Functor (Array s) where
|
||||||
|
|
|
@ -16,7 +16,8 @@ HMatrix m n = Matrix (S m) (S n)
|
||||||
|
|
||||||
export
|
export
|
||||||
fromMatrix : Num a => Matrix m n a -> HMatrix m n a
|
fromMatrix : Num a => Matrix m n a -> HMatrix m n a
|
||||||
fromMatrix mat = withDims mat $ \m,n,mat => ?h2 -- rewrite plusCommutative 1 m in ?h -- (mat `vconcat` repeat [1,n] 0) `hconcat` repeat [m,1] 0
|
fromMatrix mat with (viewShape mat)
|
||||||
|
_ | Shape [m,n] = ?h2
|
||||||
|
|
||||||
export
|
export
|
||||||
toMatrix : HMatrix m n a -> Matrix m n a
|
toMatrix : HMatrix m n a -> Matrix m n a
|
||||||
|
|
|
@ -17,10 +17,6 @@ Matrix' : Nat -> Type -> Type
|
||||||
Matrix' n = Matrix n n
|
Matrix' n = Matrix n n
|
||||||
|
|
||||||
|
|
||||||
export
|
|
||||||
withDims : {0 m',n' : Nat} -> {0 b : Nat -> Nat -> Type} -> Matrix m' n' a -> ((m,n : Nat) -> Matrix m n a -> b m n) -> b m' n'
|
|
||||||
withDims mat f = rewrite shapeEq mat in f (head $ shape mat) (index 1 $ shape mat) (rewrite sym (shapeEq mat) in mat)
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Matrix constructors
|
-- Matrix constructors
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
|
|
@ -19,9 +19,6 @@ export
|
||||||
dimEq : (v : Vector n a) -> n = dim v
|
dimEq : (v : Vector n a) -> n = dim v
|
||||||
dimEq v = cong head $ shapeEq v
|
dimEq v = cong head $ shapeEq v
|
||||||
|
|
||||||
export
|
|
||||||
withDim : {0 n' : Nat} -> Vector n' a -> ((n : Nat) -> Vector n a -> b) -> b
|
|
||||||
withDim v f = f (dim v) (rewrite sym (dimEq v) in v)
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Vector constructors
|
-- Vector constructors
|
||||||
|
|
Loading…
Reference in a new issue