Skip to Content
DocumentationFunctionalIndex Opsconcat_op

concat_op

View the code on GitHub

Structs

Functions

concat_shape

concat_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after concatenation.
Args
  • curr: ArrayShape The ArrayShape to store the result of the computation.

  • args: List[ArrayShape] The ArrayShapes to concatenate, and the axis to concatenate along encoded in an ArrayShape.

concat_fwd

concat_fwd(mut curr: Array, args: List[Array])
Performs the forward pass for the concat operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
Args
  • curr: Array The current array to store the result (modified in-place).

  • args: List[Array] The arrays to concatenate.

Note: The information of the shape computation is stored in the ArrayShape object of the curr array.

concat_vjp

concat_vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the concat operation.
Args
  • primals: List[Array] A list containing the primal input arrays.

  • grad: Array The gradient of the output with respect to some scalar function.

  • out: Array The output of the forward pass.

Returns
  • List[Array] - A list containing the gradients with respect to the input arrays.

Note: The vector-Jacobian product for concat is computed by returning an empty list.

concat

concat(args: List[Array], axis: Int) -> Array
Concatenates the input arrays along the given axis.
Args
  • args: List[Array] The arrays to concatenate.

  • axis: Int The axis along which to concatenate.

Returns
  • Array - The concatenated array.
Last updated on