Add more array utility functions

This commit is contained in:
Kiana Sheibani 2022-06-14 20:21:37 -04:00
parent 03d06a42aa
commit acd0cb6aa5
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 57 additions and 10 deletions

View file

@ -83,7 +83,7 @@ rank : Array s a -> Nat
rank = length . shape
export
shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl
@ -273,8 +273,8 @@ reorder ord' arr = let s = shape arr
sts = calcStrides ord' s
in rewrite shapeEq arr
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
export
@ -302,10 +302,15 @@ concat axis a b = let sA = shape a
-- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
display : Show a => Array s a -> IO ()
display = printLn . PrimArray.toList . getPrim
export
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
stack axis arrs = rewrite sym (lengthCorrect arrs) in
fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
(i,is') => index is' (index i arrs))
where
getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
getAxisInd FZ (i :: is) = (i, is)
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
--------------------------------------------------------------------------------
@ -392,9 +397,8 @@ export
{s : _} -> Monoid a => Monoid (Array s a) where
neutral = repeat s neutral
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
-- The shape must be known at runtime here due to `fromInteger`. If `fromInteger`
-- were moved into its own interface, this constraint could be removed.
export
{s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+)
@ -444,3 +448,30 @@ Show a => Show (Array s a) where
then List.replicate d ("[]" :: Prelude.Nil)
else map (insertPunct s) $ splitWindow (length strs `div` d) strs
in "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
--------------------------------------------------------------------------------
-- Numeric array operations
--------------------------------------------------------------------------------
export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
export
normSq : Num a => Array s a -> a
normSq arr = sum $ zipWith (*) arr arr
export
norm : Array s Double -> Double
norm = sqrt . normSq
export
normalize : Array s Double -> Array s Double
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export
pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p)