Create PrimArray submodules for each rep

This commit is contained in:
Kiana Sheibani 2023-05-05 13:42:28 -04:00
parent eeb2967f6f
commit bd235754d2
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
5 changed files with 516 additions and 1 deletions

View file

@ -12,6 +12,11 @@ readme = "README.md"
modules = Data.NP,
Data.Permutation,
Data.NumIdr,
Data.NumIdr.PrimArray,
Data.NumIdr.PrimArray.Bytes,
Data.NumIdr.PrimArray.Boxed,
Data.NumIdr.PrimArray.Linked,
Data.NumIdr.PrimArray.Delayed,
Data.NumIdr.Array,
Data.NumIdr.Array.Array,
Data.NumIdr.Array.Coords,
@ -19,7 +24,6 @@ modules = Data.NP,
Data.NumIdr.Homogeneous,
Data.NumIdr.Matrix,
Data.NumIdr.Interfaces,
Data.NumIdr.PrimArray,
Data.NumIdr.Scalar,
Data.NumIdr.Vector,
Data.NumIdr.LArray,

View file

@ -0,0 +1,210 @@
module Data.NumIdr.PrimArray.Boxed
import Data.Buffer
import System
import Data.IOArray.Prims
import Data.Vect
import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords
import Data.NumIdr.PrimArray.Bytes
%default total
public export
data PrimArrayBoxed : Order -> Vect rk Nat -> Type -> Type where
MkPABoxed : (sts : Vect rk Nat) -> ArrayData a -> PrimArrayBoxed {rk} o s a
-- Private helper functions for ArrayData primitives
export
newArrayData : Nat -> a -> IO (ArrayData a)
newArrayData n x = fromPrim $ prim__newArray (cast n) x
export
arrayDataGet : Nat -> ArrayData a -> IO a
arrayDataGet n arr = fromPrim $ prim__arrayGet arr (cast n)
export
arrayDataSet : Nat -> a -> ArrayData a -> IO ()
arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x
fromArray : {o : _} -> (s : Vect rk Nat) -> ArrayData a -> PrimArrayBoxed o s a
fromArray s = MkPABoxed (calcStrides o s)
arrayAction : {o : _} -> (s : Vect rk Nat) -> (ArrayData a -> IO ()) -> PrimArrayBoxed o s a
arrayAction s action = fromArray s $ unsafePerformIO $ do
arr <- newArrayData (product s) (believe_me ())
action arr
pure arr
||| Construct an array with a constant value.
export
constant : {o : _} -> (s : Vect rk Nat) -> a -> PrimArrayBoxed o s a
constant s x = fromArray s $ unsafePerformIO $ newArrayData (product s) x
export
unsafeDoIns : List (Nat, a) -> PrimArrayBoxed o s a -> IO ()
unsafeDoIns ins (MkPABoxed _ arr) = for_ ins $ \(i,x) => arrayDataSet i x arr
||| Construct an array from a list of "instructions" to write a value to a
||| particular index.
export
unsafeFromIns : {o : _} -> (s : Vect rk Nat) -> List (Nat, a) -> PrimArrayBoxed o s a
unsafeFromIns s ins = arrayAction s $ \arr =>
for_ ins $ \(i,x) => arrayDataSet i x arr
||| Create an array given its size and a function to generate its elements by
||| its index.
export
create : {o : _} -> (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBoxed o s a
create s f = arrayAction s $ \arr =>
for_ [0..pred (product s)] $ \i => arrayDataSet i (f i) arr
||| Index into a primitive array. This function is unsafe, as it performs no
||| boundary check on the index given.
export
index : Nat -> PrimArrayBoxed o s a -> a
index n (MkPABoxed _ arr) = unsafePerformIO $ arrayDataGet n arr
export
copy : {o,s : _} -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
copy arr = create s (\n => index n arr)
export
reorder : {o,o',s : _} -> PrimArrayBoxed o s a -> PrimArrayBoxed o' s a
reorder arr@(MkPABoxed sts ar) =
if o == o' then MkPABoxed sts ar
else let sts' = calcStrides o' s
in unsafeFromIns s
((\is => (getLocation' sts' is, index (getLocation' sts is) arr))
<$> getAllCoords' s)
export
bytesToBoxed : {s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBoxed o s a
bytesToBoxed (MkPABytes sts buf) = MkPABoxed sts $ unsafePerformIO $ do
arr <- newArrayData (product s) (believe_me ())
for_ [0..pred (product s)] $ \i => arrayDataSet i !(readBuffer i buf) arr
pure arr
export
boxedToBytes : {s : _} -> ByteRep a => PrimArrayBoxed o s a -> PrimArrayBytes o s
boxedToBytes @{br} (MkPABoxed sts arr) = MkPABytes sts $ unsafePerformIO $ do
Just buf <- newBuffer $ cast (product s * bytes @{br})
| Nothing => die "Cannot create array"
for_ [0..pred (product s)] $ \i => writeBuffer i !(arrayDataGet i arr) buf
pure buf
export
fromList : {o : _} -> (s : Vect rk Nat) -> List a -> PrimArrayBoxed o s a
fromList s xs = arrayAction s $ go xs 0
where
go : List a -> Nat -> ArrayData a -> IO ()
go [] _ buf = pure ()
go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf
{-
export
updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
updateAt n f arr = if n >= length arr then arr else
unsafePerformIO $ do
let cpy = copy arr
x <- arrayDataGet n cpy.content
arrayDataSet n (f x) cpy.content
pure cpy
export
unsafeUpdateInPlace : Nat -> (a -> a) -> PrimArrayBoxed a -> PrimArrayBoxed a
unsafeUpdateInPlace n f arr = unsafePerformIO $ do
x <- arrayDataGet n arr.content
arrayDataSet n (f x) arr.content
pure arr
||| Convert a primitive array to a list.
export
toList : PrimArrayBoxed a -> List a
toList arr = iter (length arr) []
where
iter : Nat -> List a -> List a
iter Z acc = acc
iter (S n) acc = let el = index n arr
in iter n (el :: acc)
||| Construct a primitive array from a list.
export
fromList : List a -> PrimArrayBoxed a
fromList xs = create (length xs)
(\n => assert_total $ fromJust $ getAt n xs)
where
partial
fromJust : Maybe a -> a
fromJust (Just x) = x
||| Map a function over a primitive array.
export
map : (a -> b) -> PrimArrayBoxed a -> PrimArrayBoxed b
map f arr = create (length arr) (\n => f $ index n arr)
export
unsafeZipWith : (a -> b -> c) -> PrimArrayBoxed a -> PrimArrayBoxed b -> PrimArrayBoxed c
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
export
unsafeZipWith3 : (a -> b -> c -> d) ->
PrimArrayBoxed a -> PrimArrayBoxed b -> PrimArrayBoxed c -> PrimArrayBoxed d
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
export
unzipWith : (a -> (b, c)) -> PrimArrayBoxed a -> (PrimArrayBoxed b, PrimArrayBoxed c)
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
export
unzipWith3 : (a -> (b, c, d)) -> PrimArrayBoxed a -> (PrimArrayBoxed b, PrimArrayBoxed c, PrimArrayBoxed d)
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
map ((\(_,y,_) => y) . f) arr,
map ((\(_,_,z) => z) . f) arr)
export
foldl : (b -> a -> b) -> b -> PrimArrayBoxed a -> b
foldl f z (MkPABoxed size arr) =
if size == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred size] $ \n => do
x <- readIORef ref
y <- arrayDataGet n arr
writeIORef ref (f x y)
readIORef ref
export
foldr : (a -> b -> b) -> b -> PrimArrayBoxed a -> b
foldr f z (MkPABoxed size arr) =
if size == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [pred size..0] $ \n => do
x <- arrayDataGet n arr
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : Applicative f => (a -> f b) -> PrimArrayBoxed a -> f (PrimArrayBoxed b)
traverse f = map fromList . traverse f . toList
||| Compares two primitive arrays for equal elements. This function assumes the
||| arrays have the same length; it must not be used in any other case.
export
unsafeEq : Eq a => PrimArrayBoxed a -> PrimArrayBoxed a -> Bool
unsafeEq a b = unsafePerformIO $
map (concat @{All}) $ for [0..pred (arraySize a)] $
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)
-}

View file

@ -0,0 +1,244 @@
module Data.NumIdr.PrimArray.Bytes
import System
import Data.IORef
import Data.Bits
import Data.So
import Data.Buffer
import Data.Vect
import Data.NumIdr.Array.Coords
import Data.NumIdr.Array.Rep
%default total
public export
interface ByteRep a where
bytes : Nat
toBytes : a -> Vect bytes Bits8
fromBytes : Vect bytes Bits8 -> a
export
ByteRep Bits8 where
bytes = 1
toBytes x = [x]
fromBytes [x] = x
export
ByteRep Bits16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Bits32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Bits64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Int8 where
bytes = 1
toBytes x = [cast x]
fromBytes [x] = cast x
export
ByteRep Int16 where
bytes = 2
toBytes x = [cast (x `shiftR` 8), cast x]
fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2
export
ByteRep Int32 where
bytes = 4
toBytes x = [cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4] =
cast b1 `shiftL` 24 .|.
cast b2 `shiftL` 16 .|.
cast b3 `shiftL` 8 .|.
cast b4
export
ByteRep Int64 where
bytes = 8
toBytes x = [cast (x `shiftR` 56),
cast (x `shiftR` 48),
cast (x `shiftR` 40),
cast (x `shiftR` 32),
cast (x `shiftR` 24),
cast (x `shiftR` 16),
cast (x `shiftR` 8),
cast x]
fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] =
cast b1 `shiftL` 56 .|.
cast b2 `shiftL` 48 .|.
cast b3 `shiftL` 40 .|.
cast b4 `shiftL` 32 .|.
cast b5 `shiftL` 24 .|.
cast b6 `shiftL` 16 .|.
cast b7 `shiftL` 8 .|.
cast b8
export
ByteRep Bool where
bytes = 1
toBytes b = [if b then 1 else 0]
fromBytes [x] = x /= 0
export
ByteRep a => ByteRep b => ByteRep (a, b) where
bytes = bytes {a} + bytes {a=b}
toBytes (x,y) = toBytes x ++ toBytes y
fromBytes = bimap fromBytes fromBytes . splitAt _
export
{n : _} -> ByteRep a => ByteRep (Vect n a) where
bytes = n * bytes {a}
toBytes xs = concat $ map toBytes xs
fromBytes {n = 0} bs = []
fromBytes {n = S n} bs =
let (bs1, bs2) = splitAt _ bs
in fromBytes bs1 :: fromBytes bs2
export
readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a
readBuffer i buf = do
let offset : Int = cast (i * bytes {a})
map (fromBytes {a}) $ for range $ \n => do
getBits8 buf $ offset + cast (finToInteger n)
export
writeBuffer : ByteRep a => (i : Nat) -> a -> Buffer -> IO ()
writeBuffer i x buf = do
let offset = cast (i * bytes {a})
bs = toBytes x
n <- newIORef 0
for_ bs $ \b => setBits8 buf (offset + !(readIORef n)) b <* modifyIORef n (+1)
public export
data PrimArrayBytes : Order -> Vect rk Nat -> Type where
MkPABytes : (sts : Vect rk Nat) -> Buffer -> PrimArrayBytes {rk} o s
export
getBuffer : PrimArrayBytes o s -> Buffer
getBuffer (MkPABytes _ buf) = buf
fromBuffer : {o : _} -> (s : Vect rk Nat) -> Buffer -> PrimArrayBytes o s
fromBuffer s buf = MkPABytes (calcStrides o s) buf
bufferAction : {o : _} -> ByteRep a => (s : Vect rk Nat) -> (Buffer -> IO ()) -> PrimArrayBytes o s
bufferAction @{br} s action = fromBuffer s $ unsafePerformIO $ do
Just buf <- newBuffer $ cast (product s * bytes @{br})
| Nothing => die "Cannot create array"
action buf
pure buf
export
constant : {o : _} -> ByteRep a => (s : Vect rk Nat) -> a -> PrimArrayBytes o s
constant @{br} s x = bufferAction @{br} s $ \buf => do
let size = product s
for_ [0..pred size] $ \i => writeBuffer i x buf
export
unsafeDoIns : ByteRep a => List (Nat, a) -> PrimArrayBytes o s -> IO ()
unsafeDoIns ins (MkPABytes _ buf) = for_ ins $ \(i,x) => writeBuffer i x buf
export
unsafeFromIns : {o : _} -> ByteRep a => (s : Vect rk Nat) -> List (Nat, a) -> PrimArrayBytes o s
unsafeFromIns @{br} s ins = bufferAction @{br} s $ \buf =>
for_ ins $ \(i,x) => writeBuffer i x buf
export
create : {o : _} -> ByteRep a => (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBytes o s
create @{br} s f = bufferAction @{br} s $ \buf => do
let size = product s
for_ [0..pred size] $ \i => writeBuffer i (f i) buf
export
index : ByteRep a => Nat -> PrimArrayBytes o s -> a
index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf
export
reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s
reorder @{br} arr@(MkPABytes sts buf) =
if o == o' then MkPABytes sts buf
else let sts' = calcStrides o' s
in unsafeFromIns @{br} s
((\is => (getLocation' sts' is, index @{br} (getLocation' sts is) arr))
<$> getAllCoords' s)
export
fromList : {o : _} -> ByteRep a => (s : Vect rk Nat) -> List a -> PrimArrayBytes o s
fromList @{br} s xs = bufferAction @{br} s $ go xs 0
where
go : List a -> Nat -> Buffer -> IO ()
go [] _ buf = pure ()
go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf

View file

@ -0,0 +1,23 @@
module Data.NumIdr.PrimArray.Delayed
import Data.Vect
import Data.NP
import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords
%default total
public export
PrimArrayDelayed : Vect rk Nat -> Type -> Type
PrimArrayDelayed s a = Coords s -> a
export
constant : (s : Vect rk Nat) -> a -> PrimArrayDelayed s a
constant s x _ = x
export
checkRange : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s)
checkRange [] [] = Just []
checkRange (d :: s) (i :: is) = (::) <$> natToFin i d <*> checkRange s is

View file

@ -0,0 +1,34 @@
module Data.NumIdr.PrimArray.Linked
import Data.Vect
import Data.NP
import Data.NumIdr.Array.Rep
import Data.NumIdr.Array.Coords
%default total
export
constant : (s : Vect rk Nat) -> a -> Vects s a
constant [] x = x
constant (d :: s) xs = replicate d (constant s xs)
export
index : Coords s -> Vects s a -> a
index [] x = x
index (c :: cs) xs = index cs (index c xs)
export
fromFunction : {s : _} -> (Coords s -> a) -> Vects s a
fromFunction {s = []} f = f []
fromFunction {s = d :: s} f = tabulate (\i => fromFunction (f . (i::)))
export
fromFunctionNB : {s : _} -> (Vect rk Nat -> a) -> Vects {rk} s a
fromFunctionNB {s = []} f = f []
fromFunctionNB {s = d :: s} f = tabulate' {n=d} (\i => fromFunctionNB (f . (i::)))
where
tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a
tabulate' {n=Z} f = []
tabulate' {n=S _} f = f Z :: tabulate' (f . S)