Add more array utility functions

This commit is contained in:
Kiana Sheibani 2022-06-14 20:21:37 -04:00
parent 03d06a42aa
commit acd0cb6aa5
Signed by: toki
GPG key ID: 6CB106C25E86A9F7
4 changed files with 57 additions and 10 deletions

View file

@ -83,7 +83,7 @@ rank : Array s a -> Nat
rank = length . shape
export
shapeEq : (arr : Array s a) -> s = shape arr
shapeEq (MkArray _ _ _ _) = Refl
@ -273,8 +273,8 @@ reorder ord' arr = let s = shape arr
sts = calcStrides ord' s
in rewrite shapeEq arr
in MkArray ord' sts _ (unsafeFromIns (product s) $
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
getAllCoords' s)
export
@ -302,10 +302,15 @@ concat axis a b = let sA = shape a
-- TODO: prove that the type-level shape and `s` are equivalent
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
display : Show a => Array s a -> IO ()
display = printLn . PrimArray.toList . getPrim
export
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
stack axis arrs = rewrite sym (lengthCorrect arrs) in
fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
(i,is') => index is' (index i arrs))
where
getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
getAxisInd FZ (i :: is) = (i, is)
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
--------------------------------------------------------------------------------
@ -392,9 +397,8 @@ export
{s : _} -> Monoid a => Monoid (Array s a) where
neutral = repeat s neutral
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
-- The shape must be known at runtime here due to `fromInteger`. If `fromInteger`
-- were moved into its own interface, this constraint could be removed.
export
{s : _} -> Num a => Num (Array s a) where
(+) = zipWith (+)
@ -444,3 +448,30 @@ Show a => Show (Array s a) where
then List.replicate d ("[]" :: Prelude.Nil)
else map (insertPunct s) $ splitWindow (length strs `div` d) strs
in "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
--------------------------------------------------------------------------------
-- Numeric array operations
--------------------------------------------------------------------------------
export
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
export
normSq : Num a => Array s a -> a
normSq arr = sum $ zipWith (*) arr arr
export
norm : Array s Double -> Double
norm = sqrt . normSq
export
normalize : Array s Double -> Array s Double
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
export
pnorm : (p : Double) -> Array s Double -> Double
pnorm p = (`pow` recip p) . sum . map (`pow` p)

View file

@ -102,3 +102,10 @@ export
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
hconcat = concat 1
export
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
kronecker a b = rewrite dimEq a in rewrite dimEq b in
fromFunction [dim a, dim b]
(\[i,j] => Vector.index (rewrite dimEq a in i) a *
Vector.index (rewrite dimEq b in j) b)

View file

@ -3,7 +3,7 @@ module Data.NumIdr.Multiply
%default total
infixr 7 *.
infixr 9 *.
||| A generalized multiplication/transformation operator. This interface is
||| necessary since the standard multiplication operator is homogenous.

View file

@ -11,6 +11,14 @@ Vector : Nat -> Type -> Type
Vector n = Array [n]
public export
dim : Vector n a -> Nat
dim = head . shape
export
dimEq : (v : Vector n a) -> n = dim v
dimEq v = cong head $ shapeEq v
--------------------------------------------------------------------------------
-- Vector constructors
--------------------------------------------------------------------------------
@ -107,3 +115,4 @@ export
perp : Neg a => Vector 2 a -> Vector 2 a -> a
perp a b = a.x * b.y - a.y * b.x