Refactor Data.NumIdr.Array.Coords
This commit is contained in:
parent
8f1eef25dc
commit
1d6b9d3be9
|
@ -1,117 +1,143 @@
|
||||||
module Data.NumIdr.Array.Coords
|
module Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
import Data.Either
|
import Data.Either
|
||||||
|
import Data.List
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NP
|
import Data.NP
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
namespace Coords
|
--------------------------------------------------------------------------------
|
||||||
|
-- Array coordinate types
|
||||||
||| A type-safe coordinate system for an array. The coordinates are
|
--------------------------------------------------------------------------------
|
||||||
||| values of `Fin dim`, where `dim` is the dimension of each axis.
|
|
||||||
public export
|
|
||||||
Coords : (s : Vect rk Nat) -> Type
|
|
||||||
Coords = NP Fin
|
|
||||||
|
|
||||||
||| Forget the shape of the array by converting each index to type `Nat`.
|
|
||||||
export
|
|
||||||
toNB : Coords {rk} s -> Vect rk Nat
|
|
||||||
toNB [] = []
|
|
||||||
toNB (i :: is) = finToNat i :: toNB is
|
|
||||||
|
|
||||||
|
|
||||||
public export
|
||| A type-safe coordinate system for an array. The coordinates are
|
||||||
Vects : Vect rk Nat -> Type -> Type
|
||| values of `Fin dim`, where `dim` is the dimension of each axis.
|
||||||
Vects [] a = a
|
public export
|
||||||
Vects (d::s) a = Vect d (Vects s a)
|
Coords : (s : Vect rk Nat) -> Type
|
||||||
|
Coords = NP Fin
|
||||||
|
|
||||||
export
|
||| Forget the shape of the array by converting each index to type `Nat`.
|
||||||
collapse : {s : _} -> Vects s a -> List a
|
export
|
||||||
collapse {s=[]} = pure
|
toNB : Coords {rk} s -> Vect rk Nat
|
||||||
collapse {s=_::_} = concat . map collapse
|
toNB [] = []
|
||||||
|
toNB (i :: is) = finToNat i :: toNB is
|
||||||
|
|
||||||
|
|
||||||
export
|
public export
|
||||||
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
|
Vects : Vect rk Nat -> Type -> Type
|
||||||
mapWithIndex {s=[]} f x = f [] x
|
Vects [] a = a
|
||||||
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
|
Vects (d::s) a = Vect d (Vects s a)
|
||||||
|
|
||||||
|
export
|
||||||
|
collapse : {s : _} -> Vects s a -> List a
|
||||||
|
collapse {s=[]} = pure
|
||||||
|
collapse {s=_::_} = concat . map collapse
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
|
||||||
|
mapWithIndex {s=[]} f x = f [] x
|
||||||
|
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
|
||||||
where
|
where
|
||||||
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
|
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
|
||||||
mapWithIndex' f [] = []
|
mapWithIndex' f [] = []
|
||||||
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
||||||
getLocation' = sum .: zipWith (*)
|
getLocation' = sum .: zipWith (*)
|
||||||
|
|
||||||
||| Compute the memory location of an array element
|
||| Compute the memory location of an array element
|
||||||
||| given its coordinate and the strides of the array.
|
||| given its coordinate and the strides of the array.
|
||||||
export
|
export
|
||||||
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
||||||
getLocation sts is = getLocation' sts (toNB is)
|
getLocation sts is = getLocation' sts (toNB is)
|
||||||
|
|
||||||
|
|
||||||
namespace CoordsRange
|
--------------------------------------------------------------------------------
|
||||||
|
-- Array coordinate types
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
public export
|
|
||||||
data CRange : Nat -> Type where
|
-- A Nat-based range function with better semantics
|
||||||
|
public export
|
||||||
|
range : Nat -> Nat -> List Nat
|
||||||
|
range x y = if x < y then assert_total $ takeBefore (>= y) (countFrom x S)
|
||||||
|
else []
|
||||||
|
|
||||||
|
-- helpful theorems for working with ranges
|
||||||
|
|
||||||
|
export
|
||||||
|
rangeLen : (x,y : Nat) -> length (range x y) = minus y x
|
||||||
|
rangeLen x y = believe_me $ Refl {x = minus y x}
|
||||||
|
|
||||||
|
export
|
||||||
|
rangeLenZ : (x : Nat) -> length (range 0 x) = x
|
||||||
|
rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
data CRange : Nat -> Type where
|
||||||
One : Fin n -> CRange n
|
One : Fin n -> CRange n
|
||||||
All : CRange n
|
All : CRange n
|
||||||
StartBound : Fin (S n) -> CRange n
|
StartBound : Fin (S n) -> CRange n
|
||||||
EndBound : Fin (S n) -> CRange n
|
EndBound : Fin (S n) -> CRange n
|
||||||
Bounds : Fin (S n) -> Fin (S n) -> CRange n
|
Bounds : Fin (S n) -> Fin (S n) -> CRange n
|
||||||
|
Indices : List (Fin n) -> CRange n
|
||||||
|
Filter : (Fin n -> Bool) -> CRange n
|
||||||
|
|
||||||
infix 0 ...
|
infix 0 ...
|
||||||
|
public export
|
||||||
public export
|
(...) : Fin (S n) -> Fin (S n) -> CRange n
|
||||||
(...) : Fin (S n) -> Fin (S n) -> CRange n
|
(...) = Bounds
|
||||||
(...) = Bounds
|
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
CoordsRange : (s : Vect rk Nat) -> Type
|
CoordsRange : (s : Vect rk Nat) -> Type
|
||||||
CoordsRange = NP CRange
|
CoordsRange = NP CRange
|
||||||
|
|
||||||
public export
|
public export
|
||||||
cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat)
|
cRangeToList : {n : Nat} -> CRange n -> Either Nat (List Nat)
|
||||||
cRangeToBounds (One x) = Left (cast x)
|
cRangeToList (One x) = Left (cast x)
|
||||||
cRangeToBounds All = Right (0, n)
|
cRangeToList All = Right $ range 0 n
|
||||||
cRangeToBounds (StartBound x) = Right (cast x, n)
|
cRangeToList (StartBound x) = Right $ range (cast x) n
|
||||||
cRangeToBounds (EndBound x) = Right (0, cast x)
|
cRangeToList (EndBound x) = Right $ range 0 (cast x)
|
||||||
cRangeToBounds (Bounds x y) = Right (cast x, cast y)
|
cRangeToList (Bounds x y) = Right $ range (cast x) (cast y)
|
||||||
|
cRangeToList (Indices xs) = Right $ map cast xs
|
||||||
|
cRangeToList (Filter p) = Right $ map cast $ filter p $ toList Fin.range
|
||||||
|
|
||||||
|
|
||||||
public export
|
public export
|
||||||
newRank : {s : _} -> CoordsRange s -> Nat
|
newRank : {s : _} -> CoordsRange s -> Nat
|
||||||
newRank [] = 0
|
newRank [] = 0
|
||||||
newRank (r :: rs) = case cRangeToBounds r of
|
newRank (r :: rs) = case cRangeToList r of
|
||||||
Left _ => newRank rs
|
Left _ => newRank rs
|
||||||
Right _ => S (newRank rs)
|
Right _ => S (newRank rs)
|
||||||
|
|
||||||
||| Calculate the new shape given by a coordinate range.
|
||| Calculate the new shape given by a coordinate range.
|
||||||
public export
|
public export
|
||||||
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
|
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
|
||||||
newShape [] = []
|
newShape [] = []
|
||||||
newShape (r :: rs) with (cRangeToBounds r)
|
newShape (r :: rs) with (cRangeToList r)
|
||||||
newShape (r :: rs) | Left _ = newShape rs
|
newShape (r :: rs) | Left _ = newShape rs
|
||||||
newShape (r :: rs) | Right (x,y) = minus y x :: newShape rs
|
newShape (r :: rs) | Right xs = length xs :: newShape rs
|
||||||
|
|
||||||
|
|
||||||
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
|
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
|
||||||
getNewPos [] [] = []
|
getNewPos [] [] = []
|
||||||
getNewPos (r :: rs) (i :: is) with (cRangeToBounds r)
|
getNewPos (r :: rs) (i :: is) with (cRangeToList r)
|
||||||
_ | Left _ = getNewPos rs is
|
_ | Left _ = getNewPos rs is
|
||||||
_ | Right (x, _) = minus i x :: getNewPos rs is
|
_ | Right xs = cast (assert_total $ case findIndex (==i) xs of Just x => x)
|
||||||
|
:: getNewPos rs is
|
||||||
|
|
||||||
export
|
export
|
||||||
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
|
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
|
||||||
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
|
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
|
||||||
where
|
where
|
||||||
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
|
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
|
||||||
go [] = pure []
|
go [] = pure []
|
||||||
go (r :: rs) = [| (case cRangeToBounds r of
|
go (r :: rs) = [| (either pure id (cRangeToList r)) :: go rs |]
|
||||||
Left x => pure x
|
|
||||||
Right (x,y) => [x,S x..pred y]) :: go rs |]
|
|
||||||
|
|
Loading…
Reference in a new issue