Create coordinate type
This commit is contained in:
parent
a95f38202c
commit
b0e4253b88
3 changed files with 100 additions and 28 deletions
|
|
@ -3,27 +3,23 @@ module Data.NumIdr.Array.Array
|
|||
import Data.Vect
|
||||
import Data.NumIdr.PrimArray
|
||||
import Data.NumIdr.Array.Order
|
||||
import Data.NumIdr.Array.Coords
|
||||
|
||||
%default total
|
||||
|
||||
|
||||
export
|
||||
data Array : Vect rk Nat -> Type -> Type where
|
||||
MkArray : (ord : Order rk) -> (sts : Vect rk Nat) ->
|
||||
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||
MkArray : (sts : Vect rk Nat) -> (s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||
|
||||
|
||||
export
|
||||
getPrim : Array s a -> PrimArray a
|
||||
getPrim (MkArray _ _ _ arr) = arr
|
||||
|
||||
export
|
||||
getOrder : Array {rk} s a -> Order rk
|
||||
getOrder (MkArray ord _ _ _) = ord
|
||||
getPrim (MkArray _ _ arr) = arr
|
||||
|
||||
export
|
||||
getStrides : Array {rk} s a -> Vect rk Nat
|
||||
getStrides (MkArray _ sts _ _) = sts
|
||||
getStrides (MkArray sts _ _) = sts
|
||||
|
||||
export
|
||||
size : Array s a -> Nat
|
||||
|
|
@ -31,29 +27,56 @@ size = length . getPrim
|
|||
|
||||
export
|
||||
shape : Array {rk} s a -> Vect rk Nat
|
||||
shape (MkArray _ _ s _) = s
|
||||
shape (MkArray _ s _) = s
|
||||
|
||||
export
|
||||
rank : Array s a -> Nat
|
||||
rank arr = length $ shape arr
|
||||
rank = length . shape
|
||||
|
||||
|
||||
|
||||
export
|
||||
fromVect' : (s : Vect rk Nat) -> Order rk -> Vect (product s) a -> Array s a
|
||||
fromVect' s ord v = MkArray (MkReify s) (calcStrides ord s) (fromList $ toList v)
|
||||
fromVect' s ord v = MkArray (calcStrides ord s) s (fromList $ toList v)
|
||||
|
||||
export
|
||||
fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a
|
||||
fromVect s = fromVect' s $ rewrite sym (lengthCorrect s) in COrder {rk = length s}
|
||||
fromVect s = fromVect' s (orderOfShape s COrder)
|
||||
|
||||
public export
|
||||
Vects : Vect rk Nat -> Type -> Type
|
||||
Vects [] a = a
|
||||
Vects (d::s) a = Vect d (Vects s a)
|
||||
|
||||
|
||||
export
|
||||
reshape' : (s' : Vect rk' Nat) -> Order rk -> Array {rk} s a ->
|
||||
product s = product s' => Array rk' s' a
|
||||
reshape' s' ord arr = MkArray
|
||||
array' : (s : Vect rk Nat) -> Order rk -> Vects s a -> Array s a
|
||||
array' s ord v = MkArray sts s (unsafeFromIns (product s) ins)
|
||||
where
|
||||
sts : Vect rk Nat
|
||||
sts = calcStrides ord s
|
||||
|
||||
ins : List (Nat, a)
|
||||
ins = collapse $ mapWithIndex (\i,x => (sum $ zipWith (*) i sts, x)) v
|
||||
|
||||
|
||||
export
|
||||
reshape' : (s' : Vect rk' Nat) -> Order rk' -> Array {rk} s a ->
|
||||
product s = product s' => Array s' a
|
||||
reshape' s' ord' arr = MkArray (calcStrides ord' s') s' (getPrim arr)
|
||||
|
||||
export
|
||||
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
||||
product s = product s' => Array s' a
|
||||
reshape s' = reshape' s' (orderOfShape s' COrder)
|
||||
|
||||
|
||||
export
|
||||
index : Coords s -> Array s a -> a
|
||||
index is arr = unsafeIndex (computeLoc (getStrides arr) is) (getPrim arr)
|
||||
|
||||
|
||||
export
|
||||
test : Array [2,2,3] Int
|
||||
test = array' _ FOrder [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]
|
||||
|
||||
export
|
||||
main : IO ()
|
||||
main = do
|
||||
printLn $ index [0,1,0] test
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue