Add more array utility functions

This commit is contained in:
Kiana Sheibani 2022-06-14 20:21:37 -04:00
parent 03d06a42aa
commit acd0cb6aa5
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 57 additions and 10 deletions

View file

@ -83,7 +83,7 @@ rank : Array s a -> Nat
rank = length . shape rank = length . shape
export
shapeEq : (arr : Array s a) -> s = shape arr shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl shapeEq (MkArray _ _ _ _) = Refl
@ -302,10 +302,15 @@ concat axis a b = let sA = shape a
-- TODO: prove that the type-level shape and `s` are equivalent -- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
export
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
display : Show a => Array s a -> IO () stack axis arrs = rewrite sym (lengthCorrect arrs) in
display = printLn . PrimArray.toList . getPrim fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
(i,is') => index is' (index i arrs))
where
getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
getAxisInd FZ (i :: is) = (i, is)
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -392,9 +397,8 @@ export
{s : _} -> Monoid a => Monoid (Array s a) where {s : _} -> Monoid a => Monoid (Array s a) where
neutral = repeat s neutral neutral = repeat s neutral
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger` -- The shape must be known at runtime here due to `fromInteger`. If `fromInteger`
-- were moved into its own interface, this constraint could be removed. -- were moved into its own interface, this constraint could be removed.
export export
{s : _} -> Num a => Num (Array s a) where {s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+) (+) = zipWith (+)
@ -444,3 +448,30 @@ Show a => Show (Array s a) where
then List.replicate d ("[]" :: Prelude.Nil) then List.replicate d ("[]" :: Prelude.Nil)
else map (insertPunct s) $ splitWindow (length strs `div` d) strs else map (insertPunct s) $ splitWindow (length strs `div` d) strs
in "[" :: (concat $ intersperse [", "] secs) `snoc` "]" in "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
--------------------------------------------------------------------------------
-- Numeric array operations
--------------------------------------------------------------------------------
export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
export
normSq : Num a => Array s a -> a
normSq arr = sum $ zipWith (*) arr arr
export
norm : Array s Double -> Double
norm = sqrt . normSq
export
normalize : Array s Double -> Array s Double
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export
pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p)

View file

@ -102,3 +102,10 @@ export
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
hconcat = concat 1 hconcat = concat 1
export
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
kronecker a b = rewrite dimEq a in rewrite dimEq b in
fromFunction [dim a, dim b]
(\[i,j] => Vector.index (rewrite dimEq a in i) a *
Vector.index (rewrite dimEq b in j) b)

View file

@ -3,7 +3,7 @@ module Data.NumIdr.Multiply
%default total %default total
infixr 7 *. infixr 9 *.
||| A generalized multiplication/transformation operator. This interface is ||| A generalized multiplication/transformation operator. This interface is
||| necessary since the standard multiplication operator is homogenous. ||| necessary since the standard multiplication operator is homogenous.

View file

@ -11,6 +11,14 @@ Vector : Nat -> Type -> Type
Vector n = Array [n] Vector n = Array [n]
public export
dim : Vector n a -> Nat
dim = head . shape
export
dimEq : (v : Vector n a) -> n = dim v
dimEq v = cong head $ shapeEq v
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Vector constructors -- Vector constructors
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
@ -107,3 +115,4 @@ export
perp : Neg a => Vector 2 a -> Vector 2 a -> a perp : Neg a => Vector 2 a -> Vector 2 a -> a
perp a b = a.x * b.y - a.y * b.x perp a b = a.x * b.y - a.y * b.x