permute_op
Structs
Struct: Permute
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Permutes the dimensions of an array shape given a list of axes.
__call__(mut curr: Array, args: List[Array])
Permutes the input array based on the given axis and stores the result in the current array (curr). The first agument is set as the base of the current array.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Compute vector-Jacobian product for array permutation.
Args
-
primals
:List[Array]
Primal input arrays. -
grad
:Array
Gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass.
Returns
List[Array]
- List[Array]: Gradients with respect to each input.
Note: Implements reverse-mode automatic differentiation for permutation. Returns arrays with shape zero for inputs that do not require gradients.
See Also:
permute_jvp: Forward-mode autodiff for permutation.
fwd(arg0: Array, axis: ArrayShape) -> Array
Creates a view of the input array with its dimensions permuted based on the given axis.
Struct: InvPermute
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Permutes the dimensions of an array shape given a list of axes, in an inverse manner to the permute_shape function.
__call__(mut curr: Array, args: List[Array])
Permutes the input array based on the given axis and stores the result in the current array (curr). The first agument is set as the base of the current array.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Compute vector-Jacobian product for array permutation.
Args
-
primals
:List[Array]
Primal input arrays. -
grad
:Array
Gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass.
Returns
List[Array]
- List[Array]: Gradients with respect to each input.
Note: Implements reverse-mode automatic differentiation for permutation. Returns arrays with shape zero for inputs that do not require gradients.
See Also:
permute_inv_jvp: Forward-mode autodiff for permutation.
fwd(arg0: Array, axis: ArrayShape) -> Array
Creates a view of the input array with its dimensions permuted based on the given axis.
Args
-
arg0
:Array
The input array. -
axis
:ArrayShape
The axis to permute.
Returns
Array
- A view of the input array with its dimensions permuted.
Examples:
a = Array([[1, 2], [3, 4]])
result = permute_inv(a, axis=List(-1,-2))
print(result)
This function supports
- Automatic differentiation (forward and reverse modes).
- Complex valued arguments.
Functions
permute
permute(arg0: Array, axis: ArrayShape) -> Array
Creates a view of the input array with its dimensions permuted based on the given axis.
Args
-
arg0
:Array
The input array. -
axis
:ArrayShape
The axis to permute.
Returns
Array
- A view of the input array with its dimensions permuted.
Examples:
a = Array([[1, 2], [3, 4]])
result = permute(a, axis=List(-1,-2))
print(result)
This function supports
- Automatic differentiation (forward and reverse modes).
- Complex valued arguments.
transpose
transpose(arg0: Array, axis1: Int, axis2: Int) -> Array
Transposes the input array based on the given axes.
swapaxes
swapaxes(arg0: Array, axis1: Int, axis2: Int) -> Array
Swaps the input array’s axes based on the given axes.
swapdims
swapdims(arg0: Array, axis1: Int, axis2: Int) -> Array
Swaps the input array’s dimensions based on the given axes.
permute_inv
permute_inv(arg0: Array, axis: ArrayShape) -> Array
Creates a view of the input array with its dimensions permuted based on the given axis.
Args
-
arg0
:Array
The input array. -
axis
:ArrayShape
The axis to permute.
Returns
Array
- A view of the input array with its dimensions permuted.
Examples:
a = Array([[1, 2], [3, 4]])
result = permute_inv(a, axis=List(-1,-2))
print(result)
This function supports
- Automatic differentiation (forward and reverse modes).
- Complex valued arguments.