diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index abb363c..6e5eaf6 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -46,6 +46,11 @@ data Array : (s : Vect rk Nat) -> (a : Type) -> Type where %name Array arr +export +unsafeMkArray : Order -> Vect rk Nat -> (s : Vect rk Nat) -> PrimArray a -> Array s a +unsafeMkArray = MkArray + + -------------------------------------------------------------------------------- -- Properties of arrays -------------------------------------------------------------------------------- diff --git a/src/Data/NumIdr/LArray.idr b/src/Data/NumIdr/LArray.idr new file mode 100644 index 0000000..1360e34 --- /dev/null +++ b/src/Data/NumIdr/LArray.idr @@ -0,0 +1,149 @@ +module Data.NumIdr.LArray + +import Data.Vect +import Data.NumIdr.PrimArray +import Data.NumIdr.Array + +%default total + + +export +data LArray : Vect rk Nat -> Type -> Type where + MkLArray : (ord : Order) -> (sts : Vect rk Nat) -> + (s : Vect rk Nat) -> PrimArray a -> LArray s a + +thaw : Array s a -> LArray s a +thaw {s} arr with (viewShape arr) + _ | Shape s = MkLArray (getOrder arr) (strides arr) s (getPrim arr) + +export +freeze : (1 _ : LArray s a) -> Array s a +freeze (MkLArray ord sts s arr) = unsafeMkArray ord sts s arr + + +-------------------------------------------------------------------------------- +-- Constructor +-------------------------------------------------------------------------------- + + +export +repeatL : (s : Vect rk Nat) -> a -> ((1 _ : LArray s a) -> LArray s' b) -> Array s' b +repeatL s x f = freeze $ f (thaw $ repeat s x) + + +export +zerosL : Num a => (s : Vect rk Nat) -> ((1 _ : LArray s a) -> LArray s' b) -> Array s' b +zerosL s f = freeze $ f (thaw $ zeros s) + +export +onesL : Num a => (s : Vect rk Nat) -> ((1 _ : LArray s a) -> LArray s' b) -> Array s' b +onesL s f = freeze $ f (thaw $ ones s) + + +export +fromVectL : (s : Vect rk Nat) -> Vect (product s) a -> ((1 _ : LArray s a) -> LArray s' b) -> Array s' b +fromVectL s v f = freeze $ f (thaw $ fromVect s v) + +export +fromStreamL : (s : Vect rk Nat) -> Stream a -> ((1 _ : LArray s a) -> LArray s' b) -> Array s' b +fromStreamL s xs f = freeze $ f (thaw $ fromStream s xs) + +export +copyToL : Array s a -> ((1 _ : LArray s a) -> LArray s' b) -> Array s' b +copyToL {s} arr f with (viewShape arr) + _ | Shape s = + let ord = getOrder arr + sts = strides arr + prim = copy $ getPrim arr + in freeze $ f (thaw $ unsafeMkArray ord sts s prim) + + +-------------------------------------------------------------------------------- +-- Indexing +-------------------------------------------------------------------------------- + + +infixl 10 !! +infixl 10 !? +infixl 10 !# +infixl 11 !!.. +infixl 11 !?.. +infixl 11 !#.. + + +||| Index the array using the given coordinates. +export +index : Coords s -> (1 _ : LArray s a) -> Res a (const $ LArray s a) +index is (MkLArray ord sts s arr) = + index (getLocation sts is) arr + # MkLArray ord sts s arr + +||| Index the array using the given coordinates. +||| +||| This is the operator form of `index`. +export %inline +(!!) : (1 _ : LArray s a) -> Coords s -> Res a (const $ LArray s a) +arr !! is = index is arr + +||| Update the entry at the given coordinates using the function. +export +indexUpdate : Coords s -> (a -> a) -> (1 _ : LArray s a) -> LArray s a +indexUpdate is f (MkLArray ord sts s arr) = + MkLArray ord sts s (unsafeUpdateInPlace (getLocation sts is) f arr) + +||| Set the entry at the given coordinates to the given value. +export +indexSet : Coords s -> a -> (1 _ : LArray s a) -> LArray s a +indexSet is = indexUpdate is . const + + +||| Index the array using the given coordinates, returning `Nothing` if the +||| coordinates are out of bounds. +export +indexNB : Vect rk Nat -> (1 _ : LArray {rk} s a) -> Res (Maybe a) (const $ LArray s a) +indexNB is (MkLArray ord sts s arr) = + (if all id $ zipWith (<) is s + then Just $ index (getLocation' sts is) arr + else Nothing) + # MkLArray ord sts s arr + +||| Index the array using the given coordinates, returning `Nothing` if the +||| coordinates are out of bounds. +||| +||| This is the operator form of `indexNB`. +export %inline +(!?) : (1 _ : LArray {rk} s a) -> Vect rk Nat -> Res (Maybe a) (const $ LArray s a) +arr !? is = indexNB is arr + +||| Update the entry at the given coordinates using the function. `Nothing` is +||| returned if the coordinates are out of bounds. +export +indexUpdateNB : Vect rk Nat -> (a -> a) -> LArray {rk} s a -> Maybe (LArray s a) +indexUpdateNB is f (MkLArray ord sts s arr) = + if all id $ zipWith (<) is s + then Just $ MkLArray ord sts s (unsafeUpdateInPlace (getLocation' sts is) f arr) + else Nothing + +||| Set the entry at the given coordinates using the function. `Nothing` is +||| returned if the coordinates are out of bounds. +export +indexSetNB : Vect rk Nat -> a -> LArray {rk} s a -> Maybe (LArray s a) +indexSetNB is = indexUpdateNB is . const + +||| Index the array using the given coordinates. +||| WARNING: This function does not perform any bounds check on its inputs. +||| Misuse of this function can easily break memory safety. +export +indexUnsafe : Vect rk Nat -> (1 _ : LArray {rk} s a) -> Res a (const $ LArray s a) +indexUnsafe is (MkLArray ord sts s arr) = + index (getLocation' sts is) arr + # MkLArray ord sts s arr + +||| Index the array using the given coordinates. +||| WARNING: This function does not perform any bounds check on its inputs. +||| Misuse of this function can easily break memory safety. +||| +||| This is the operator form of `indexUnsafe`. +export %inline +(!#) : (1 _ : LArray {rk} s a) -> Vect rk Nat -> Res a (const $ LArray s a) +arr !# is = indexUnsafe is arr diff --git a/src/Data/NumIdr/PrimArray.idr b/src/Data/NumIdr/PrimArray.idr index fe4fa83..0d66593 100644 --- a/src/Data/NumIdr/PrimArray.idr +++ b/src/Data/NumIdr/PrimArray.idr @@ -89,6 +89,13 @@ updateAt n f arr = if n >= length arr then arr else arrayDataSet n (f x) cpy.content pure cpy +export +unsafeUpdateInPlace : Nat -> (a -> a) -> PrimArray a -> PrimArray a +unsafeUpdateInPlace n f arr = unsafePerformIO $ do + x <- arrayDataGet n arr.content + arrayDataSet n (f x) arr.content + pure arr + ||| Convert a primitive array to a list. export toList : PrimArray a -> List a