Create coordinate type
This commit is contained in:
parent
a95f38202c
commit
b0e4253b88
|
@ -3,27 +3,23 @@ module Data.NumIdr.Array.Array
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.NumIdr.PrimArray
|
import Data.NumIdr.PrimArray
|
||||||
import Data.NumIdr.Array.Order
|
import Data.NumIdr.Array.Order
|
||||||
|
import Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
data Array : Vect rk Nat -> Type -> Type where
|
data Array : Vect rk Nat -> Type -> Type where
|
||||||
MkArray : (ord : Order rk) -> (sts : Vect rk Nat) ->
|
MkArray : (sts : Vect rk Nat) -> (s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||||
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
getPrim : Array s a -> PrimArray a
|
getPrim : Array s a -> PrimArray a
|
||||||
getPrim (MkArray _ _ _ arr) = arr
|
getPrim (MkArray _ _ arr) = arr
|
||||||
|
|
||||||
export
|
|
||||||
getOrder : Array {rk} s a -> Order rk
|
|
||||||
getOrder (MkArray ord _ _ _) = ord
|
|
||||||
|
|
||||||
export
|
export
|
||||||
getStrides : Array {rk} s a -> Vect rk Nat
|
getStrides : Array {rk} s a -> Vect rk Nat
|
||||||
getStrides (MkArray _ sts _ _) = sts
|
getStrides (MkArray sts _ _) = sts
|
||||||
|
|
||||||
export
|
export
|
||||||
size : Array s a -> Nat
|
size : Array s a -> Nat
|
||||||
|
@ -31,29 +27,56 @@ size = length . getPrim
|
||||||
|
|
||||||
export
|
export
|
||||||
shape : Array {rk} s a -> Vect rk Nat
|
shape : Array {rk} s a -> Vect rk Nat
|
||||||
shape (MkArray _ _ s _) = s
|
shape (MkArray _ s _) = s
|
||||||
|
|
||||||
export
|
export
|
||||||
rank : Array s a -> Nat
|
rank : Array s a -> Nat
|
||||||
rank arr = length $ shape arr
|
rank = length . shape
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
fromVect' : (s : Vect rk Nat) -> Order rk -> Vect (product s) a -> Array s a
|
fromVect' : (s : Vect rk Nat) -> Order rk -> Vect (product s) a -> Array s a
|
||||||
fromVect' s ord v = MkArray (MkReify s) (calcStrides ord s) (fromList $ toList v)
|
fromVect' s ord v = MkArray (calcStrides ord s) s (fromList $ toList v)
|
||||||
|
|
||||||
export
|
export
|
||||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||||
fromVect s = fromVect' s $ rewrite sym (lengthCorrect s) in COrder {rk = length s}
|
fromVect s = fromVect' s (orderOfShape s COrder)
|
||||||
|
|
||||||
public export
|
|
||||||
Vects : Vect rk Nat -> Type -> Type
|
|
||||||
Vects [] a = a
|
|
||||||
Vects (d::s) a = Vect d (Vects s a)
|
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
reshape' : (s' : Vect rk' Nat) -> Order rk -> Array {rk} s a ->
|
array' : (s : Vect rk Nat) -> Order rk -> Vects s a -> Array s a
|
||||||
product s = product s' => Array rk' s' a
|
array' s ord v = MkArray sts s (unsafeFromIns (product s) ins)
|
||||||
reshape' s' ord arr = MkArray
|
where
|
||||||
|
sts : Vect rk Nat
|
||||||
|
sts = calcStrides ord s
|
||||||
|
|
||||||
|
ins : List (Nat, a)
|
||||||
|
ins = collapse $ mapWithIndex (\i,x => (sum $ zipWith (*) i sts, x)) v
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
reshape' : (s' : Vect rk' Nat) -> Order rk' -> Array {rk} s a ->
|
||||||
|
product s = product s' => Array s' a
|
||||||
|
reshape' s' ord' arr = MkArray (calcStrides ord' s') s' (getPrim arr)
|
||||||
|
|
||||||
|
export
|
||||||
|
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
||||||
|
product s = product s' => Array s' a
|
||||||
|
reshape s' = reshape' s' (orderOfShape s' COrder)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
index : Coords s -> Array s a -> a
|
||||||
|
index is arr = unsafeIndex (computeLoc (getStrides arr) is) (getPrim arr)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
test : Array [2,2,3] Int
|
||||||
|
test = array' _ FOrder [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]
|
||||||
|
|
||||||
|
export
|
||||||
|
main : IO ()
|
||||||
|
main = do
|
||||||
|
printLn $ index [0,1,0] test
|
||||||
|
|
49
src/Data/NumIdr/Array/Coords.idr
Normal file
49
src/Data/NumIdr/Array/Coords.idr
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
module Data.NumIdr.Array.Coords
|
||||||
|
|
||||||
|
import Data.Vect
|
||||||
|
|
||||||
|
%default total
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
data Coords : (s : Vect rk Nat) -> Type where
|
||||||
|
Nil : Coords Nil
|
||||||
|
(::) : Fin dim -> Coords s -> Coords (dim :: s)
|
||||||
|
|
||||||
|
export
|
||||||
|
toNats : Coords {rk} s -> Vect rk Nat
|
||||||
|
toNats [] = []
|
||||||
|
toNats (i :: is) = finToNat i :: toNats is
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
Vects : Vect rk Nat -> Type -> Type
|
||||||
|
Vects [] a = a
|
||||||
|
Vects (d::s) a = Vect d (Vects s a)
|
||||||
|
|
||||||
|
export
|
||||||
|
collapse : {s : _} -> Vects s a -> List a
|
||||||
|
collapse {s=[]} = (::[])
|
||||||
|
collapse {s=_::_} = concat . map collapse
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b
|
||||||
|
mapWithIndex {s=[]} f x = f [] x
|
||||||
|
mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v
|
||||||
|
where
|
||||||
|
mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b
|
||||||
|
mapWithIndex' f [] = []
|
||||||
|
mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
index : Coords s -> Vects s a -> a
|
||||||
|
index [] x = x
|
||||||
|
index (i::is) v = index is $ index i v
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
computeLoc : Vect rk Nat -> Coords {rk} s -> Nat
|
||||||
|
computeLoc sts is = sum $ zipWith (*) sts (toNats is)
|
|
@ -6,17 +6,16 @@ import Data.Permutation
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
export
|
public export
|
||||||
Order : (rk : Nat) -> Type
|
data Order : (rk : Nat) -> Type where
|
||||||
Order = Permutation
|
COrder : Order rk
|
||||||
|
FOrder : Order rk
|
||||||
|
|
||||||
|
|
||||||
export
|
export
|
||||||
COrder : {rk : Nat} -> Order rk
|
orderOfShape : (0 s : Vect rk Nat) -> Order (length s) -> Order rk
|
||||||
COrder = identity
|
orderOfShape s ord = rewrite sym (lengthCorrect s) in ord
|
||||||
|
|
||||||
export
|
|
||||||
FOrder : {rk : Nat} -> Order rk
|
|
||||||
FOrder = reversed
|
|
||||||
|
|
||||||
|
|
||||||
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
||||||
|
@ -27,4 +26,5 @@ scanr f q0 (x::xs) = f x (head qs) :: qs
|
||||||
|
|
||||||
export
|
export
|
||||||
calcStrides : Order rk -> Vect rk Nat -> Vect rk Nat
|
calcStrides : Order rk -> Vect rk Nat -> Vect rk Nat
|
||||||
calcStrides ord v = permuteVect ord $ tail $ scanr (*) 1 v
|
calcStrides COrder v = tail $ scanr (*) 1 v
|
||||||
|
calcStrides FOrder v = init $ scanl (*) 1 v
|
||||||
|
|
Loading…
Reference in a new issue