as_strided_op
Structs
Struct: AsStrided
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after striding.
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for the as_strided operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the as_strided operation.
Args
-
primals
:List[Array]
A list containing the primal input array. -
grad
:Array
The gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass.
Returns
List[Array]
- A list containing the gradient with respect to the input.
Note: The vector-Jacobian product for as_strided is computed by calling the inverse operation as_strided_inv.
fwd(arg0: Array, shape: List[Int], stride: List[Int], storage_offset: Int) -> Array
Creates a view of the input array with the given shape and stride.
Struct: AsStridedInv
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after striding, in an inverse manner to the as_strided_shape function.
__call__(mut curr: Array, args: List[Array])
Performs the forward pass for the as_strided_inv operation. It sets the base of the argument to be the base of the current array and computes the shape of the current array via its dedicated ArraySahpe fwd fucntion.
jvp(primals: List[Array], tangents: List[Array]) -> Array
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
Computes the vector-Jacobian product for the as_strided_inv operation.
Args
-
primals
:List[Array]
A list containing the primal input array. -
grad
:Array
The gradient of the output with respect to some scalar function. -
out
:Array
The output of the forward pass.
Returns
List[Array]
- A list containing the gradient with respect to the input.
Note: The vector-Jacobian product for as_strided_inv is computed by calling the as_strided operation.
fwd(arg0: Array, target_shape: ArrayShape, shape: List[Int], stride: List[Int], storage_offset: Int) -> Array
Creates a view of the input array with the given shape and stride.
Functions
as_strided
as_strided(arg0: Array, shape: List[Int], stride: List[Int], storage_offset: Int) -> Array
Creates a view of the input array with the given shape and stride.
as_strided_inv
as_strided_inv(arg0: Array, target_shape: ArrayShape, shape: List[Int], stride: List[Int], storage_offset: Int) -> Array
Creates a view of the input array with the given shape and stride.
Last updated on