Refactor Data.NumIdr.Array.Coords

This commit is contained in:
Kiana Sheibani 2022-08-27 19:11:13 -04:00
parent 8f1eef25dc
commit 1d6b9d3be9
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -1,117 +1,143 @@
module Data.NumIdr.Array.Coords module Data.NumIdr.Array.Coords
import Data.Either import Data.Either
import Data.List
import Data.Vect import Data.Vect
import Data.NP import Data.NP
%default total %default total
namespace Coords --------------------------------------------------------------------------------
-- Array coordinate types
||| A type-safe coordinate system for an array. The coordinates are --------------------------------------------------------------------------------
||| values of `Fin dim`, where `dim` is the dimension of each axis.
public export
Coords : (s : Vect rk Nat) -> Type
Coords = NP Fin
||| Forget the shape of the array by converting each index to type `Nat`.
export
toNB : Coords {rk} s -> Vect rk Nat
toNB [] = []
toNB (i :: is) = finToNat i :: toNB is
public export ||| A type-safe coordinate system for an array. The coordinates are
Vects : Vect rk Nat -> Type -> Type ||| values of `Fin dim`, where `dim` is the dimension of each axis.
Vects [] a = a public export
Vects (d::s) a = Vect d (Vects s a) Coords : (s : Vect rk Nat) -> Type
Coords = NP Fin
export ||| Forget the shape of the array by converting each index to type `Nat`.
collapse : {s : _} -> Vects s a -> List a export
collapse {s=[]} = pure toNB : Coords {rk} s -> Vect rk Nat
collapse {s=_::_} = concat . map collapse toNB [] = []
toNB (i :: is) = finToNat i :: toNB is
export public export
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b Vects : Vect rk Nat -> Type -> Type
mapWithIndex {s=[]} f x = f [] x Vects [] a = a
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v Vects (d::s) a = Vect d (Vects s a)
where
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b export
mapWithIndex' f [] = [] collapse : {s : _} -> Vects s a -> List a
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs collapse {s=[]} = pure
collapse {s=_::_} = concat . map collapse
export export
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
getLocation' = sum .: zipWith (*) mapWithIndex {s=[]} f x = f [] x
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
||| Compute the memory location of an array element where
||| given its coordinate and the strides of the array. mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
export mapWithIndex' f [] = []
getLocation : Vect rk Nat -> Coords {rk} s -> Nat mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
getLocation sts is = getLocation' sts (toNB is)
namespace CoordsRange export
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
getLocation' = sum .: zipWith (*)
public export ||| Compute the memory location of an array element
data CRange : Nat -> Type where ||| given its coordinate and the strides of the array.
One : Fin n -> CRange n export
All : CRange n getLocation : Vect rk Nat -> Coords {rk} s -> Nat
StartBound : Fin (S n) -> CRange n getLocation sts is = getLocation' sts (toNB is)
EndBound : Fin (S n) -> CRange n
Bounds : Fin (S n) -> Fin (S n) -> CRange n
infix 0 ...
public export
(...) : Fin (S n) -> Fin (S n) -> CRange n
(...) = Bounds
public export --------------------------------------------------------------------------------
CoordsRange : (s : Vect rk Nat) -> Type -- Array coordinate types
CoordsRange = NP CRange --------------------------------------------------------------------------------
public export
cRangeToBounds : {n : Nat} -> CRange n -> Either Nat (Nat, Nat)
cRangeToBounds (One x) = Left (cast x)
cRangeToBounds All = Right (0, n)
cRangeToBounds (StartBound x) = Right (cast x, n)
cRangeToBounds (EndBound x) = Right (0, cast x)
cRangeToBounds (Bounds x y) = Right (cast x, cast y)
public export -- A Nat-based range function with better semantics
newRank : {s : _} -> CoordsRange s -> Nat public export
newRank [] = 0 range : Nat -> Nat -> List Nat
newRank (r :: rs) = case cRangeToBounds r of range x y = if x < y then assert_total $ takeBefore (>= y) (countFrom x S)
Left _ => newRank rs else []
Right _ => S (newRank rs)
||| Calculate the new shape given by a coordinate range. -- helpful theorems for working with ranges
public export
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat export
newShape [] = [] rangeLen : (x,y : Nat) -> length (range x y) = minus y x
newShape (r :: rs) with (cRangeToBounds r) rangeLen x y = believe_me $ Refl {x = minus y x}
newShape (r :: rs) | Left _ = newShape rs
newShape (r :: rs) | Right (x,y) = minus y x :: newShape rs export
rangeLenZ : (x : Nat) -> length (range 0 x) = x
rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat public export
getNewPos [] [] = [] data CRange : Nat -> Type where
getNewPos (r :: rs) (i :: is) with (cRangeToBounds r) One : Fin n -> CRange n
_ | Left _ = getNewPos rs is All : CRange n
_ | Right (x, _) = minus i x :: getNewPos rs is StartBound : Fin (S n) -> CRange n
EndBound : Fin (S n) -> CRange n
Bounds : Fin (S n) -> Fin (S n) -> CRange n
Indices : List (Fin n) -> CRange n
Filter : (Fin n -> Bool) -> CRange n
export infix 0 ...
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat) public export
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs (...) : Fin (S n) -> Fin (S n) -> CRange n
where (...) = Bounds
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
go [] = pure []
go (r :: rs) = [| (case cRangeToBounds r of public export
Left x => pure x CoordsRange : (s : Vect rk Nat) -> Type
Right (x,y) => [x,S x..pred y]) :: go rs |] CoordsRange = NP CRange
public export
cRangeToList : {n : Nat} -> CRange n -> Either Nat (List Nat)
cRangeToList (One x) = Left (cast x)
cRangeToList All = Right $ range 0 n
cRangeToList (StartBound x) = Right $ range (cast x) n
cRangeToList (EndBound x) = Right $ range 0 (cast x)
cRangeToList (Bounds x y) = Right $ range (cast x) (cast y)
cRangeToList (Indices xs) = Right $ map cast xs
cRangeToList (Filter p) = Right $ map cast $ filter p $ toList Fin.range
public export
newRank : {s : _} -> CoordsRange s -> Nat
newRank [] = 0
newRank (r :: rs) = case cRangeToList r of
Left _ => newRank rs
Right _ => S (newRank rs)
||| Calculate the new shape given by a coordinate range.
public export
newShape : {s : _} -> (rs : CoordsRange s) -> Vect (newRank rs) Nat
newShape [] = []
newShape (r :: rs) with (cRangeToList r)
newShape (r :: rs) | Left _ = newShape rs
newShape (r :: rs) | Right xs = length xs :: newShape rs
getNewPos : {s : _} -> (rs : CoordsRange {rk} s) -> Vect rk Nat -> Vect (newRank rs) Nat
getNewPos [] [] = []
getNewPos (r :: rs) (i :: is) with (cRangeToList r)
_ | Left _ = getNewPos rs is
_ | Right xs = cast (assert_total $ case findIndex (==i) xs of Just x => x)
:: getNewPos rs is
export
getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat, Vect (newRank rs) Nat)
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
where
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
go [] = pure []
go (r :: rs) = [| (either pure id (cRangeToList r)) :: go rs |]