Shape view
This commit is contained in:
parent
4b293d7e2a
commit
cb0157cd8c
|
@ -87,9 +87,6 @@ export
|
|||
shapeEq : (arr : Array s a) -> s = shape arr
|
||||
shapeEq (MkArray _ _ _ _) = Refl
|
||||
|
||||
export
|
||||
withShape : {0 s' : Vect rk Nat} -> Array s' a -> ((s : Vect rk Nat) -> Array s a -> b) -> b
|
||||
withShape arr f = f (shape arr) (rewrite sym (shapeEq arr) in arr)
|
||||
|
||||
-- Get a list of all coordinates
|
||||
getAllCoords' : Vect rk Nat -> List (Vect rk Nat)
|
||||
|
@ -101,6 +98,22 @@ getAllCoords (Z :: s) = []
|
|||
getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |]
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Shape view
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
public export
|
||||
data ShapeView : Array s a -> Type where
|
||||
Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr
|
||||
|
||||
public export
|
||||
viewShape : (arr : Array s a) -> ShapeView arr
|
||||
viewShape arr = rewrite shapeEq arr in
|
||||
Shape (shape arr)
|
||||
{arr = rewrite sym (shapeEq arr) in arr}
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Array constructors
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -233,13 +246,17 @@ export
|
|||
||| Index the array using the given `CoordsRange` object.
|
||||
export
|
||||
indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a
|
||||
indexRange rs arr = let ord = getOrder arr
|
||||
s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs
|
||||
sts = calcStrides ord s
|
||||
-- TODO: Make this actually typecheck without resorting to this mess
|
||||
in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $
|
||||
map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||
getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs)
|
||||
indexRange rs arr with (viewShape arr)
|
||||
_ | Shape s =
|
||||
let ord = getOrder arr
|
||||
sts = calcStrides ord s'
|
||||
in MkArray ord sts s'
|
||||
(unsafeFromIns (product s') $
|
||||
map (\(is,is') => (getLocation' sts is',
|
||||
index (getLocation' (strides arr) is) (getPrim arr)))
|
||||
$ getCoordsList rs)
|
||||
where s' : Vect ? Nat
|
||||
s' = newShape rs
|
||||
|
||||
export
|
||||
(!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a
|
||||
|
@ -273,9 +290,8 @@ reshape s' arr = reshape' s' (getOrder arr) arr
|
|||
||| Change the internal order of the array's elements.
|
||||
export
|
||||
reorder : Order -> Array s a -> Array s a
|
||||
reorder ord' arr = let s = shape arr
|
||||
sts = calcStrides ord' s
|
||||
in rewrite shapeEq arr
|
||||
reorder ord' arr with (viewShape arr)
|
||||
_ | Shape s = let sts = calcStrides ord' s
|
||||
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
||||
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||
getAllCoords' s)
|
||||
|
@ -289,8 +305,8 @@ enumerate' (MkArray _ sts sh p) =
|
|||
|
||||
export
|
||||
enumerate : Array {rk} s a -> List (Coords s, a)
|
||||
enumerate arr = map (\is => (is, index is arr))
|
||||
(rewrite shapeEq arr in getAllCoords $ shape arr)
|
||||
enumerate arr with (viewShape arr)
|
||||
_ | Shape s = map (\is => (is, index is arr)) (getAllCoords s)
|
||||
|
||||
export
|
||||
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||
|
@ -329,29 +345,29 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in
|
|||
|
||||
export
|
||||
Zippable (Array s) where
|
||||
zipWith f a b = rewrite shapeEq a
|
||||
in MkArray (getOrder a) (strides a) _ $
|
||||
zipWith f a b with (viewShape a)
|
||||
_ | Shape s = MkArray (getOrder a) (strides a) s $
|
||||
if getOrder a == getOrder b
|
||||
then unsafeZipWith f (getPrim a) (getPrim b)
|
||||
else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||
|
||||
zipWith3 f a b c = rewrite shapeEq a
|
||||
in MkArray (getOrder a) (strides a) _ $
|
||||
zipWith3 f a b c with (viewShape a)
|
||||
_ | Shape s = MkArray (getOrder a) (strides a) s $
|
||||
if (getOrder a == getOrder b) && (getOrder b == getOrder c)
|
||||
then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c)
|
||||
else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b)
|
||||
(getPrim $ reorder (getOrder a) c)
|
||||
|
||||
unzipWith f arr = rewrite shapeEq arr
|
||||
in case unzipWith f (getPrim arr) of
|
||||
(a, b) => (MkArray (getOrder arr) (strides arr) _ a,
|
||||
MkArray (getOrder arr) (strides arr) _ b)
|
||||
unzipWith f arr with (viewShape arr)
|
||||
_ | Shape s = case unzipWith f (getPrim arr) of
|
||||
(a, b) => (MkArray (getOrder arr) (strides arr) s a,
|
||||
MkArray (getOrder arr) (strides arr) s b)
|
||||
|
||||
unzipWith3 f arr = rewrite shapeEq arr
|
||||
in case unzipWith3 f (getPrim arr) of
|
||||
(a, b, c) => (MkArray (getOrder arr) (strides arr) _ a,
|
||||
MkArray (getOrder arr) (strides arr) _ b,
|
||||
MkArray (getOrder arr) (strides arr) _ c)
|
||||
unzipWith3 f arr with (viewShape arr)
|
||||
_ | Shape s = case unzipWith3 f (getPrim arr) of
|
||||
(a, b, c) => (MkArray (getOrder arr) (strides arr) s a,
|
||||
MkArray (getOrder arr) (strides arr) s b,
|
||||
MkArray (getOrder arr) (strides arr) s c)
|
||||
|
||||
export
|
||||
Functor (Array s) where
|
||||
|
|
|
@ -16,7 +16,8 @@ HMatrix m n = Matrix (S m) (S n)
|
|||
|
||||
export
|
||||
fromMatrix : Num a => Matrix m n a -> HMatrix m n a
|
||||
fromMatrix mat = withDims mat $ \m,n,mat => ?h2 -- rewrite plusCommutative 1 m in ?h -- (mat `vconcat` repeat [1,n] 0) `hconcat` repeat [m,1] 0
|
||||
fromMatrix mat with (viewShape mat)
|
||||
_ | Shape [m,n] = ?h2
|
||||
|
||||
export
|
||||
toMatrix : HMatrix m n a -> Matrix m n a
|
||||
|
|
|
@ -17,10 +17,6 @@ Matrix' : Nat -> Type -> Type
|
|||
Matrix' n = Matrix n n
|
||||
|
||||
|
||||
export
|
||||
withDims : {0 m',n' : Nat} -> {0 b : Nat -> Nat -> Type} -> Matrix m' n' a -> ((m,n : Nat) -> Matrix m n a -> b m n) -> b m' n'
|
||||
withDims mat f = rewrite shapeEq mat in f (head $ shape mat) (index 1 $ shape mat) (rewrite sym (shapeEq mat) in mat)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Matrix constructors
|
||||
--------------------------------------------------------------------------------
|
||||
|
|
|
@ -19,9 +19,6 @@ export
|
|||
dimEq : (v : Vector n a) -> n = dim v
|
||||
dimEq v = cong head $ shapeEq v
|
||||
|
||||
export
|
||||
withDim : {0 n' : Nat} -> Vector n' a -> ((n : Nat) -> Vector n a -> b) -> b
|
||||
withDim v f = f (dim v) (rewrite sym (dimEq v) in v)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Vector constructors
|
||||
|
|
Loading…
Reference in a new issue