max_pool2d_op
Structs
Struct: MaxPool2d
Namespace for 2D max pooling operations.
Fields
Methods
compute_shape(mut curr: ArrayShape, args: List[ArrayShape])
Computes the shape of an array after a 2-dimensional max pooling operation with dilation.
__call__(mut curr: Array, args: List[Array])
vjp(primals: List[Array], grad: Array, out: Array) -> List[Array]
jvp(primals: List[Array], tangents: List[Array]) -> Array
fwd(arg0: Array, kernel_size: Tuple[Int, Int], stride: Tuple[Int, Int] = Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True)), padding: Tuple[Int, Int] = Tuple(VariadicPack(<store_to_mem({0}), store_to_mem({0})>, True)), dilation: Tuple[Int, Int] = Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True))) -> Array
more details
Args
-
arg0
:Array
-
kernel_size
:Tuple[Int, Int]
-
stride
:Tuple[Int, Int]
(default:Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True))
) -
padding
:Tuple[Int, Int]
(default:Tuple(VariadicPack(<store_to_mem({0}), store_to_mem({0})>, True))
) -
dilation
:Tuple[Int, Int]
(default:Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True))
)
Returns
Array
Functions
max_pool2d
max_pool2d(arg0: Array, kernel_size: Tuple[Int, Int], stride: Tuple[Int, Int] = Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True)), padding: Tuple[Int, Int] = Tuple(VariadicPack(<store_to_mem({0}), store_to_mem({0})>, True)), dilation: Tuple[Int, Int] = Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True))) -> Array
Applies a 2D max pooling operation over an input tensor.
Args
-
arg0
:Array
The input tensor. -
kernel_size
:Tuple[Int, Int]
The size of the window to take a max over. -
stride
:Tuple[Int, Int]
(default:Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True))
) The stride of the window. Default is (1, 1). -
padding
:Tuple[Int, Int]
(default:Tuple(VariadicPack(<store_to_mem({0}), store_to_mem({0})>, True))
) The padding to apply to the input tensor. Default is (0, 0). -
dilation
:Tuple[Int, Int]
(default:Tuple(VariadicPack(<store_to_mem({1}), store_to_mem({1})>, True))
) The dilation to apply to the input tensor. Default is (1, 1).
Returns
Array
- The output tensor.