Define functions for permuting arrays

This commit is contained in:
Kiana Sheibani 2022-09-01 18:25:07 -04:00
parent 0177781c74
commit b74734fbc1
Signed by: toki
GPG key ID: 6CB106C25E86A9F7

View file

@ -5,6 +5,7 @@ import Data.List1
import Data.Vect
import Data.Zippable
import Data.NP
import Data.Permutation
import Data.NumIdr.Multiply
import Data.NumIdr.PrimArray
import Data.NumIdr.Array.Order
@ -494,6 +495,27 @@ export
(.T) = transpose
export
swapAxes : (i,j : Fin rk) -> Array s a -> Array (swapElems i j s) a
swapAxes i j arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# swapElems i j is)
export
permuteAxes : (p : Permutation rk) -> Array s a -> Array (permuteVect p s) a
permuteAxes p arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# permuteVect p s)
export
swapInAxis : (ax : Fin rk) -> (i,j : Fin (index ax s)) -> Array s a -> Array s a
swapInAxis ax i j arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# updateAt ax (swapValues i j) is)
export
permuteInAxis : (ax : Fin rk) -> Permutation (index ax s) -> Array s a -> Array s a
permuteInAxis ax p arr with (viewShape arr)
_ | Shape s = fromFunctionNB _ (\is => arr !# updateAt ax (permuteValues p) is)
--------------------------------------------------------------------------------
-- Implementations
--------------------------------------------------------------------------------