From b0e4253b888ddd76fc39b433b47b9d41d19d8602 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Fri, 13 May 2022 08:28:42 -0400 Subject: [PATCH] Create coordinate type --- src/Data/NumIdr/Array/Array.idr | 61 ++++++++++++++++++++++---------- src/Data/NumIdr/Array/Coords.idr | 49 +++++++++++++++++++++++++ src/Data/NumIdr/Array/Order.idr | 18 +++++----- 3 files changed, 100 insertions(+), 28 deletions(-) create mode 100644 src/Data/NumIdr/Array/Coords.idr diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 51ddb81..90ae18b 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -3,27 +3,23 @@ module Data.NumIdr.Array.Array import Data.Vect import Data.NumIdr.PrimArray import Data.NumIdr.Array.Order +import Data.NumIdr.Array.Coords %default total export data Array : Vect rk Nat -> Type -> Type where - MkArray : (ord : Order rk) -> (sts : Vect rk Nat) -> - (s : Vect rk Nat) -> PrimArray a -> Array s a + MkArray : (sts : Vect rk Nat) -> (s : Vect rk Nat) -> PrimArray a -> Array s a export getPrim : Array s a -> PrimArray a -getPrim (MkArray _ _ _ arr) = arr - -export -getOrder : Array {rk} s a -> Order rk -getOrder (MkArray ord _ _ _) = ord +getPrim (MkArray _ _ arr) = arr export getStrides : Array {rk} s a -> Vect rk Nat -getStrides (MkArray _ sts _ _) = sts +getStrides (MkArray sts _ _) = sts export size : Array s a -> Nat @@ -31,29 +27,56 @@ size = length . getPrim export shape : Array {rk} s a -> Vect rk Nat -shape (MkArray _ _ s _) = s +shape (MkArray _ s _) = s export rank : Array s a -> Nat -rank arr = length $ shape arr +rank = length . shape export fromVect' : (s : Vect rk Nat) -> Order rk -> Vect (product s) a -> Array s a -fromVect' s ord v = MkArray (MkReify s) (calcStrides ord s) (fromList $ toList v) +fromVect' s ord v = MkArray (calcStrides ord s) s (fromList $ toList v) export fromVect : (s : Vect rk Nat) -> Vect (product s) a -> Array s a -fromVect s = fromVect' s $ rewrite sym (lengthCorrect s) in COrder {rk = length s} +fromVect s = fromVect' s (orderOfShape s COrder) -public export -Vects : Vect rk Nat -> Type -> Type -Vects [] a = a -Vects (d::s) a = Vect d (Vects s a) export -reshape' : (s' : Vect rk' Nat) -> Order rk -> Array {rk} s a -> - product s = product s' => Array rk' s' a -reshape' s' ord arr = MkArray +array' : (s : Vect rk Nat) -> Order rk -> Vects s a -> Array s a +array' s ord v = MkArray sts s (unsafeFromIns (product s) ins) + where + sts : Vect rk Nat + sts = calcStrides ord s + + ins : List (Nat, a) + ins = collapse $ mapWithIndex (\i,x => (sum $ zipWith (*) i sts, x)) v + + +export +reshape' : (s' : Vect rk' Nat) -> Order rk' -> Array {rk} s a -> + product s = product s' => Array s' a +reshape' s' ord' arr = MkArray (calcStrides ord' s') s' (getPrim arr) + +export +reshape : (s' : Vect rk' Nat) -> Array {rk} s a -> + product s = product s' => Array s' a +reshape s' = reshape' s' (orderOfShape s' COrder) + + +export +index : Coords s -> Array s a -> a +index is arr = unsafeIndex (computeLoc (getStrides arr) is) (getPrim arr) + + +export +test : Array [2,2,3] Int +test = array' _ FOrder [[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]] + +export +main : IO () +main = do + printLn $ index [0,1,0] test diff --git a/src/Data/NumIdr/Array/Coords.idr b/src/Data/NumIdr/Array/Coords.idr new file mode 100644 index 0000000..8fcacb6 --- /dev/null +++ b/src/Data/NumIdr/Array/Coords.idr @@ -0,0 +1,49 @@ +module Data.NumIdr.Array.Coords + +import Data.Vect + +%default total + + +public export +data Coords : (s : Vect rk Nat) -> Type where + Nil : Coords Nil + (::) : Fin dim -> Coords s -> Coords (dim :: s) + +export +toNats : Coords {rk} s -> Vect rk Nat +toNats [] = [] +toNats (i :: is) = finToNat i :: toNats is + + + +public export +Vects : Vect rk Nat -> Type -> Type +Vects [] a = a +Vects (d::s) a = Vect d (Vects s a) + +export +collapse : {s : _} -> Vects s a -> List a +collapse {s=[]} = (::[]) +collapse {s=_::_} = concat . map collapse + + +export +mapWithIndex : {s : Vect rk Nat} -> (Vect rk Nat -> a -> b) -> Vects {rk} s a -> Vects s b +mapWithIndex {s=[]} f x = f [] x +mapWithIndex {s=_::_} f v = mapWithIndex' (\i => mapWithIndex (\is => f (i::is))) v + where + mapWithIndex' : {0 a,b : Type} -> (Nat -> a -> b) -> Vect n a -> Vect n b + mapWithIndex' f [] = [] + mapWithIndex' f (x::xs) = f Z x :: mapWithIndex' (f . S) xs + + +export +index : Coords s -> Vects s a -> a +index [] x = x +index (i::is) v = index is $ index i v + + +export +computeLoc : Vect rk Nat -> Coords {rk} s -> Nat +computeLoc sts is = sum $ zipWith (*) sts (toNats is) diff --git a/src/Data/NumIdr/Array/Order.idr b/src/Data/NumIdr/Array/Order.idr index 022ffdc..5032851 100644 --- a/src/Data/NumIdr/Array/Order.idr +++ b/src/Data/NumIdr/Array/Order.idr @@ -6,17 +6,16 @@ import Data.Permutation %default total -export -Order : (rk : Nat) -> Type -Order = Permutation +public export +data Order : (rk : Nat) -> Type where + COrder : Order rk + FOrder : Order rk + export -COrder : {rk : Nat} -> Order rk -COrder = identity +orderOfShape : (0 s : Vect rk Nat) -> Order (length s) -> Order rk +orderOfShape s ord = rewrite sym (lengthCorrect s) in ord -export -FOrder : {rk : Nat} -> Order rk -FOrder = reversed scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res @@ -27,4 +26,5 @@ scanr f q0 (x::xs) = f x (head qs) :: qs export calcStrides : Order rk -> Vect rk Nat -> Vect rk Nat -calcStrides ord v = permuteVect ord $ tail $ scanr (*) 1 v +calcStrides COrder v = tail $ scanr (*) 1 v +calcStrides FOrder v = init $ scanl (*) 1 v