From eddb0e318c5d2d814d6615be66b3ff750dc6ddc6 Mon Sep 17 00:00:00 2001 From: Kiana Sheibani Date: Wed, 30 Nov 2022 22:02:56 -0500 Subject: [PATCH] Add more utility functions --- numidr.ipkg | 2 +- src/Data/NumIdr/Array/Array.idr | 19 +++++++++++++++++++ src/Data/NumIdr/Transform/Transform.idr | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/numidr.ipkg b/numidr.ipkg index 7940648..2d60408 100644 --- a/numidr.ipkg +++ b/numidr.ipkg @@ -1,5 +1,5 @@ package numidr -version = 0.2.0 +version = 0.2.1 authors = "Kiana Sheibani" license = "MIT" diff --git a/src/Data/NumIdr/Array/Array.idr b/src/Data/NumIdr/Array/Array.idr index 2f0ef01..abb363c 100644 --- a/src/Data/NumIdr/Array/Array.idr +++ b/src/Data/NumIdr/Array/Array.idr @@ -516,6 +516,25 @@ stack axis arrs = rewrite sym (lengthCorrect arrs) in getAxisInd FZ (i :: is) = (i, is) getAxisInd {s=_::_} (FS ax) (i :: is) = mapSnd (i::) (getAxisInd ax is) +export +joinAxes : {s' : _} -> Array s (Array s' a) -> Array (s ++ s') a +joinAxes {s} arr with (viewShape arr) + _ | Shape s = fromFunctionNB (s ++ s') (\is => arr !# takeUpTo s is !# dropUpTo s is) + where + takeUpTo : Vect rk Nat -> Vect (rk + rk') Nat -> Vect rk Nat + takeUpTo [] ys = [] + takeUpTo (x::xs) (y::ys) = y :: takeUpTo xs ys + + dropUpTo : Vect rk Nat -> Vect (rk + rk') Nat -> Vect rk' Nat + dropUpTo [] ys = ys + dropUpTo (x::xs) (y::ys) = dropUpTo xs ys + + +export +splitAxes : (rk : Nat) -> {0 rk' : Nat} -> {s : _} -> + Array {rk=rk+rk'} s a -> Array (take {m=rk'} rk s) (Array (drop {m=rk'} rk s) a) +splitAxes _ {s} arr = fromFunctionNB _ (\is => fromFunctionNB _ (\is' => arr !# (is ++ is'))) + ||| Construct the transpose of an array by reversing the order of its axes. export diff --git a/src/Data/NumIdr/Transform/Transform.idr b/src/Data/NumIdr/Transform/Transform.idr index e60931c..cdb851b 100644 --- a/src/Data/NumIdr/Transform/Transform.idr +++ b/src/Data/NumIdr/Transform/Transform.idr @@ -120,6 +120,19 @@ setTranslation : Num a => Vector n a -> Transform ty n a setTranslation v (MkTrans _ mat) = MkTrans _ (hmatrix (getMatrix mat) v) +namespace Vector + export + applyInv : FieldCmp a => Transform ty n a -> Vector n a -> Vector n a + applyInv tr v = assert_total $ case solve (getMatrix tr) v of Just v' => v' + +namespace Point + export + applyInv : FieldCmp a => Transform ty n a -> Point n a -> Point n a + applyInv (MkTrans _ mat) p = assert_total $ + case solve (getMatrix mat) (zipWith (-) (toVector p) (getTranslationVector mat)) of + Just v => fromVector v + + -------------------------------------------------------------------------------- -- Interface implementations --------------------------------------------------------------------------------