dmx.compressor.modeling.nn.torch_modules.GroupNorm
- class dmx.compressor.modeling.nn.torch_modules.GroupNorm(num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True)
An extension of PyTorch’s GroupNorm layer to support DmxModule configurations. This module applies group normalization over an input tensor, suitable for use with various types of layers. The module is parameterized by the number of groups, number of channels, epsilon value for numerical stability, and an option to use affine transformation.
- Parameters:
num_groups (int) – Number of groups to separate the channels into.
channels (int) – Number of channels in the input tensor.
eps (float, optional) – A small constant added to the denominator for numerical stability. Defaults to 1e-5.
affine (bool, optional) – Whether to include learnable affine parameters for this layer. Defaults to True.
- _forward (_input
Tensor) -> Tensor: Computes the forward pass of the group normalization.
- __init__(num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True) None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(num_groups, num_channels[, eps, affine])Initialize internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Add a child module to the current module.
align_device(_input, args, kwargs, _device)apply(fn)Apply
fnrecursively to every submodule (as returned by.children()) as well as self.approx_forward(inputs, *args, **kwargs)approximator_wrapper(inputs, approx_args, ...)Override this in the DMX modules to enable pre-processing of the inputs, and the SIMD approximator arguments before calling the SIMD reference kernels.
bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Return an iterator over module buffers.
calibrating_quantizers(hyperparams)calibrating_smoothquant(hyperparams)check_format_dim_consistency()check_input_format_dim_consistency()check_residual_format_dim_consistency()check_sparseness_dim_consistency()check_weight_format_dim_consistency()children()Return an iterator over immediate children modules.
compile(*args, **kwargs)Compile this Module's forward using
torch.compile().configure(config)A function that changes the format of the ops and loading state dicts according to the config file.
count_flops(_input, _output)counting_flops([zero])cpu()Move all model parameters and buffers to the CPU.
cuda([device])Move all model parameters and buffers to the GPU.
dmx_config([freeze])A function that the DmxModuleConfig object for the module
double()Casts all floating point parameters and buffers to
doubledatatype.enable_approximation_function_tuning(state, ...)enable_flop_counter([state])enable_optimal_brain_compression(state, ...)enable_quantizer_calib(state, hyperparams)enable_smoothquant_calib(state, hyperparams)eval()Set the module in evaluation mode.
extra_repr()Return the extra representation of the module.
float()Casts all floating point parameters and buffers to
floatdatatype.fold_weight_and_bias()A function that applies the ops the weights and biases using the corresponding formats.
forward(input, *args, **kwargs)Forward pass of the module with quantization ops applied.
get_buffer(target)Return the buffer given by
targetif it exists, otherwise throw an error.get_extra_state()Return any extra state to include in the module's state_dict.
get_parameter(target)Return the parameter given by
targetif it exists, otherwise throw an error.get_submodule(target)Return the submodule given by
targetif it exists, otherwise throw an error.half()Casts all floating point parameters and buffers to
halfdatatype.infer_ch_axis()init_casts()init_smoothquant([migration_strength, ...])init_sparsifier()ipu([device])Move all model parameters and buffers to the IPU.
load_state_dict(state_dict, *[, strict, assign])load_state_dict_and_register_url(url)A function that loads state dict from a url and sets url to self.state_dict_url
measuring_runtime(_records, device)Context manager for monitoring runtime of DmxModule
modules()Return an iterator over all modules in the network.
monitoring(_records)Context manager for monitoring input/output to/from the DmxModule
mtia([device])Move all model parameters and buffers to the MTIA.
named_buffers([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
optimal_brain_compressing(hyperparams)parameters([recurse])Return an iterator over module parameters.
register_backward_hook(hook)Register a backward hook on the module.
register_buffer(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook(hook)Register a post-hook to be run after module's
load_state_dict()is called.register_load_state_dict_pre_hook(hook)Register a pre-hook to be run before module's
load_state_dict()is called.register_module(name, module)Alias for
add_module().register_parameter(name, param)Add a parameter to the module.
register_state_dict_post_hook(hook)Register a post-hook for the
state_dict()method.register_state_dict_pre_hook(hook)Register a pre-hook for the
state_dict()method.requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
reset_parameters()save_state_dict_and_register_url(parent_dir)A function that saves the current state dict of the module to a url under a specified parent directory
set_extra_state(state)Set extra state contained in the loaded state_dict.
set_submodule(target, module[, strict])Set the submodule given by
targetif it exists, otherwise throw an error.share_memory()See
torch.Tensor.share_memory_().slanc_tuning(hyperparams)state_dict(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to(*args, **kwargs)Move and/or cast the parameters and buffers.
Returns a compiler friendly graph
to_empty(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train([mode])Set the module in training mode.
transform(config)A function that changes the format of the ops and loading state dicts according to the config file.
tuning_approximation_function(hyperparams)type(dst_type)Casts all parameters and buffers to
dst_type.update_params_with_raw(raw)Update parameters of a DmxModule from a torch.nn.Module.
update_smoothquant_scale(input)xpu([device])Move all model parameters and buffers to the XPU.
zero_flop_counter()zero_grad([set_to_none])Reset gradients of all model parameters.
Attributes
T_destinationaccum_formatapproximation_functionbias_formatbopscall_super_initdump_patcheseffective_weightflop_counterflop_counter_enabledflopsfunctional_forwardinput_formatsinput_precisionis_compoundlast_input_shapelast_output_shapemultiplier_formatoutput_formatspluginsresidual_formatweight_elem_countweight_formatweight_hypernetReturns a function that processes weight according to the ops format of the module
weight_precisionweight_scaleweight_size_in_bytesweight_sparsenessweight_storage_formatweight_storage_precisionweight_storage_scaleweight_storage_zero_pointweight_zero_pointnum_groupsnum_channelsepsaffine