From acd0cb6aa56d86fcef950f5dc0855517e4183afa Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Tue, 14 Jun 2022 20:21:37 -0400 Subject: [PATCH] Add more array utility functions --- src/Data/NumIdr/Array/Array.idr | 49 +++++++++++++++++++++++++++------ src/Data/NumIdr/Matrix.idr | 7 +++++ src/Data/NumIdr/Multiply.idr | 2 +- src/Data/NumIdr/Vector.idr | 9 ++++++ 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 2519a9d..d92844f 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -83,7 +83,7 @@ rank : Array s a -> Nat rank = length . shape - +export shapeEq : (arr : Array s a) -> s = shape arr shapeEq (MkArray _ _ _ _) = Refl @@ -273,8 +273,8 @@ reorder ord' arr = let s = shape arr sts = calcStrides ord' s in rewrite shapeEq arr in MkArray ord' sts _ (unsafeFromIns (product s) $ - map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ - getAllCoords' s) + map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $ + getAllCoords' s) export @@ -302,10 +302,15 @@ concat axis a b = let sA = shape a -- TODO: prove that the type-level shape and `s` are equivalent in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins) - - -display : Show a => Array s a -> IO () -display = printLn . PrimArray.toList . getPrim +export +stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a +stack axis arrs = rewrite sym (lengthCorrect arrs) in + fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of + (i,is') => index is' (index i arrs)) + where + getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s) + getAxisInd FZ (i :: is) = (i, is) + getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is) -------------------------------------------------------------------------------- @@ -392,9 +397,8 @@ export {s : _} -> Monoid a => Monoid (Array s a) where neutral = repeat s neutral --- the shape must be known at runtime due to `fromInteger`. If `fromInteger` +-- The shape must be known at runtime here due to `fromInteger`. If `fromInteger` -- were moved into its own interface, this constraint could be removed. - export {s : _} -> Num a => Num (Array s a) where (+) = zipWith (+) @@ -444,3 +448,30 @@ Show a => Show (Array s a) where then List.replicate d ("[]" :: Prelude.Nil) else map (insertPunct s) $ splitWindow (length strs `div` d) strs in "[" :: (concat $ intersperse [", "] secs) `snoc` "]" + + +-------------------------------------------------------------------------------- +-- Numeric array operations +-------------------------------------------------------------------------------- + + +export +lerp : Neg a => a -> Array s a -> Array s a -> Array s a +lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t) + + +export +normSq : Num a => Array s a -> a +normSq arr = sum $ zipWith (*) arr arr + +export +norm : Array s Double -> Double +norm = sqrt . normSq + +export +normalize : Array s Double -> Array s Double +normalize arr = if all (==0) arr then arr else map (/ norm arr) arr + +export +pnorm : (p : Double) -> Array s Double -> Double +pnorm p = (`pow` recip p) . sum . map (`pow` p) diff --git a/src/Data/NumIdr/Matrix.idr b/src/Data/NumIdr/Matrix.idr index 2011a83..c1559da 100644 --- a/src/Data/NumIdr/Matrix.idr +++ b/src/Data/NumIdr/Matrix.idr @@ -102,3 +102,10 @@ export hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a hconcat = concat 1 + +export +kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a +kronecker a b = rewrite dimEq a in rewrite dimEq b in + fromFunction [dim a, dim b] + (\[i,j] => Vector.index (rewrite dimEq a in i) a * + Vector.index (rewrite dimEq b in j) b) diff --git a/src/Data/NumIdr/Multiply.idr b/src/Data/NumIdr/Multiply.idr index fb828b9..63d3308 100644 --- a/src/Data/NumIdr/Multiply.idr +++ b/src/Data/NumIdr/Multiply.idr @@ -3,7 +3,7 @@ module Data.NumIdr.Multiply %default total -infixr 7 *. +infixr 9 *. ||| A generalized multiplication/transformation operator. This interface is ||| necessary since the standard multiplication operator is homogenous. diff --git a/src/Data/NumIdr/Vector.idr b/src/Data/NumIdr/Vector.idr index c768dc2..9746271 100644 --- a/src/Data/NumIdr/Vector.idr +++ b/src/Data/NumIdr/Vector.idr @@ -11,6 +11,14 @@ Vector : Nat -> Type -> Type Vector n = Array [n] +public export +dim : Vector n a -> Nat +dim = head . shape + +export +dimEq : (v : Vector n a) -> n = dim v +dimEq v = cong head $ shapeEq v + -------------------------------------------------------------------------------- -- Vector constructors -------------------------------------------------------------------------------- @@ -107,3 +115,4 @@ export perp : Neg a => Vector 2 a -> Vector 2 a -> a perp a b = a.x * b.y - a.y * b.x +