Skip to Content

squeeze_op

View the code on GitHub

Structs

Struct: Squeeze

Fields

Methods

squeezable_axis(mut curr: ArrayShape, args: List[ArrayShape])
more details
Args
  • curr: ArrayShape

  • args: List[ArrayShape]

compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after squeezing. This removes all dimensions of size 1.
Args
  • curr: ArrayShape The ArrayShape to store the result of the computation.

  • args: List[ArrayShape] The ArrayShape to squeeze.

__call__(mut curr: Array, args: List[Array])
Performs the forward pass for the squeeze operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
Args
  • curr: Array The current array to store the result (modified in-place).

  • args: List[Array] The array on which the squeeze view is created.

Note: The information of the shape computation is stored in the ArrayShape object of the curr array.

jvp(primals: List[Array], tangents: List[Array]) -> Array
more details
Args
  • primals: List[Array]

  • tangents: List[Array]

Returns
  • Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the squeeze operation.
Args
  • primals: List[Array] A list containing the primal input array.

  • grad: Array The gradient of the output with respect to some scalar function.

  • out: Array The output of the forward pass (unused in this function).

Returns
  • List[Array] - A list containing the gradient with respect to the input.

Note: The vector-Jacobian product for squeeze is computed by unsqueezing the gradient along the axes that were squeezed.

fwd(arg0: Array) -> Array
Squeezes the input array by removing axes of length 1.
Args
  • arg0: Array The input array.
Returns
  • Array - The squeezed array.

Functions

squeeze

squeeze(arg0: Array) -> Array
Squeezes the input array by removing axes of length 1.
Args
  • arg0: Array The input array.
Returns
  • Array - The squeezed array.
Last updated on