diff --git a/numidr.ipkg b/numidr.ipkg index 40eb48c..9ba64da 100644 --- a/numidr.ipkg +++ b/numidr.ipkg @@ -12,6 +12,11 @@ readme = "README.md" modules = Data.NP, Data.Permutation, Data.NumIdr, + Data.NumIdr.PrimArray, + Data.NumIdr.PrimArray.Bytes, + Data.NumIdr.PrimArray.Boxed, + Data.NumIdr.PrimArray.Linked, + Data.NumIdr.PrimArray.Delayed, Data.NumIdr.Array, Data.NumIdr.Array.Array, Data.NumIdr.Array.Coords, @@ -19,7 +24,6 @@ modules = Data.NP, Data.NumIdr.Homogeneous, Data.NumIdr.Matrix, Data.NumIdr.Interfaces, - Data.NumIdr.PrimArray, Data.NumIdr.Scalar, Data.NumIdr.Vector, Data.NumIdr.LArray, diff --git a/src/Data/NumIdr/PrimArray/Boxed.idr b/src/Data/NumIdr/PrimArray/Boxed.idr new file mode 100644 index 0000000..9b038bd --- /dev/null +++ b/src/Data/NumIdr/PrimArray/Boxed.idr @@ -0,0 +1,210 @@ +module Data.NumIdr.PrimArray.Boxed + +import Data.Buffer +import System +import Data.IOArray.Prims +import Data.Vect +import Data.NumIdr.Array.Rep +import Data.NumIdr.Array.Coords +import Data.NumIdr.PrimArray.Bytes + +%default total + + +public export +data PrimArrayBoxed : Order -> Vect rk Nat -> Type -> Type where + MkPABoxed : (sts : Vect rk Nat) -> ArrayData a -> PrimArrayBoxed {rk} o s a + + +-- Private helper functions for ArrayData primitives +export +newArrayData : Nat -> a -> IO (ArrayData a) +newArrayData n x = fromPrim $ prim__newArray (cast n) x + +export +arrayDataGet : Nat -> ArrayData a -> IO a +arrayDataGet n arr = fromPrim $ prim__arrayGet arr (cast n) + +export +arrayDataSet : Nat -> a -> ArrayData a -> IO () +arrayDataSet n x arr = fromPrim $ prim__arraySet arr (cast n) x + + +fromArray : {o : _} -> (s : Vect rk Nat) -> ArrayData a -> PrimArrayBoxed o s a +fromArray s = MkPABoxed (calcStrides o s) + +arrayAction : {o : _} -> (s : Vect rk Nat) -> (ArrayData a -> IO ()) -> PrimArrayBoxed o s a +arrayAction s action = fromArray s $ unsafePerformIO $ do + arr <- newArrayData (product s) (believe_me ()) + action arr + pure arr + +||| Construct an array with a constant value. +export +constant : {o : _} -> (s : Vect rk Nat) -> a -> PrimArrayBoxed o s a +constant s x = fromArray s $ unsafePerformIO $ newArrayData (product s) x + + +export +unsafeDoIns : List (Nat, a) -> PrimArrayBoxed o s a -> IO () +unsafeDoIns ins (MkPABoxed _ arr) = for_ ins $ \(i,x) => arrayDataSet i x arr + +||| Construct an array from a list of "instructions" to write a value to a +||| particular index. +export +unsafeFromIns : {o : _} -> (s : Vect rk Nat) -> List (Nat, a) -> PrimArrayBoxed o s a +unsafeFromIns s ins = arrayAction s $ \arr => + for_ ins $ \(i,x) => arrayDataSet i x arr + +||| Create an array given its size and a function to generate its elements by +||| its index. +export +create : {o : _} -> (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBoxed o s a +create s f = arrayAction s $ \arr => + for_ [0..pred (product s)] $ \i => arrayDataSet i (f i) arr + +||| Index into a primitive array. This function is unsafe, as it performs no +||| boundary check on the index given. +export +index : Nat -> PrimArrayBoxed o s a -> a +index n (MkPABoxed _ arr) = unsafePerformIO $ arrayDataGet n arr + +export +copy : {o,s : _} -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a +copy arr = create s (\n => index n arr) + + +export +reorder : {o,o',s : _} -> PrimArrayBoxed o s a -> PrimArrayBoxed o' s a +reorder arr@(MkPABoxed sts ar) = + if o == o' then MkPABoxed sts ar + else let sts' = calcStrides o' s + in unsafeFromIns s + ((\is => (getLocation' sts' is, index (getLocation' sts is) arr)) + <$> getAllCoords' s) + +export +bytesToBoxed : {s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBoxed o s a +bytesToBoxed (MkPABytes sts buf) = MkPABoxed sts $ unsafePerformIO $ do + arr <- newArrayData (product s) (believe_me ()) + for_ [0..pred (product s)] $ \i => arrayDataSet i !(readBuffer i buf) arr + pure arr + +export +boxedToBytes : {s : _} -> ByteRep a => PrimArrayBoxed o s a -> PrimArrayBytes o s +boxedToBytes @{br} (MkPABoxed sts arr) = MkPABytes sts $ unsafePerformIO $ do + Just buf <- newBuffer $ cast (product s * bytes @{br}) + | Nothing => die "Cannot create array" + for_ [0..pred (product s)] $ \i => writeBuffer i !(arrayDataGet i arr) buf + pure buf + + +export +fromList : {o : _} -> (s : Vect rk Nat) -> List a -> PrimArrayBoxed o s a +fromList s xs = arrayAction s $ go xs 0 + where + go : List a -> Nat -> ArrayData a -> IO () + go [] _ buf = pure () + go (x :: xs) i buf = arrayDataSet i x buf >> go xs (S i) buf + + +{- +export +updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a +updateAt n f arr = if n >= length arr then arr else + unsafePerformIO $ do + let cpy = copy arr + x <- arrayDataGet n cpy.content + arrayDataSet n (f x) cpy.content + pure cpy + +export +unsafeUpdateInPlace : Nat -> (a -> a) -> PrimArrayBoxed a -> PrimArrayBoxed a +unsafeUpdateInPlace n f arr = unsafePerformIO $ do + x <- arrayDataGet n arr.content + arrayDataSet n (f x) arr.content + pure arr + +||| Convert a primitive array to a list. +export +toList : PrimArrayBoxed a -> List a +toList arr = iter (length arr) [] + where + iter : Nat -> List a -> List a + iter Z acc = acc + iter (S n) acc = let el = index n arr + in iter n (el :: acc) + +||| Construct a primitive array from a list. +export +fromList : List a -> PrimArrayBoxed a +fromList xs = create (length xs) + (\n => assert_total $ fromJust $ getAt n xs) + where + partial + fromJust : Maybe a -> a + fromJust (Just x) = x + +||| Map a function over a primitive array. +export +map : (a -> b) -> PrimArrayBoxed a -> PrimArrayBoxed b +map f arr = create (length arr) (\n => f $ index n arr) + + +export +unsafeZipWith : (a -> b -> c) -> PrimArrayBoxed a -> PrimArrayBoxed b -> PrimArrayBoxed c +unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b)) + +export +unsafeZipWith3 : (a -> b -> c -> d) -> + PrimArrayBoxed a -> PrimArrayBoxed b -> PrimArrayBoxed c -> PrimArrayBoxed d +unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c)) + +export +unzipWith : (a -> (b, c)) -> PrimArrayBoxed a -> (PrimArrayBoxed b, PrimArrayBoxed c) +unzipWith f arr = (map (fst . f) arr, map (snd . f) arr) + +export +unzipWith3 : (a -> (b, c, d)) -> PrimArrayBoxed a -> (PrimArrayBoxed b, PrimArrayBoxed c, PrimArrayBoxed d) +unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr, + map ((\(_,y,_) => y) . f) arr, + map ((\(_,_,z) => z) . f) arr) + + +export +foldl : (b -> a -> b) -> b -> PrimArrayBoxed a -> b +foldl f z (MkPABoxed size arr) = + if size == 0 then z + else unsafePerformIO $ do + ref <- newIORef z + for_ [0..pred size] $ \n => do + x <- readIORef ref + y <- arrayDataGet n arr + writeIORef ref (f x y) + readIORef ref + +export +foldr : (a -> b -> b) -> b -> PrimArrayBoxed a -> b +foldr f z (MkPABoxed size arr) = + if size == 0 then z + else unsafePerformIO $ do + ref <- newIORef z + for_ [pred size..0] $ \n => do + x <- arrayDataGet n arr + y <- readIORef ref + writeIORef ref (f x y) + readIORef ref + +export +traverse : Applicative f => (a -> f b) -> PrimArrayBoxed a -> f (PrimArrayBoxed b) +traverse f = map fromList . traverse f . toList + + +||| Compares two primitive arrays for equal elements. This function assumes the +||| arrays have the same length; it must not be used in any other case. +export +unsafeEq : Eq a => PrimArrayBoxed a -> PrimArrayBoxed a -> Bool +unsafeEq a b = unsafePerformIO $ + map (concat @{All}) $ for [0..pred (arraySize a)] $ + \n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b) +-} diff --git a/src/Data/NumIdr/PrimArray/Bytes.idr b/src/Data/NumIdr/PrimArray/Bytes.idr new file mode 100644 index 0000000..4c0ce72 --- /dev/null +++ b/src/Data/NumIdr/PrimArray/Bytes.idr @@ -0,0 +1,244 @@ +module Data.NumIdr.PrimArray.Bytes + +import System +import Data.IORef +import Data.Bits +import Data.So +import Data.Buffer +import Data.Vect +import Data.NumIdr.Array.Coords +import Data.NumIdr.Array.Rep + +%default total + + +public export +interface ByteRep a where + bytes : Nat + + toBytes : a -> Vect bytes Bits8 + fromBytes : Vect bytes Bits8 -> a + +export +ByteRep Bits8 where + bytes = 1 + + toBytes x = [x] + fromBytes [x] = x + +export +ByteRep Bits16 where + bytes = 2 + + toBytes x = [cast (x `shiftR` 8), cast x] + fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2 + +export +ByteRep Bits32 where + bytes = 4 + + toBytes x = [cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4] = + cast b1 `shiftL` 24 .|. + cast b2 `shiftL` 16 .|. + cast b3 `shiftL` 8 .|. + cast b4 + +export +ByteRep Bits64 where + bytes = 8 + + toBytes x = [cast (x `shiftR` 56), + cast (x `shiftR` 48), + cast (x `shiftR` 40), + cast (x `shiftR` 32), + cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = + cast b1 `shiftL` 56 .|. + cast b2 `shiftL` 48 .|. + cast b3 `shiftL` 40 .|. + cast b4 `shiftL` 32 .|. + cast b5 `shiftL` 24 .|. + cast b6 `shiftL` 16 .|. + cast b7 `shiftL` 8 .|. + cast b8 + +export +ByteRep Int where + bytes = 8 + + toBytes x = [cast (x `shiftR` 56), + cast (x `shiftR` 48), + cast (x `shiftR` 40), + cast (x `shiftR` 32), + cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = + cast b1 `shiftL` 56 .|. + cast b2 `shiftL` 48 .|. + cast b3 `shiftL` 40 .|. + cast b4 `shiftL` 32 .|. + cast b5 `shiftL` 24 .|. + cast b6 `shiftL` 16 .|. + cast b7 `shiftL` 8 .|. + cast b8 + +export +ByteRep Int8 where + bytes = 1 + + toBytes x = [cast x] + fromBytes [x] = cast x + +export +ByteRep Int16 where + bytes = 2 + + toBytes x = [cast (x `shiftR` 8), cast x] + fromBytes [b1,b2] = cast b1 `shiftL` 8 .|. cast b2 + +export +ByteRep Int32 where + bytes = 4 + + toBytes x = [cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4] = + cast b1 `shiftL` 24 .|. + cast b2 `shiftL` 16 .|. + cast b3 `shiftL` 8 .|. + cast b4 + +export +ByteRep Int64 where + bytes = 8 + + toBytes x = [cast (x `shiftR` 56), + cast (x `shiftR` 48), + cast (x `shiftR` 40), + cast (x `shiftR` 32), + cast (x `shiftR` 24), + cast (x `shiftR` 16), + cast (x `shiftR` 8), + cast x] + fromBytes [b1,b2,b3,b4,b5,b6,b7,b8] = + cast b1 `shiftL` 56 .|. + cast b2 `shiftL` 48 .|. + cast b3 `shiftL` 40 .|. + cast b4 `shiftL` 32 .|. + cast b5 `shiftL` 24 .|. + cast b6 `shiftL` 16 .|. + cast b7 `shiftL` 8 .|. + cast b8 + +export +ByteRep Bool where + bytes = 1 + + toBytes b = [if b then 1 else 0] + fromBytes [x] = x /= 0 + +export +ByteRep a => ByteRep b => ByteRep (a, b) where + bytes = bytes {a} + bytes {a=b} + + toBytes (x,y) = toBytes x ++ toBytes y + fromBytes = bimap fromBytes fromBytes . splitAt _ + +export +{n : _} -> ByteRep a => ByteRep (Vect n a) where + bytes = n * bytes {a} + + toBytes xs = concat $ map toBytes xs + fromBytes {n = 0} bs = [] + fromBytes {n = S n} bs = + let (bs1, bs2) = splitAt _ bs + in fromBytes bs1 :: fromBytes bs2 + + +export +readBuffer : ByteRep a => (i : Nat) -> Buffer -> IO a +readBuffer i buf = do + let offset : Int = cast (i * bytes {a}) + map (fromBytes {a}) $ for range $ \n => do + getBits8 buf $ offset + cast (finToInteger n) + +export +writeBuffer : ByteRep a => (i : Nat) -> a -> Buffer -> IO () +writeBuffer i x buf = do + let offset = cast (i * bytes {a}) + bs = toBytes x + n <- newIORef 0 + for_ bs $ \b => setBits8 buf (offset + !(readIORef n)) b <* modifyIORef n (+1) + + +public export +data PrimArrayBytes : Order -> Vect rk Nat -> Type where + MkPABytes : (sts : Vect rk Nat) -> Buffer -> PrimArrayBytes {rk} o s + +export +getBuffer : PrimArrayBytes o s -> Buffer +getBuffer (MkPABytes _ buf) = buf + + +fromBuffer : {o : _} -> (s : Vect rk Nat) -> Buffer -> PrimArrayBytes o s +fromBuffer s buf = MkPABytes (calcStrides o s) buf + +bufferAction : {o : _} -> ByteRep a => (s : Vect rk Nat) -> (Buffer -> IO ()) -> PrimArrayBytes o s +bufferAction @{br} s action = fromBuffer s $ unsafePerformIO $ do + Just buf <- newBuffer $ cast (product s * bytes @{br}) + | Nothing => die "Cannot create array" + action buf + pure buf + +export +constant : {o : _} -> ByteRep a => (s : Vect rk Nat) -> a -> PrimArrayBytes o s +constant @{br} s x = bufferAction @{br} s $ \buf => do + let size = product s + for_ [0..pred size] $ \i => writeBuffer i x buf + +export +unsafeDoIns : ByteRep a => List (Nat, a) -> PrimArrayBytes o s -> IO () +unsafeDoIns ins (MkPABytes _ buf) = for_ ins $ \(i,x) => writeBuffer i x buf + +export +unsafeFromIns : {o : _} -> ByteRep a => (s : Vect rk Nat) -> List (Nat, a) -> PrimArrayBytes o s +unsafeFromIns @{br} s ins = bufferAction @{br} s $ \buf => + for_ ins $ \(i,x) => writeBuffer i x buf + +export +create : {o : _} -> ByteRep a => (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBytes o s +create @{br} s f = bufferAction @{br} s $ \buf => do + let size = product s + for_ [0..pred size] $ \i => writeBuffer i (f i) buf + +export +index : ByteRep a => Nat -> PrimArrayBytes o s -> a +index i (MkPABytes _ buf) = unsafePerformIO $ readBuffer i buf + +export +reorder : {o,o',s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBytes o' s +reorder @{br} arr@(MkPABytes sts buf) = + if o == o' then MkPABytes sts buf + else let sts' = calcStrides o' s + in unsafeFromIns @{br} s + ((\is => (getLocation' sts' is, index @{br} (getLocation' sts is) arr)) + <$> getAllCoords' s) + +export +fromList : {o : _} -> ByteRep a => (s : Vect rk Nat) -> List a -> PrimArrayBytes o s +fromList @{br} s xs = bufferAction @{br} s $ go xs 0 + where + go : List a -> Nat -> Buffer -> IO () + go [] _ buf = pure () + go (x :: xs) i buf = writeBuffer i x buf >> go xs (S i) buf diff --git a/src/Data/NumIdr/PrimArray/Delayed.idr b/src/Data/NumIdr/PrimArray/Delayed.idr new file mode 100644 index 0000000..c7a5825 --- /dev/null +++ b/src/Data/NumIdr/PrimArray/Delayed.idr @@ -0,0 +1,23 @@ +module Data.NumIdr.PrimArray.Delayed + +import Data.Vect +import Data.NP +import Data.NumIdr.Array.Rep +import Data.NumIdr.Array.Coords + +%default total + +public export +PrimArrayDelayed : Vect rk Nat -> Type -> Type +PrimArrayDelayed s a = Coords s -> a + + +export +constant : (s : Vect rk Nat) -> a -> PrimArrayDelayed s a +constant s x _ = x + + +export +checkRange : (s : Vect rk Nat) -> Vect rk Nat -> Maybe (Coords s) +checkRange [] [] = Just [] +checkRange (d :: s) (i :: is) = (::) <$> natToFin i d <*> checkRange s is diff --git a/src/Data/NumIdr/PrimArray/Linked.idr b/src/Data/NumIdr/PrimArray/Linked.idr new file mode 100644 index 0000000..54af7c0 --- /dev/null +++ b/src/Data/NumIdr/PrimArray/Linked.idr @@ -0,0 +1,34 @@ +module Data.NumIdr.PrimArray.Linked + +import Data.Vect +import Data.NP +import Data.NumIdr.Array.Rep +import Data.NumIdr.Array.Coords + +%default total + + +export +constant : (s : Vect rk Nat) -> a -> Vects s a +constant [] x = x +constant (d :: s) xs = replicate d (constant s xs) + +export +index : Coords s -> Vects s a -> a +index [] x = x +index (c :: cs) xs = index cs (index c xs) + + +export +fromFunction : {s : _} -> (Coords s -> a) -> Vects s a +fromFunction {s = []} f = f [] +fromFunction {s = d :: s} f = tabulate (\i => fromFunction (f . (i::))) + +export +fromFunctionNB : {s : _} -> (Vect rk Nat -> a) -> Vects {rk} s a +fromFunctionNB {s = []} f = f [] +fromFunctionNB {s = d :: s} f = tabulate' {n=d} (\i => fromFunctionNB (f . (i::))) + where + tabulate' : forall a. {n : _} -> (Nat -> a) -> Vect n a + tabulate' {n=Z} f = [] + tabulate' {n=S _} f = f Z :: tabulate' (f . S)