concat_op
Structs
Functions
concat_shape
concat_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after concatenation.
concat_fwd
concat_fwd(mut curr: Array, args: List[Array])
Performs the forward pass for the concat operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
concat_vjp
concat_vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the concat operation.
Args
-
primals
:List[Array]
A list containing the primal input arrays. -
grad
:Array
The gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass.
Returns
List[Array]
- A list containing the gradients with respect to the input arrays.
Note: The vector-Jacobian product for concat is computed by returning an empty list.
concat
concat(args: List[Array], axis: Int) -> Array
Last updated on