dmx.compressor.modeling.nn.torch_modules.MaxPool2d
- class dmx.compressor.modeling.nn.torch_modules.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
An extension of PyTorch’s MaxPool2d layer to support DmxModule configurations. This module applies a 2D max pooling over an input signal composed of several input planes.
- Parameters:
kernel_size (int or tuple) – Size of the window to take a max over.
stride (int or tuple, optional) – Stride of the window. Defaults to None.
padding (int or tuple, optional) – Zero-padding added to both sides of the input. Defaults to 0.
dilation (int or tuple, optional) – Spacing between kernel elements. Defaults to 1.
return_indices (bool, optional) – If True, will return the max indices in a second tensor. Defaults to False.
ceil_mode (bool, optional) – If True, will use ceil instead of floor to compute the output shape. Defaults to False.
- None specific to this subclass. Inherits attributes from parent classes.
- _forward (_input
Tensor) -> Tensor: Computes the forward pass of the 2D max pooling.
- __init__(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False) None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(kernel_size[, stride, padding, ...])Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Add a child module to the current module.
align_device(_input, args, kwargs, _device)apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.approx_forward(inputs, *args, **kwargs)approximator_wrapper(inputs, approx_args, ...)Override this in the DMX modules to enable pre-processing of the inputs, and the SIMD approximator arguments before calling the SIMD reference kernels.
bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
calibrating_quantizers(hyperparams)calibrating_smoothquant(hyperparams)check_format_dim_consistency()check_input_format_dim_consistency()check_residual_format_dim_consistency()check_sparseness_dim_consistency()check_weight_format_dim_consistency()children()Return an iterator over immediate children modules.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().configure(config)A function that changes the format of the ops and loading state dicts according to the config file.
count_flops(_input, _output)counting_flops([zero])cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
dmx_config([freeze])A function that the DmxModuleConfig object for the module
double()Casts all floating point parameters and buffers to
doubledatatype.enable_approximation_function_tuning(state, ...)enable_flop_counter([state])enable_optimal_brain_compression(state, ...)enable_quantizer_calib(state, hyperparams)enable_smoothquant_calib(state, hyperparams)eval()Set the module in evaluation mode.
extra_repr()Return the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.fold_weight_and_bias()A function that applies the ops the weights and biases using the corresponding formats.
forward(input, *args, **kwargs)Forward pass of the module with quantization ops applied.
from_raw(raw)Creates a new MaxPool2d object (DmxModule) from a given PyTorch MaxPool2d layer.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.infer_ch_axis()init_casts()init_smoothquant([migration_strength, ...])init_sparsifier()ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict, *[, strict, assign])load_state_dict_and_register_url(url)A function that loads state dict from a url and sets url to self.state_dict_url
measuring_runtime(_records, device)Context manager for monitoring runtime of DmxModule
modules()Return an iterator over all modules in the network.
monitoring(_records)Context manager for monitoring input/output to/from the DmxModule
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
optimal_brain_compressing(hyperparams)parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
save_state_dict_and_register_url(parent_dir)A function that saves the current state dict of the module to a url under a specified parent directory
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module[, strict])Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()See
torch.Tensor.share_memory_().slanc_tuning(hyperparams)state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
to_compiler_graph()Returns a compiler friendly graph
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
transform(config)A function that changes the format of the ops and loading state dicts according to the config file.
tuning_approximation_function(hyperparams)type(dst_type)Casts all parameters and buffers to
dst_type.update_params_with_raw(raw)Update parameters of a DmxModule from a torch.nn.Module.
update_smoothquant_scale(input)xpu([device])Move all model parameters and buffers to the XPU.
zero_flop_counter()zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
T_destinationaccum_formatapproximation_functionbias_formatbopscall_super_initdump_patcheseffective_weightflop_counterflop_counter_enabledflopsfunctional_forwardinput_formatsinput_precisionis_compoundlast_input_shapelast_output_shapemultiplier_formatoutput_formatspluginsresidual_formatweight_elem_countweight_formatweight_hypernetReturns a function that processes weight according to the ops format of the module
weight_precisionweight_scaleweight_size_in_bytesweight_sparsenessweight_storage_formatweight_storage_precisionweight_storage_scaleweight_storage_zero_pointweight_zero_pointkernel_sizestridepaddingdilationreturn_indicesceil_mode