Remove rank-index on Order type
This commit is contained in:
parent
a499d14e87
commit
76e16574f1
|
@ -31,20 +31,27 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where
|
||||||
||| *strides*, which determine how indexes into the internal array should be
|
||| *strides*, which determine how indexes into the internal array should be
|
||||||
||| performed. This is how the order of the array is configurable.
|
||| performed. This is how the order of the array is configurable.
|
||||||
|||
|
|||
|
||||||
||| @ s The shape of the array
|
||| @ ord The order of the elements of the array
|
||||||
||| @ sts The strides of the array
|
||| @ sts The strides of the array
|
||||||
MkArray : (s : Vect rk Nat) -> (sts : Vect rk Nat) -> PrimArray a -> Array s a
|
||| @ s The shape of the array
|
||||||
|
MkArray : (ord : Order) -> (sts : Vect rk Nat) ->
|
||||||
|
(s : Vect rk Nat) -> PrimArray a -> Array s a
|
||||||
|
|
||||||
|
|
||||||
||| Extract the primitive array value.
|
||| Extract the primitive array value.
|
||||||
export
|
export
|
||||||
getPrim : Array s a -> PrimArray a
|
getPrim : Array s a -> PrimArray a
|
||||||
getPrim (MkArray _ _ arr) = arr
|
getPrim (MkArray _ _ _ arr) = arr
|
||||||
|
|
||||||
|
||| The order of the elements of the array
|
||||||
|
export
|
||||||
|
getOrder : Array s a -> Order
|
||||||
|
getOrder (MkArray ord _ _ _) = ord
|
||||||
|
|
||||||
||| The strides of the array, returned in the same axis order as in the shape.
|
||| The strides of the array, returned in the same axis order as in the shape.
|
||||||
export
|
export
|
||||||
getStrides : Array {rk} s a -> Vect rk Nat
|
getStrides : Array {rk} s a -> Vect rk Nat
|
||||||
getStrides (MkArray _ sts _) = sts
|
getStrides (MkArray _ sts _ _) = sts
|
||||||
|
|
||||||
||| The total number of elements of the array
|
||| The total number of elements of the array
|
||||||
||| This is equivalent to `product s`.
|
||| This is equivalent to `product s`.
|
||||||
|
@ -55,7 +62,7 @@ size = length . getPrim
|
||||||
||| The shape of the array
|
||| The shape of the array
|
||||||
export
|
export
|
||||||
shape : Array {rk} s a -> Vect rk Nat
|
shape : Array {rk} s a -> Vect rk Nat
|
||||||
shape (MkArray s _ _) = s
|
shape (MkArray _ _ s _) = s
|
||||||
|
|
||||||
||| The rank of the array
|
||| The rank of the array
|
||||||
export
|
export
|
||||||
|
@ -69,8 +76,8 @@ rank = length . shape
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order to interpret the elements
|
||| @ ord The order to interpret the elements
|
||||||
export
|
export
|
||||||
fromVect' : (s : Vect rk Nat) -> (ord : Order rk) -> Vect (product s) a -> Array s a
|
fromVect' : (s : Vect rk Nat) -> (ord : Order) -> Vect (product s) a -> Array s a
|
||||||
fromVect' s ord v = MkArray s (calcStrides ord s) (fromList $ toList v)
|
fromVect' s ord v = MkArray ord (calcStrides ord s) s (fromList $ toList v)
|
||||||
|
|
||||||
||| Create an array given a vector of its elements. The elements of the vector
|
||| Create an array given a vector of its elements. The elements of the vector
|
||||||
||| are assembled into the provided shape using row-major order (the last axis is the
|
||| are assembled into the provided shape using row-major order (the last axis is the
|
||||||
|
@ -88,19 +95,19 @@ fromVect s = fromVect' s COrder
|
||||||
||| @ s The shape of the constructed array
|
||| @ s The shape of the constructed array
|
||||||
||| @ ord The order of the constructed array
|
||| @ ord The order of the constructed array
|
||||||
export
|
export
|
||||||
array' : (s : Vect rk Nat) -> (ord : Order rk) -> Vects s a -> Array s a
|
array' : (s : Vect rk Nat) -> (ord : Order) -> Vects s a -> Array s a
|
||||||
array' s ord v = MkArray s sts (unsafeFromIns (product s) ins)
|
array' s ord v = MkArray ord sts s (unsafeFromIns (product s) ins)
|
||||||
where
|
where
|
||||||
sts : Vect rk Nat
|
sts : Vect rk Nat
|
||||||
sts = calcStrides ord s
|
sts = calcStrides ord s
|
||||||
|
|
||||||
ins : List (Nat, a)
|
ins : List (Nat, a)
|
||||||
ins = collapse $ mapWithIndex (\i,x => (sum $ zipWith (*) i sts, x)) v
|
ins = collapse $ mapWithIndex (\is,x => (getLocation' sts is, x)) v
|
||||||
|
|
||||||
||| Construct an array using a structure of nested vectors.
|
||| Construct an array using a structure of nested vectors.
|
||||||
export
|
export
|
||||||
array : {s : _} -> Vects s a -> Array s a
|
array : {s : _} -> Vects s a -> Array s a
|
||||||
array v = MkArray s (calcStrides COrder s) (fromList $ collapse v)
|
array v = MkArray COrder (calcStrides COrder s) s (fromList $ collapse v)
|
||||||
|
|
||||||
|
|
||||||
||| Reshape the array into the given shape and reinterpret it according to
|
||| Reshape the array into the given shape and reinterpret it according to
|
||||||
|
@ -109,20 +116,17 @@ array v = MkArray s (calcStrides COrder s) (fromList $ collapse v)
|
||||||
||| @ s' The shape to convert the array to
|
||| @ s' The shape to convert the array to
|
||||||
||| @ ord The order to reinterpret the array by
|
||| @ ord The order to reinterpret the array by
|
||||||
export
|
export
|
||||||
reshape' : (s' : Vect rk' Nat) -> (ord : Order rk') -> Array {rk} s a ->
|
reshape' : (s' : Vect rk' Nat) -> (ord : Order) -> Array {rk} s a ->
|
||||||
product s = product s' => Array s' a
|
product s = product s' => Array s' a
|
||||||
reshape' s' ord' arr = MkArray s' (calcStrides ord' s') (getPrim arr)
|
reshape' s' ord' arr = MkArray ord' (calcStrides ord' s') s' (getPrim arr)
|
||||||
|
|
||||||
||| Reshape the array into the given shape.
|
||| Reshape the array into the given shape.
|
||||||
|||
|
|||
|
||||||
||| The array is also reinterpreted in row-major order; if this is undesirable,
|
|
||||||
||| then `reshape'` must be used instead.
|
|
||||||
|||
|
|
||||||
||| @ s' The shape to convert the array to
|
||| @ s' The shape to convert the array to
|
||||||
export
|
export
|
||||||
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
reshape : (s' : Vect rk' Nat) -> Array {rk} s a ->
|
||||||
product s = product s' => Array s' a
|
product s = product s' => Array s' a
|
||||||
reshape s' = reshape' s' COrder
|
reshape s' arr = reshape' s' (getOrder arr) arr
|
||||||
|
|
||||||
|
|
||||||
||| Index the array using the given `Coords` object.
|
||| Index the array using the given `Coords` object.
|
||||||
|
|
|
@ -46,8 +46,12 @@ index [] x = x
|
||||||
index (i::is) v = index is $ index i v
|
index (i::is) v = index is $ index i v
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
getLocation' : (sts : Vect rk Nat) -> (is : Vect rk Nat) -> Nat
|
||||||
|
getLocation' = sum .: zipWith (*)
|
||||||
|
|
||||||
||| Compute the memory location of an array element
|
||| Compute the memory location of an array element
|
||||||
||| given its coordinate and the strides of the array.
|
||| given its coordinate and the strides of the array.
|
||||||
export
|
export
|
||||||
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
getLocation : Vect rk Nat -> Coords {rk} s -> Nat
|
||||||
getLocation sts is = sum $ zipWith (*) sts (toNats is)
|
getLocation sts is = getLocation' sts (toNats is)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
module Data.NumIdr.Array.Order
|
module Data.NumIdr.Array.Order
|
||||||
|
|
||||||
import Data.Vect
|
import Data.Vect
|
||||||
import Data.Permutation
|
|
||||||
|
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
@ -10,24 +9,16 @@ import Data.Permutation
|
||||||
||| elements are stored in memory. Orders are used to calculate strides,
|
||| elements are stored in memory. Orders are used to calculate strides,
|
||||||
||| which provide a method of converting an array coordinate into a linear
|
||| which provide a method of converting an array coordinate into a linear
|
||||||
||| memory location.
|
||| memory location.
|
||||||
|||
|
|
||||||
||| @ rk The rank of the array this order applies to
|
|
||||||
public export
|
public export
|
||||||
data Order : (rk : Nat) -> Type where
|
data Order : Type where
|
||||||
|
|
||||||
||| C-like order, or contiguous order. This order stores elements in a
|
||| C-like order, or contiguous order. This order stores elements in a
|
||||||
||| row-major fashion (the last axis is the least significant).
|
||| row-major fashion (the last axis is the least significant).
|
||||||
COrder : Order rk
|
COrder : Order
|
||||||
|
|
||||||
||| Fortran-like order. This order stores elements in a column-major
|
||| Fortran-like order. This order stores elements in a column-major
|
||||||
||| fashion (the first axis is the least significant).
|
||| fashion (the first axis is the least significant).
|
||||||
FOrder : Order rk
|
FOrder : Order
|
||||||
|
|
||||||
|
|
||||||
export
|
|
||||||
orderOfShape : (0 s : Vect rk Nat) -> Order (length s) -> Order rk
|
|
||||||
orderOfShape s ord = rewrite sym (lengthCorrect s) in ord
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
scanr : (el -> res -> res) -> res -> Vect len el -> Vect (S len) res
|
||||||
|
@ -38,7 +29,7 @@ scanr f q0 (x::xs) = f x (head qs) :: qs
|
||||||
|
|
||||||
||| Calculate an array's strides given its order and shape.
|
||| Calculate an array's strides given its order and shape.
|
||||||
export
|
export
|
||||||
calcStrides : Order rk -> Vect rk Nat -> Vect rk Nat
|
calcStrides : Order -> Vect rk Nat -> Vect rk Nat
|
||||||
calcStrides _ [] = []
|
calcStrides _ [] = []
|
||||||
calcStrides COrder v@(_::_) = scanr (*) 1 $ tail v
|
calcStrides COrder v@(_::_) = scanr (*) 1 $ tail v
|
||||||
calcStrides FOrder v@(_::_) = scanl (*) 1 $ init v
|
calcStrides FOrder v@(_::_) = scanl (*) 1 $ init v
|
||||||
|
|
|
@ -1,32 +0,0 @@
|
||||||
module Data.Permutation
|
|
||||||
|
|
||||||
import Data.Vect
|
|
||||||
|
|
||||||
%default total
|
|
||||||
|
|
||||||
|
|
||||||
||| A permutation of `n` elements represented as a vector of indices.
|
|
||||||
||| For example, `[1,2,0]` is a permutation that maps element `0` to
|
|
||||||
||| element `1`, element `1` to element `2`, and element `2` to element `0`.
|
|
||||||
public export
|
|
||||||
Permutation : (n : Nat) -> Type
|
|
||||||
Permutation n = Vect n (Fin n)
|
|
||||||
|
|
||||||
|
|
||||||
||| The identity permutation.
|
|
||||||
public export
|
|
||||||
identity : {n : _} -> Permutation n
|
|
||||||
identity {n=Z} = []
|
|
||||||
identity {n=S n} = FZ :: map FS identity
|
|
||||||
|
|
||||||
||| The permutation that reverses the order of elements.
|
|
||||||
public export
|
|
||||||
reversed : {n : _} -> Permutation n
|
|
||||||
reversed {n=Z} = []
|
|
||||||
reversed {n=S n} = last :: map weaken reversed
|
|
||||||
|
|
||||||
|
|
||||||
||| Apply a permutation to a vector.
|
|
||||||
public export
|
|
||||||
permuteVect : Permutation n -> Vect n a -> Vect n a
|
|
||||||
permuteVect p v = map (\i => index i v) p
|
|
Loading…
Reference in a new issue