dmx.compressor.modeling.nn.torch_modules.GroupNorm

class dmx.compressor.modeling.nn.torch_modules.GroupNorm(num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True)

An extension of PyTorch’s GroupNorm layer to support DmxModule configurations. This module applies group normalization over an input tensor, suitable for use with various types of layers. The module is parameterized by the number of groups, number of channels, epsilon value for numerical stability, and an option to use affine transformation.

Parameters:
  • num_groups (int) – Number of groups to separate the channels into.

  • channels (int) – Number of channels in the input tensor.

  • eps (float, optional) – A small constant added to the denominator for numerical stability. Defaults to 1e-5.

  • affine (bool, optional) – Whether to include learnable affine parameters for this layer. Defaults to True.

_forward (_input

Tensor) -> Tensor: Computes the forward pass of the group normalization.

__init__(num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True) None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(num_groups, num_channels[, eps, affine])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

add_module(name, module)

Add a child module to the current module.

align_device(_input, args, kwargs, _device)

apply(fn)

Apply fn recursively to every submodule (as returned by .children()) as well as self.

approx_forward(inputs, *args, **kwargs)

approximator_wrapper(inputs, approx_args, ...)

Override this in the DMX modules to enable pre-processing of the inputs, and the SIMD approximator arguments before calling the SIMD reference kernels.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

buffers([recurse])

Return an iterator over module buffers.

calibrating_quantizers(hyperparams)

calibrating_smoothquant(hyperparams)

check_format_dim_consistency()

check_input_format_dim_consistency()

check_residual_format_dim_consistency()

check_sparseness_dim_consistency()

check_weight_format_dim_consistency()

children()

Return an iterator over immediate children modules.

compile(*args, **kwargs)

Compile this Module's forward using torch.compile().

configure(config)

A function that changes the format of the ops and loading state dicts according to the config file.

count_flops(_input, _output)

counting_flops([zero])

cpu()

Move all model parameters and buffers to the CPU.

cuda([device])

Move all model parameters and buffers to the GPU.

dmx_config([freeze])

A function that the DmxModuleConfig object for the module

double()

Casts all floating point parameters and buffers to double datatype.

enable_approximation_function_tuning(state, ...)

enable_flop_counter([state])

enable_optimal_brain_compression(state, ...)

enable_quantizer_calib(state, hyperparams)

enable_smoothquant_calib(state, hyperparams)

eval()

Set the module in evaluation mode.

extra_repr()

Return the extra representation of the module.

float()

Casts all floating point parameters and buffers to float datatype.

fold_weight_and_bias()

A function that applies the ops the weights and biases using the corresponding formats.

forward(input, *args, **kwargs)

Forward pass of the module with quantization ops applied.

get_buffer(target)

Return the buffer given by target if it exists, otherwise throw an error.

get_extra_state()

Return any extra state to include in the module's state_dict.

get_parameter(target)

Return the parameter given by target if it exists, otherwise throw an error.

get_submodule(target)

Return the submodule given by target if it exists, otherwise throw an error.

half()

Casts all floating point parameters and buffers to half datatype.

infer_ch_axis()

init_casts()

init_smoothquant([migration_strength, ...])

init_sparsifier()

ipu([device])

Move all model parameters and buffers to the IPU.

load_state_dict(state_dict, *[, strict, assign])

load_state_dict_and_register_url(url)

A function that loads state dict from a url and sets url to self.state_dict_url

measuring_runtime(_records, device)

Context manager for monitoring runtime of DmxModule

modules()

Return an iterator over all modules in the network.

monitoring(_records)

Context manager for monitoring input/output to/from the DmxModule

mtia([device])

Move all model parameters and buffers to the MTIA.

named_buffers([prefix, recurse, ...])

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix, remove_duplicate])

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse, ...])

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

optimal_brain_compressing(hyperparams)

parameters([recurse])

Return an iterator over module parameters.

register_backward_hook(hook)

Register a backward hook on the module.

register_buffer(name, tensor[, persistent])

Add a buffer to the module.

register_forward_hook(hook, *[, prepend, ...])

Register a forward hook on the module.

register_forward_pre_hook(hook, *[, ...])

Register a forward pre-hook on the module.

register_full_backward_hook(hook[, prepend])

Register a backward hook on the module.

register_full_backward_pre_hook(hook[, prepend])

Register a backward pre-hook on the module.

register_load_state_dict_post_hook(hook)

Register a post-hook to be run after module's load_state_dict() is called.

register_load_state_dict_pre_hook(hook)

Register a pre-hook to be run before module's load_state_dict() is called.

register_module(name, module)

Alias for add_module().

register_parameter(name, param)

Add a parameter to the module.

register_state_dict_post_hook(hook)

Register a post-hook for the state_dict() method.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

reset_parameters()

save_state_dict_and_register_url(parent_dir)

A function that saves the current state dict of the module to a url under a specified parent directory

set_extra_state(state)

Set extra state contained in the loaded state_dict.

set_submodule(target, module[, strict])

Set the submodule given by target if it exists, otherwise throw an error.

share_memory()

See torch.Tensor.share_memory_().

slanc_tuning(hyperparams)

state_dict(*args[, destination, prefix, ...])

Return a dictionary containing references to the whole state of the module.

to(*args, **kwargs)

Move and/or cast the parameters and buffers.

to_compiler_graph()

Returns a compiler friendly graph

to_empty(*, device[, recurse])

Move the parameters and buffers to the specified device without copying storage.

train([mode])

Set the module in training mode.

transform(config)

A function that changes the format of the ops and loading state dicts according to the config file.

tuning_approximation_function(hyperparams)

type(dst_type)

Casts all parameters and buffers to dst_type.

update_params_with_raw(raw)

Update parameters of a DmxModule from a torch.nn.Module.

update_smoothquant_scale(input)

xpu([device])

Move all model parameters and buffers to the XPU.

zero_flop_counter()

zero_grad([set_to_none])

Reset gradients of all model parameters.

Attributes

T_destination

accum_format

approximation_function

bias_format

bops

call_super_init

dump_patches

effective_weight

flop_counter

flop_counter_enabled

flops

functional_forward

input_formats

input_precision

is_compound

last_input_shape

last_output_shape

multiplier_format

output_formats

plugins

residual_format

weight_elem_count

weight_format

weight_hypernet

Returns a function that processes weight according to the ops format of the module

weight_precision

weight_scale

weight_size_in_bytes

weight_sparseness

weight_storage_format

weight_storage_precision

weight_storage_scale

weight_storage_zero_point

weight_zero_point

num_groups

num_channels

eps

affine

training