Add more array utility functions
This commit is contained in:
parent
03d06a42aa
commit
acd0cb6aa5
|
@ -83,7 +83,7 @@ rank : Array s a -> Nat
|
||||||
rank = length . shape
|
rank = length . shape
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
shapeEq : (arr : Array s a) -> s = shape arr
|
shapeEq : (arr : Array s a) -> s = shape arr
|
||||||
shapeEq (MkArray _ _ _ _) = Refl
|
shapeEq (MkArray _ _ _ _) = Refl
|
||||||
|
|
||||||
|
@ -302,10 +302,15 @@ concat axis a b = let sA = shape a
|
||||||
-- TODO: prove that the type-level shape and `s` are equivalent
|
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||||
|
|
||||||
|
export
|
||||||
|
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
|
||||||
display : Show a => Array s a -> IO ()
|
stack axis arrs = rewrite sym (lengthCorrect arrs) in
|
||||||
display = printLn . PrimArray.toList . getPrim
|
fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
|
||||||
|
(i,is') => index is' (index i arrs))
|
||||||
|
where
|
||||||
|
getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
|
||||||
|
getAxisInd FZ (i :: is) = (i, is)
|
||||||
|
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
|
||||||
|
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -392,9 +397,8 @@ export
|
||||||
{s : _} -> Monoid a => Monoid (Array s a) where
|
{s : _} -> Monoid a => Monoid (Array s a) where
|
||||||
neutral = repeat s neutral
|
neutral = repeat s neutral
|
||||||
|
|
||||||
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
|
-- The shape must be known at runtime here due to `fromInteger`. If `fromInteger`
|
||||||
-- were moved into its own interface, this constraint could be removed.
|
-- were moved into its own interface, this constraint could be removed.
|
||||||
|
|
||||||
export
|
export
|
||||||
{s : _} -> Num a => Num (Array s a) where
|
{s : _} -> Num a => Num (Array s a) where
|
||||||
(+) = zipWith (+)
|
(+) = zipWith (+)
|
||||||
|
@ -444,3 +448,30 @@ Show a => Show (Array s a) where
|
||||||
then List.replicate d ("[]" :: Prelude.Nil)
|
then List.replicate d ("[]" :: Prelude.Nil)
|
||||||
else map (insertPunct s) $ splitWindow (length strs `div` d) strs
|
else map (insertPunct s) $ splitWindow (length strs `div` d) strs
|
||||||
in "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
|
in "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
|
||||||
|
|
||||||
|
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
-- Numeric array operations
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
|
||||||
|
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
normSq : Num a => Array s a -> a
|
||||||
|
normSq arr = sum $ zipWith (*) arr arr
|
||||||
|
|
||||||
|
export
|
||||||
|
norm : Array s Double -> Double
|
||||||
|
norm = sqrt . normSq
|
||||||
|
|
||||||
|
export
|
||||||
|
normalize : Array s Double -> Array s Double
|
||||||
|
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
|
||||||
|
|
||||||
|
export
|
||||||
|
pnorm : (p : Double) -> Array s Double -> Double
|
||||||
|
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
||||||
|
|
|
@ -102,3 +102,10 @@ export
|
||||||
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||||
hconcat = concat 1
|
hconcat = concat 1
|
||||||
|
|
||||||
|
|
||||||
|
export
|
||||||
|
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
||||||
|
kronecker a b = rewrite dimEq a in rewrite dimEq b in
|
||||||
|
fromFunction [dim a, dim b]
|
||||||
|
(\[i,j] => Vector.index (rewrite dimEq a in i) a *
|
||||||
|
Vector.index (rewrite dimEq b in j) b)
|
||||||
|
|
|
@ -3,7 +3,7 @@ module Data.NumIdr.Multiply
|
||||||
%default total
|
%default total
|
||||||
|
|
||||||
|
|
||||||
infixr 7 *.
|
infixr 9 *.
|
||||||
|
|
||||||
||| A generalized multiplication/transformation operator. This interface is
|
||| A generalized multiplication/transformation operator. This interface is
|
||||||
||| necessary since the standard multiplication operator is homogenous.
|
||| necessary since the standard multiplication operator is homogenous.
|
||||||
|
|
|
@ -11,6 +11,14 @@ Vector : Nat -> Type -> Type
|
||||||
Vector n = Array [n]
|
Vector n = Array [n]
|
||||||
|
|
||||||
|
|
||||||
|
public export
|
||||||
|
dim : Vector n a -> Nat
|
||||||
|
dim = head . shape
|
||||||
|
|
||||||
|
export
|
||||||
|
dimEq : (v : Vector n a) -> n = dim v
|
||||||
|
dimEq v = cong head $ shapeEq v
|
||||||
|
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
-- Vector constructors
|
-- Vector constructors
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
|
@ -107,3 +115,4 @@ export
|
||||||
perp : Neg a => Vector 2 a -> Vector 2 a -> a
|
perp : Neg a => Vector 2 a -> Vector 2 a -> a
|
||||||
perp a b = a.x * b.y - a.y * b.x
|
perp a b = a.x * b.y - a.y * b.x
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue