Fix primitive array functions handling empty arrays

This commit is contained in:
Kiana Sheibani 2023-09-26 00:39:47 -04:00
parent b924d960b5
commit c3c14eba04
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
3 changed files with 13 additions and 108 deletions

View file

@ -512,9 +512,12 @@ decompLUP : FieldCmp a => (mat : Matrix m n a) -> DecompLUP mat
decompLUP {m,n} mat with (viewShape mat) decompLUP {m,n} mat with (viewShape mat)
decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0 decompLUP {m=0,n} mat | Shape [0,n] = MkLUP mat identity 0
decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0 decompLUP {m=S m,n=0} mat | Shape [S m,0] = MkLUP mat identity 0
decompLUP {m=S m,n=S n} mat | Shape [S m,S n] = decompLUP {m=S m,n=S n} mat | Shape [S m,S n] = undelay $
iterateN (S $ minimum m n) gaussStepSwap (MkLUP mat identity 0) iterateN (S $ minimum m n) gaussStepSwap (MkLUP (convertRep Delayed mat) identity 0)
where where
undelay : DecompLUP mat -> DecompLUP mat
undelay (MkLUP mat' p sw) = MkLUP (convertRep _ @{getRepC mat} mat') p sw
maxIndex : (s,a) -> List (s,a) -> (s,a) maxIndex : (s,a) -> List (s,a) -> (s,a)
maxIndex x [] = x maxIndex x [] = x
maxIndex _ [x] = x maxIndex _ [x] = x

View file

