diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 237e94d..faad3e9 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -87,9 +87,6 @@ export shapeEq : (arr : Array s a) -> s = shape arr shapeEq (MkArray _ _ _ _) = Refl -export -withShape : {0 s' : Vect rk Nat} -> Array s' a -> ((s : Vect rk Nat) -> Array s a -> b) -> b -withShape arr f = f (shape arr) (rewrite sym (shapeEq arr) in arr) -- Get a list of all coordinates getAllCoords' : Vect rk Nat -> List (Vect rk Nat) @@ -101,6 +98,22 @@ getAllCoords (Z :: s) = [] getAllCoords (S d :: s) = [| forget (allFins d) :: getAllCoords s |] +-------------------------------------------------------------------------------- +-- Shape view +-------------------------------------------------------------------------------- + + +public export +data ShapeView : Array s a -> Type where + Shape : (s : Vect rk Nat) -> {0 arr : Array s a} -> ShapeView arr + +public export +viewShape : (arr : Array s a) -> ShapeView arr +viewShape arr = rewrite shapeEq arr in + Shape (shape arr) + {arr = rewrite sym (shapeEq arr) in arr} + + -------------------------------------------------------------------------------- -- Array constructors -------------------------------------------------------------------------------- @@ -233,13 +246,17 @@ export ||| Index the array using the given `CoordsRange` object. export indexRange : (rs : CoordsRange s) -> Array s a -> Array (newShape rs) a -indexRange rs arr = let ord = getOrder arr - s = newShape {s = shape arr} $ rewrite sym $ shapeEq arr in rs - sts = calcStrides ord s - -- TODO: Make this actually typecheck without resorting to this mess - in believe_me $ MkArray ord sts s (unsafeFromIns (product s) $ - map (\(is,is') => (getLocation' sts is', index (getLocation' (strides arr) is) (getPrim arr))) $ - getCoordsList {s = shape arr} $ rewrite sym $ shapeEq arr in rs) +indexRange rs arr with (viewShape arr) + _ | Shape s = + let ord = getOrder arr + sts = calcStrides ord s' + in MkArray ord sts s' + (unsafeFromIns (product s') $ + map (\(is,is') => (getLocation' sts is', + index (getLocation' (strides arr) is) (getPrim arr))) + $ getCoordsList rs) + where s' : Vect ? Nat + s' = newShape rs export (!!..) : Array s a -> (rs : CoordsRange s) -> Array (newShape rs) a @@ -273,12 +290,11 @@ reshape s' arr = reshape' s' (getOrder arr) arr ||| Change the internal order of the array's elements. export reorder : Order -> Array s a -> Array s a -reorder ord' arr = let s = shape arr - sts = calcStrides ord' s - in rewrite shapeEq arr - in MkArray ord' sts _ (unsafeFromIns (product s) $ - map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ - getAllCoords' s) +reorder ord' arr with (viewShape arr) + _ | Shape s = let sts = calcStrides ord' s + in MkArray ord' sts _ (unsafeFromIns (product s) $ + map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ + getAllCoords' s) export @@ -289,8 +305,8 @@ enumerate' (MkArray _ sts sh p) = export enumerate : Array {rk} s a -> List (Coords s, a) -enumerate arr = map (\is => (is, index is arr)) - (rewrite shapeEq arr in getAllCoords $ shape arr) +enumerate arr with (viewShape arr) + _ | Shape s = map (\is => (is, index is arr)) (getAllCoords s) export concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a -> @@ -329,29 +345,29 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in export Zippable (Array s) where - zipWith f a b = rewrite shapeEq a - in MkArray (getOrder a) (strides a) _ $ - if getOrder a == getOrder b - then unsafeZipWith f (getPrim a) (getPrim b) - else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b) + zipWith f a b with (viewShape a) + _ | Shape s = MkArray (getOrder a) (strides a) s $ + if getOrder a == getOrder b + then unsafeZipWith f (getPrim a) (getPrim b) + else unsafeZipWith f (getPrim a) (getPrim $ reorder (getOrder a) b) - zipWith3 f a b c = rewrite shapeEq a - in MkArray (getOrder a) (strides a) _ $ - if (getOrder a == getOrder b) && (getOrder b == getOrder c) - then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c) - else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b) - (getPrim $ reorder (getOrder a) c) + zipWith3 f a b c with (viewShape a) + _ | Shape s = MkArray (getOrder a) (strides a) s $ + if (getOrder a == getOrder b) && (getOrder b == getOrder c) + then unsafeZipWith3 f (getPrim a) (getPrim b) (getPrim c) + else unsafeZipWith3 f (getPrim a) (getPrim $ reorder (getOrder a) b) + (getPrim $ reorder (getOrder a) c) - unzipWith f arr = rewrite shapeEq arr - in case unzipWith f (getPrim arr) of - (a, b) => (MkArray (getOrder arr) (strides arr) _ a, - MkArray (getOrder arr) (strides arr) _ b) + unzipWith f arr with (viewShape arr) + _ | Shape s = case unzipWith f (getPrim arr) of + (a, b) => (MkArray (getOrder arr) (strides arr) s a, + MkArray (getOrder arr) (strides arr) s b) - unzipWith3 f arr = rewrite shapeEq arr - in case unzipWith3 f (getPrim arr) of - (a, b, c) => (MkArray (getOrder arr) (strides arr) _ a, - MkArray (getOrder arr) (strides arr) _ b, - MkArray (getOrder arr) (strides arr) _ c) + unzipWith3 f arr with (viewShape arr) + _ | Shape s = case unzipWith3 f (getPrim arr) of + (a, b, c) => (MkArray (getOrder arr) (strides arr) s a, + MkArray (getOrder arr) (strides arr) s b, + MkArray (getOrder arr) (strides arr) s c) export Functor (Array s) where diff --git a/src/Data/NumIdr/Homogeneous/Matrix.idr b/src/Data/NumIdr/Homogeneous/Matrix.idr index 1557117..a7ad0f1 100644 --- a/src/Data/NumIdr/Homogeneous/Matrix.idr +++ b/src/Data/NumIdr/Homogeneous/Matrix.idr @@ -16,7 +16,8 @@ HMatrix m n = Matrix (S m) (S n) export fromMatrix : Num a => Matrix m n a -> HMatrix m n a -fromMatrix mat = withDims mat $ \m,n,mat => ?h2 -- rewrite plusCommutative 1 m in ?h -- (mat `vconcat` repeat [1,n] 0) `hconcat` repeat [m,1] 0 +fromMatrix mat with (viewShape mat) + _ | Shape [m,n] = ?h2 export toMatrix : HMatrix m n a -> Matrix m n a diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index 699b6f1..c1559da 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -17,10 +17,6 @@ Matrix' : Nat -> Type -> Type Matrix' n = Matrix n n -export -withDims : {0 m',n' : Nat} -> {0 b : Nat -> Nat -> Type} -> Matrix m' n' a -> ((m,n : Nat) -> Matrix m n a -> b m n) -> b m' n' -withDims mat f = rewrite shapeEq mat in f (head $ shape mat) (index 1 $ shape mat) (rewrite sym (shapeEq mat) in mat) - -------------------------------------------------------------------------------- -- Matrix constructors -------------------------------------------------------------------------------- diff --git a/src/Data/NumIdr/Vector.idr b/src/Data/NumIdr/Vector.idr index 8e06854..1e40068 100644 --- a/src/Data/NumIdr/Vector.idr +++ b/src/Data/NumIdr/Vector.idr @@ -19,9 +19,6 @@ export dimEq : (v : Vector n a) -> n = dim v dimEq v = cong head $ shapeEq v -export -withDim : {0 n' : Nat} -> Vector n' a -> ((n : Nat) -> Vector n a -> b) -> b -withDim v f = f (dim v) (rewrite sym (dimEq v) in v) -------------------------------------------------------------------------------- -- Vector constructors