Add more array utility functions
This commit is contained in:
parent
03d06a42aa
commit
acd0cb6aa5
|
@ -83,7 +83,7 @@ rank : Array s a -> Nat
|
|||
rank = length . shape
|
||||
|
||||
|
||||
|
||||
export
|
||||
shapeEq : (arr : Array s a) -> s = shape arr
|
||||
shapeEq (MkArray _ _ _ _) = Refl
|
||||
|
||||
|
@ -273,8 +273,8 @@ reorder ord' arr = let s = shape arr
|
|||
sts = calcStrides ord' s
|
||||
in rewrite shapeEq arr
|
||||
in MkArray ord' sts _ (unsafeFromIns (product s) $
|
||||
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||
getAllCoords' s)
|
||||
map (\is => (getLocation' sts is, index (getLocation' (strides arr) is) (getPrim arr))) $
|
||||
getAllCoords' s)
|
||||
|
||||
|
||||
export
|
||||
|
@ -302,10 +302,15 @@ concat axis a b = let sA = shape a
|
|||
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||
|
||||
|
||||
|
||||
display : Show a => Array s a -> IO ()
|
||||
display = printLn . PrimArray.toList . getPrim
|
||||
export
|
||||
stack : {s : _} -> (axis : Fin (S rk)) -> Vect n (Array {rk} s a) -> Array (insertAt axis n s) a
|
||||
stack axis arrs = rewrite sym (lengthCorrect arrs) in
|
||||
fromFunction _ (\is => case getAxisInd axis (rewrite sym (lengthCorrect arrs) in is) of
|
||||
(i,is') => index is' (index i arrs))
|
||||
where
|
||||
getAxisInd : {0 rk : _} -> {s : _} -> (ax : Fin (S rk)) -> Coords (insertAt ax n s) -> (Fin n, Coords s)
|
||||
getAxisInd FZ (i :: is) = (i, is)
|
||||
getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is)
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -392,9 +397,8 @@ export
|
|||
{s : _} -> Monoid a => Monoid (Array s a) where
|
||||
neutral = repeat s neutral
|
||||
|
||||
-- the shape must be known at runtime due to `fromInteger`. If `fromInteger`
|
||||
-- The shape must be known at runtime here due to `fromInteger`. If `fromInteger`
|
||||
-- were moved into its own interface, this constraint could be removed.
|
||||
|
||||
export
|
||||
{s : _} -> Num a => Num (Array s a) where
|
||||
(+) = zipWith (+)
|
||||
|
@ -444,3 +448,30 @@ Show a => Show (Array s a) where
|
|||
then List.replicate d ("[]" :: Prelude.Nil)
|
||||
else map (insertPunct s) $ splitWindow (length strs `div` d) strs
|
||||
in "[" :: (concat $ intersperse [", "] secs) `snoc` "]"
|
||||
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Numeric array operations
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
|
||||
export
|
||||
lerp : Neg a => a -> Array s a -> Array s a -> Array s a
|
||||
lerp t a b = zipWith (+) (a *. (1 - t)) (b *. t)
|
||||
|
||||
|
||||
export
|
||||
normSq : Num a => Array s a -> a
|
||||
normSq arr = sum $ zipWith (*) arr arr
|
||||
|
||||
export
|
||||
norm : Array s Double -> Double
|
||||
norm = sqrt . normSq
|
||||
|
||||
export
|
||||
normalize : Array s Double -> Array s Double
|
||||
normalize arr = if all (==0) arr then arr else map (/ norm arr) arr
|
||||
|
||||
export
|
||||
pnorm : (p : Double) -> Array s Double -> Double
|
||||
pnorm p = (`pow` recip p) . sum . map (`pow` p)
|
||||
|
|
|
@ -102,3 +102,10 @@ export
|
|||
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||
hconcat = concat 1
|
||||
|
||||
|
||||
export
|
||||
kronecker : Num a => Vector m a -> Vector n a -> Matrix m n a
|
||||
kronecker a b = rewrite dimEq a in rewrite dimEq b in
|
||||
fromFunction [dim a, dim b]
|
||||
(\[i,j] => Vector.index (rewrite dimEq a in i) a *
|
||||
Vector.index (rewrite dimEq b in j) b)
|
||||
|
|
|
@ -3,7 +3,7 @@ module Data.NumIdr.Multiply
|
|||
%default total
|
||||
|
||||
|
||||
infixr 7 *.
|
||||
infixr 9 *.
|
||||
|
||||
||| A generalized multiplication/transformation operator. This interface is
|
||||
||| necessary since the standard multiplication operator is homogenous.
|
||||
|
|
|
@ -11,6 +11,14 @@ Vector : Nat -> Type -> Type
|
|||
Vector n = Array [n]
|
||||
|
||||
|
||||
public export
|
||||
dim : Vector n a -> Nat
|
||||
dim = head . shape
|
||||
|
||||
export
|
||||
dimEq : (v : Vector n a) -> n = dim v
|
||||
dimEq v = cong head $ shapeEq v
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Vector constructors
|
||||
--------------------------------------------------------------------------------
|
||||
|
@ -107,3 +115,4 @@ export
|
|||
perp : Neg a => Vector 2 a -> Vector 2 a -> a
|
||||
perp a b = a.x * b.y - a.y * b.x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue