numidr/src/Data/NumIdr/PrimArray.idr
2022-06-25 00:58:36 -04:00

166 lines
4.8 KiB
Idris

module Data.NumIdr.PrimArray
import Data.Nat
import Data.IORef
import Data.IOArray.Prims
%default total
||| A wrapper for Idris's primitive array type.
export
record PrimArray a where
constructor MkPrimArray
arraySize : Nat
content : ArrayData a
export
length : PrimArray a -> Nat
length = arraySize
-- Private helper functions for ArrayData primitives
newArrayData : Nat -> a -> IO (ArrayData a)
newArrayData n x = fromPrim $ prim__newArray (cast n) x
arrayDataGet : Nat -> ArrayData a -> IO a
arrayDataGet n arr = fromPrim $ prim__arrayGet arr (cast n)
arrayDataSet : Nat -> a -> ArrayData a -> IO ()
arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x
||| Construct an array with a constant value.
export
constant : Nat -> a -> PrimArray a
constant size x = MkPrimArray size $ unsafePerformIO $ newArrayData size x
||| Construct an array from a list of "instructions" to write a value to a
||| particular index.
export
unsafeFromIns : Nat -> List (Nat, a) -> PrimArray a
unsafeFromIns size ins = unsafePerformIO $ do
arr <- newArrayData size (believe_me ())
for_ ins $ \(i,x) => arrayDataSet i x arr
pure $ MkPrimArray size arr
||| Create an array given its size and a function to generate its elements by
||| its index.
export
create : Nat -> (Nat -> a) -> PrimArray a
create size f = unsafePerformIO $ do
arr <- newArrayData size (believe_me ())
addToArray Z size arr
pure $ MkPrimArray size arr
where
addToArray : Nat -> Nat -> ArrayData a -> IO ()
addToArray loc Z arr = pure ()
addToArray loc (S n) arr
= do arrayDataSet loc (f loc) arr
addToArray (S loc) n arr
||| Index into a primitive array. This function is unsafe, as it performs no
||| boundary check on the index given.
export
index : Nat -> PrimArray a -> a
index n arr = unsafePerformIO $ arrayDataGet n $ content arr
||| A safe version of `index` that ensures the index entered is valid.
export
safeIndex : Nat -> PrimArray a -> Maybe a
safeIndex n arr = if n < length arr
then Just $ index n arr
else Nothing
export
copy : PrimArray a -> PrimArray a
copy arr = create (length arr) (\n => index n arr)
export
updateAt : Nat -> (a -> a) -> PrimArray a -> PrimArray a
updateAt n f arr = if n >= length arr then arr else
unsafePerformIO $ do
let cpy = copy arr
x <- arrayDataGet n cpy.content
arrayDataSet n (f x) cpy.content
pure cpy
||| Convert a primitive array to a list.
export
toList : PrimArray a -> List a
toList arr = iter (length arr) []
where
iter : Nat -> List a -> List a
iter Z acc = acc
iter (S n) acc = let el = index n arr
in iter n (el :: acc)
||| Construct a primitive array from a list.
export
fromList : List a -> PrimArray a
fromList xs = create (length xs)
(\n => assert_total $ fromJust $ getAt n xs)
where
partial
fromJust : Maybe a -> a
fromJust (Just x) = x
||| Map a function over a primitive array.
export
map : (a -> b) -> PrimArray a -> PrimArray b
map f arr = create (length arr) (\n => f $ index n arr)
export
unsafeZipWith : (a -> b -> c) -> PrimArray a -> PrimArray b -> PrimArray c
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
export
unsafeZipWith3 : (a -> b -> c -> d) ->
PrimArray a -> PrimArray b -> PrimArray c -> PrimArray d
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
export
unzipWith : (a -> (b, c)) -> PrimArray a -> (PrimArray b, PrimArray c)
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
export
unzipWith3 : (a -> (b, c, d)) -> PrimArray a -> (PrimArray b, PrimArray c, PrimArray d)
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
map ((\(_,y,_) => y) . f) arr,
map ((\(_,_,z) => z) . f) arr)
export
foldl : (b -> a -> b) -> b -> PrimArray a -> b
foldl f z (MkPrimArray size arr) = unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred size] $ \n => do
x <- readIORef ref
y <- arrayDataGet n arr
writeIORef ref (f x y)
readIORef ref
export
foldr : (a -> b -> b) -> b -> PrimArray a -> b
foldr f z (MkPrimArray size arr) = unsafePerformIO $ do
ref <- newIORef z
for_ [pred size..0] $ \n => do
x <- arrayDataGet n arr
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : Applicative f => (a -> f b) -> PrimArray a -> f (PrimArray b)
traverse f = map fromList . traverse f . toList
||| Compares two primitive arrays for equal elements. This function assumes the
||| arrays have the same length; it must not be used in any other case.
export
unsafeEq : Eq a => PrimArray a -> PrimArray a -> Bool
unsafeEq a b = unsafePerformIO $
map (concat @{All}) $ for [0..pred (arraySize a)] $
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)