Skip to Content

reduce_argmax_op

View the code on GitHub

Structs

Struct: ReduceArgMax

Fields

Methods

compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after reducing along a specific axis.
Args
  • curr: ArrayShape The ArrayShape to store the result of the computation.

  • args: List[ArrayShape] The ArrayShape to reduce, and the axis to reduce along encoded in an ArrayShape.

Constraints:

  • The axis must be a valid axis of the ArrayShape (args[0]).
  • The number of axis must not exceed the number of dimensions of the ArrayShape (args[0]).
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for element-wise arg_maxition of two arrays.
Args
  • curr: Array The current array to store the result (modified in-place).

  • args: List[Array] A list containing the input arrays.

Computes the sum of the input arrays and stores the result in the current array. Initializes the current array if not already set up.

Note: This function assumes that the shape and data of the args are already set up. If the current array (curr) is not initialized, it computes the shape based on the input array and the axis and sets up the data accordingly.

jvp(primals: List[Array], tangents: List[Array]) -> Array
more details
Args
  • primals: List[Array]

  • tangents: List[Array]

Returns
  • Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
more details
Args
  • primals: List[Array]

  • grad: Array

  • out: Array

Returns
  • List[Array]
fwd(arg0: Array, axis: List[Int]) -> Array
Reduces the input array along the specified axis by summing the elements.
Args
  • arg0: Array The input array.

  • axis: List[Int] The axis along which to reduce the array.

Returns
  • Array - An array containing the sum of the input array along the specified axis.

Examples:

a = Array([[1, 2], [3, 4]]) result = reduce_argmax(a, List(0)) print(result)

Note: This function supports:

  • Automatic differentiation (forward and reverse modes).
  • Complex valued arguments.

Functions

reduce_argmax

reduce_argmax(arg0: Array, axis: List[Int], keepdims: Bool = False) -> Array
Reduces the input array along the specified axis by summing the elements.
Args
  • arg0: Array The input array.

  • axis: List[Int] The axis along which to reduce the array.

  • keepdims: Bool (default: False) If True, retains the reduced dimensions with length 1.

Returns
  • Array - An array containing the sum of the input array along the specified axis.

Examples:

a = Array([[1, 2], [3, 4]]) result = reduce_argmax(a, List(0)) print(result)

Note: This function supports:

  • Automatic differentiation (forward and reverse modes).
  • Complex valued arguments.
Last updated on