Rename stack to concat
This commit is contained in:
parent
87d8814c38
commit
03d06a42aa
|
@ -289,18 +289,18 @@ enumerate arr = map (\is => (is, index is arr))
|
|||
(rewrite shapeEq arr in getAllCoords $ shape arr)
|
||||
|
||||
export
|
||||
stack : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||
concat : (axis : Fin rk) -> Array {rk} s a -> Array (replaceAt axis d s) a ->
|
||||
Array (updateAt axis (+d) s) a
|
||||
stack axis a b = let sA = shape a
|
||||
sB = shape b
|
||||
dA = index axis sA
|
||||
dB = index axis sB
|
||||
s = replaceAt axis (dA + dB) sA
|
||||
sts = calcStrides COrder s
|
||||
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||
concat axis a b = let sA = shape a
|
||||
sB = shape b
|
||||
dA = index axis sA
|
||||
dB = index axis sB
|
||||
s = replaceAt axis (dA + dB) sA
|
||||
sts = calcStrides COrder s
|
||||
ins = map (mapFst $ getLocation' sts . toNats) (enumerate a)
|
||||
++ map (mapFst $ getLocation' sts . updateAt axis (+dA) . toNats) (enumerate b)
|
||||
-- TODO: prove that the type-level shape and `s` are equivalent
|
||||
in believe_me $ MkArray COrder sts s (unsafeFromIns (product s) ins)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -95,10 +95,10 @@ getColumn c mat = rewrite sym (minusZeroRight m) in indexRange [All, One c] mat
|
|||
--------------------------------------------------------------------------------
|
||||
|
||||
export
|
||||
vstack : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
||||
vstack = stack 0
|
||||
vconcat : Matrix m n a -> Matrix m' n a -> Matrix (m + m') n a
|
||||
vconcat = concat 0
|
||||
|
||||
export
|
||||
hstack : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||
hstack = stack 1
|
||||
hconcat : Matrix m n a -> Matrix m n' a -> Matrix m (n + n') a
|
||||
hconcat = concat 1
|
||||
|
||||
|
|
|
@ -94,8 +94,8 @@ swizzle p v = rewrite sym (lengthCorrect p)
|
|||
|
||||
|
||||
export
|
||||
concat : Vector m a -> Vector n a -> Vector (m + n) a
|
||||
concat = stack 0
|
||||
(++) : Vector m a -> Vector n a -> Vector (m + n) a
|
||||
(++) = concat 0
|
||||
|
||||
|
||||
export
|
||||
|
|
Loading…
Reference in a new issue