@ -62,7 +62,7 @@ unsafeFromIns s ins = arrayAction s $ \arr =>
export export
create : {o : _} -> (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBoxed o s a create : {o : _} -> (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBoxed o s a
create s f = arrayAction s $ \arr => create s f = arrayAction s $ \arr =>
for_ [0..pred (product s)] $ \i => arrayDataSet i (f i) arr for_ (range 0 (product s)) $ \i => arrayDataSet i (f i) arr
||| Index into a primitive array. This function is unsafe, as it performs no ||| Index into a primitive array. This function is unsafe, as it performs no
||| boundary check on the index given. ||| boundary check on the index given.
@ -88,7 +88,7 @@ export
bytesToBoxed : {s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBoxed o s a bytesToBoxed : {s : _} -> ByteRep a => PrimArrayBytes o s -> PrimArrayBoxed o s a
bytesToBoxed (MkPABytes sts buf) = MkPABoxed sts $ unsafePerformIO $ do bytesToBoxed (MkPABytes sts buf) = MkPABoxed sts $ unsafePerformIO $ do
arr <- newArrayData (product s) (believe_me ()) arr <- newArrayData (product s) (believe_me ())
for_ [0..pred (product s)] $ \i => arrayDataSet i !(readBuffer i buf) arr for_ (range 0 (product s)) $ \i => arrayDataSet i !(readBuffer i buf) arr
pure arr pure arr
export export
@ -96,7 +96,7 @@ boxedToBytes : {s : _} -> ByteRep a => PrimArrayBoxed o s a -> PrimArrayBytes o
boxedToBytes @{br} (MkPABoxed sts arr) = MkPABytes sts $ unsafePerformIO $ do boxedToBytes @{br} (MkPABoxed sts arr) = MkPABytes sts $ unsafePerformIO $ do
Just buf <- newBuffer $ cast (product s * bytes @{br}) Just buf <- newBuffer $ cast (product s * bytes @{br})
| Nothing => die "Cannot create array" | Nothing => die "Cannot create array"
for_ [0..pred (product s)] $ \i => writeBuffer i !(arrayDataGet i arr) buf for_ (range 0 (product s)) $ \i => writeBuffer i !(arrayDataGet i arr) buf
pure buf pure buf
@ -115,7 +115,7 @@ foldl f z (MkPABoxed sts arr) =
if product s == 0 then z if product s == 0 then z
else unsafePerformIO $ do else unsafePerformIO $ do
ref <- newIORef z ref <- newIORef z
for_ [0..pred $ product s] $ \n => do for_ (range 0 (product s)) $ \n => do
x <- readIORef ref x <- readIORef ref
y <- arrayDataGet n arr y <- arrayDataGet n arr
writeIORef ref (f x y) writeIORef ref (f x y)
@ -127,7 +127,7 @@ foldr f z (MkPABoxed sts arr) =
if product s == 0 then z if product s == 0 then z
else unsafePerformIO $ do else unsafePerformIO $ do
ref <- newIORef z ref <- newIORef z
for_ [pred $ product s..0] $ \n => do for_ (range 0 (product s)) $ \n => do
x <- arrayDataGet n arr x <- arrayDataGet n arr
y <- readIORef ref y <- readIORef ref
writeIORef ref (f x y) writeIORef ref (f x y)
@ -137,99 +137,3 @@ export
traverse : {o,s : _} -> Applicative f => (a -> f b) -> PrimArrayBoxed o s a -> f (PrimArrayBoxed o s b) traverse : {o,s : _} -> Applicative f => (a -> f b) -> PrimArrayBoxed o s a -> f (PrimArrayBoxed o s b)
traverse f = map (Boxed.fromList _) . traverse f . foldr (::) [] traverse f = map (Boxed.fromList _) . traverse f . foldr (::) []
{-
export
updateAt : Nat -> (a -> a) -> PrimArrayBoxed o s a -> PrimArrayBoxed o s a
updateAt n f arr = if n >= length arr then arr else
unsafePerformIO $ do
let cpy = copy arr
x <- arrayDataGet n cpy.content
arrayDataSet n (f x) cpy.content
pure cpy
export
unsafeUpdateInPlace : Nat -> (a -> a) -> PrimArrayBoxed a -> PrimArrayBoxed a
unsafeUpdateInPlace n f arr = unsafePerformIO $ do
x <- arrayDataGet n arr.content
arrayDataSet n (f x) arr.content
pure arr
||| Convert a primitive array to a list.
export
toList : PrimArrayBoxed a -> List a
toList arr = iter (length arr) []
where
iter : Nat -> List a -> List a
iter Z acc = acc
iter (S n) acc = let el = index n arr
in iter n (el :: acc)
||| Construct a primitive array from a list.
export
fromList : List a -> PrimArrayBoxed a
fromList xs = create (length xs)
(\n => assert_total $ fromJust $ getAt n xs)
where
partial
fromJust : Maybe a -> a
fromJust (Just x) = x
export
unsafeZipWith : (a -> b -> c) -> PrimArrayBoxed a -> PrimArrayBoxed b -> PrimArrayBoxed c
unsafeZipWith f a b = create (length a) (\n => f (index n a) (index n b))
export
unsafeZipWith3 : (a -> b -> c -> d) ->
PrimArrayBoxed a -> PrimArrayBoxed b -> PrimArrayBoxed c -> PrimArrayBoxed d
unsafeZipWith3 f a b c = create (length a) (\n => f (index n a) (index n b) (index n c))
export
unzipWith : (a -> (b, c)) -> PrimArrayBoxed a -> (PrimArrayBoxed b, PrimArrayBoxed c)
unzipWith f arr = (map (fst . f) arr, map (snd . f) arr)
export
unzipWith3 : (a -> (b, c, d)) -> PrimArrayBoxed a -> (PrimArrayBoxed b, PrimArrayBoxed c, PrimArrayBoxed d)
unzipWith3 f arr = (map ((\(x,_,_) => x) . f) arr,
map ((\(_,y,_) => y) . f) arr,
map ((\(_,_,z) => z) . f) arr)
export
foldl : (b -> a -> b) -> b -> PrimArrayBoxed a -> b
foldl f z (MkPABoxed size arr) =
if size == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [0..pred size] $ \n => do
x <- readIORef ref
y <- arrayDataGet n arr
writeIORef ref (f x y)
readIORef ref
export
foldr : (a -> b -> b) -> b -> PrimArrayBoxed a -> b
foldr f z (MkPABoxed size arr) =
if size == 0 then z
else unsafePerformIO $ do
ref <- newIORef z
for_ [pred size..0] $ \n => do
x <- arrayDataGet n arr
y <- readIORef ref
writeIORef ref (f x y)
readIORef ref
export
traverse : Applicative f => (a -> f b) -> PrimArrayBoxed a -> f (PrimArrayBoxed b)
traverse f = map fromList . traverse f . toList
||| Compares two primitive arrays for equal elements. This function assumes the
||| arrays have the same length; it must not be used in any other case.
export
unsafeEq : Eq a => PrimArrayBoxed a -> PrimArrayBoxed a -> Bool
unsafeEq a b = unsafePerformIO $
map (concat @{All}) $ for [0..pred (arraySize a)] $
\n => (==) <$> arrayDataGet n (content a) <*> arrayDataGet n (content b)
-}

View file

@ -48,8 +48,7 @@ bufferAction @{br} s action = fromBuffer s $ unsafePerformIO $ do
export export
constant : {o : _} -> ByteRep a => (s : Vect rk Nat) -> a -> PrimArrayBytes o s constant : {o : _} -> ByteRep a => (s : Vect rk Nat) -> a -> PrimArrayBytes o s
constant @{br} s x = bufferAction @{br} s $ \buf => do constant @{br} s x = bufferAction @{br} s $ \buf => do
let size = product s for_ (range 0 (product s)) $ \i => writeBuffer i x buf
for_ [0..pred size] $ \i => writeBuffer i x buf
export export
unsafeDoIns : ByteRep a => List (Nat, a) -> PrimArrayBytes o s -> IO () unsafeDoIns : ByteRep a => List (Nat, a) -> PrimArrayBytes o s -> IO ()
@ -63,8 +62,7 @@ unsafeFromIns @{br} s ins = bufferAction @{br} s $ \buf =>
export export
create : {o : _} -> ByteRep a => (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBytes o s create : {o : _} -> ByteRep a => (s : Vect rk Nat) -> (Nat -> a) -> PrimArrayBytes o s
create @{br} s f = bufferAction @{br} s $ \buf => do create @{br} s f = bufferAction @{br} s $ \buf => do
let size = product s for_ (range 0 (product s)) $ \i => writeBuffer i (f i) buf
for_ [0..pred size] $ \i => writeBuffer i (f i) buf
export export
index : ByteRep a => Nat -> PrimArrayBytes o s -> a index : ByteRep a => Nat -> PrimArrayBytes o s -> a
@ -102,7 +100,7 @@ foldl f z (MkPABytes sts buf) =
if product s == 0 then z if product s == 0 then z
else unsafePerformIO $ do else unsafePerformIO $ do
ref <- newIORef z ref <- newIORef z
for_ [0..pred $ product s] $ \n => do for_ (range 0 (product s)) $ \n => do
x <- readIORef ref x <- readIORef ref
y <- readBuffer n buf y <- readBuffer n buf
writeIORef ref (f x y) writeIORef ref (f x y)