Reorganize Data.NumIdr.Array.Coords

This commit is contained in:
Kiana Sheibani 2022-08-29 13:55:02 -04:00
parent 1d6b9d3be9
commit 9ece7e6963
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -8,6 +8,23 @@ import Data.NP
%default total
-- A Nat-based range function with better semantics
public export
range : Nat -> Nat -> List Nat
range x y = if x < y then assert_total $ takeBefore (>= y) (countFrom x S)
else []
-- helpful theorems for working with ranges
export
rangeLen : (x,y : Nat) -> length (range x y) = minus y x
rangeLen x y = believe_me $ Refl {x = minus y x}
export
rangeLenZ : (x : Nat) -> length (range 0 x) = x
rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x
--------------------------------------------------------------------------------
-- Array coordinate types
--------------------------------------------------------------------------------
@ -26,60 +43,7 @@ toNB [] = []
toNB (i :: is) = finToNat i :: toNB is
public export
Vects : Vect rk Nat -> Type -> Type
Vects [] a = a
Vects (d::s) a = Vect d (Vects s a)
export
collapse : {s : _} -> Vects s a -> List a
collapse {s=[]} = pure
collapse {s=_::_} = concat . map collapse
export
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
mapWithIndex {s=[]} f x = f [] x
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
where
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
mapWithIndex' f [] = []
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
export
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
getLocation' = sum .: zipWith (*)
||| Compute the memory location of an array element
||| given its coordinate and the strides of the array.
export
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
getLocation sts is = getLocation' sts (toNB is)
--------------------------------------------------------------------------------
-- Array coordinate types
--------------------------------------------------------------------------------
-- A Nat-based range function with better semantics
public export
range : Nat -> Nat -> List Nat
range x y = if x < y then assert_total $ takeBefore (>= y) (countFrom x S)
else []
-- helpful theorems for working with ranges
export
rangeLen : (x,y : Nat) -> length (range x y) = minus y x
rangeLen x y = believe_me $ Refl {x = minus y x}
export
rangeLenZ : (x : Nat) -> length (range 0 x) = x
rangeLenZ x = rangeLen 0 x `trans` minusZeroRight x
namespace Strict
public export
data CRange : Nat -> Type where
One : Fin n -> CRange n
@ -100,6 +64,56 @@ public export
CoordsRange : (s : Vect rk Nat) -> Type
CoordsRange = NP CRange
namespace NB
public export
data CRangeNB : Type where
All : CRangeNB
StartBound : Nat -> CRangeNB
EndBound : Nat -> CRangeNB
Bounds : Nat -> Nat -> CRangeNB
Indices : List Nat -> CRangeNB
Filter : (Nat -> Bool) -> CRangeNB
--------------------------------------------------------------------------------
-- Indexing helper functions
--------------------------------------------------------------------------------
public export
Vects : Vect rk Nat -> Type -> Type
Vects [] a = a
Vects (d::s) a = Vect d (Vects s a)
export
collapse : {s : _} -> Vects s a -> List a
collapse {s=[]} = pure
collapse {s=_::_} = concat . map collapse
export
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects s a -> Vects s b
mapWithIndex {s=[]} f x = f [] x
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
where
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
mapWithIndex' f [] = []
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
export
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
getLocation' = sum .: zipWith (*)
||| Compute the memory location of an array element
||| given its coordinate and the strides of the array.
export
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
getLocation sts is = getLocation' sts (toNB is)
namespace Strict
public export
cRangeToList : {n : Nat} -> CRange n -> Either Nat (List Nat)
cRangeToList (One x) = Left (cast x)
@ -139,5 +153,33 @@ getCoordsList : {s : Vect rk Nat} -> (rs : CoordsRange s) -> List (Vect rk Nat,
getCoordsList rs = map (\is => (is, getNewPos rs is)) $ go rs
where
go : {0 rk : _} -> {s : Vect rk Nat} -> CoordsRange s -> List (Vect rk Nat)
go [] = pure []
go [] = [[]]
go (r :: rs) = [| (either pure id (cRangeToList r)) :: go rs |]
namespace NB
export
cRangeNBToList : Nat -> CRangeNB -> List Nat
cRangeNBToList s All = range 0 s
cRangeNBToList s (StartBound x) = range x s
cRangeNBToList s (EndBound x) = range 0 x
cRangeNBToList s (Bounds x y) = range x (minimum y s)
cRangeNBToList s (Indices xs) = filter (<s) xs
cRangeNBToList s (Filter p) = filter p $ range 0 s
export
newShape : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat
newShape = zipWith (length .: cRangeNBToList)
export
getNewPos : Vect rk Nat -> Vect rk CRangeNB -> Vect rk Nat -> Vect rk Nat
getNewPos = zipWith3 (\d,r,i => assert_total $
case findIndex (==i) (cRangeNBToList d r) of Just x => cast x)
export
getCoordsList : Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat, Vect rk Nat)
getCoordsList s rs = map (\is => (is, getNewPos s rs is)) $ go s rs
where
go : {0 rk : _} -> Vect rk Nat -> Vect rk CRangeNB -> List (Vect rk Nat)
go [] [] = [[]]
go (d :: s) (r :: rs) = [| cRangeNBToList d r :: go s rs |]