dmx.compressor.modeling.nn.core

Functions

is_configurable(m)

Classes

DmxGraph([owning_module, tracer_cls, ...])

DmxModule(*args[, state_dict_url])

Extended torch.nn.Module for Dmx to support quantization.

DmxModuleConfig

DmxModuleType

class dmx.compressor.modeling.nn.core.DmxGraph(owning_module=None, tracer_cls=None, tracer_extras=None)

Bases: Graph

call_function(the_function, args=None, kwargs=None, type_expr=None, cast_name=None, cast_format=None)

Insert a call_function Node into the Graph. A call_function node represents a call to a Python callable, specified by the_function.

Parameters:
  • the_function (Callable[..., Any]) – The function to be called. Can be any PyTorch operator, Python function, or member of the builtins or operator namespaces.

  • args (Optional[Tuple[Argument, ...]]) – The positional arguments to be passed to the called function.

  • kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called function

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

  • name (Optional[str]) – The name of the node. If not specified, set to None

Returns:

The newly created and inserted call_function node.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node().

Note

Backwards-compatibility for this API is guaranteed.

call_method(method_name, args=None, kwargs=None, type_expr=None, cast_name=None, cast_format=None)

Insert a call_method Node into the Graph. A call_method node represents a call to a given method on the 0th element of args.

Parameters:
  • method_name (str) – The name of the method to apply to the self argument. For example, if args[0] is a Node representing a Tensor, then to call relu() on that Tensor, pass relu to method_name.

  • args (Optional[Tuple[Argument, ...]]) – The positional arguments to be passed to the called method. Note that this should include a self argument.

  • kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called method

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns:

The newly created and inserted call_method node.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node().

Note

Backwards-compatibility for this API is guaranteed.

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None, cast_name=None, cast_format=None)

Create a Node and add it to the Graph at the current insert-point. Note that the current insert-point can be set via Graph.inserting_before() and Graph.inserting_after().

Parameters:
  • op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the Graph docstring.

  • args (Optional[Tuple[Argument, ...]]) – is a tuple of arguments to this node.

  • kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node

  • name (Optional[str]) – an optional string name for the Node. This will influence the name of the value assigned to in the Python generated code.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns:

The newly-created and inserted node.

Note

Backwards-compatibility for this API is guaranteed.

create_placeholders(names, cast_names=None, cast_formats=None)
get_attr(qualified_name, cast_name=None, cast_format=None, optional_arg=True, type_expr=None)
optional_arg: controlling whether None will be returned instead of a Node. e.g. optional_arg = linear_mod.bias will return None when the module does not have a bias term.

Defaults to True.

placeholder(name, cast_name=None, cast_format=None, type_expr=None, default_value)

Insert a placeholder node into the Graph. A placeholder represents a function input.

Parameters:
  • name (str) – A name for the input value. This corresponds to the name of the positional argument to the function this Graph represents.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have. This is needed in some cases for proper code generation (e.g. when the function is used subsequently in TorchScript compilation).

  • default_value (Any) – The default value this function argument should take on. NOTE: to allow for None as a default value, inspect.Signature.empty should be passed as this argument to specify that the parameter does _not_ have a default value.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node.

Note

Backwards-compatibility for this API is guaranteed.

qdq_node(node, cast_name, cast_format)
class dmx.compressor.modeling.nn.core.DmxModule(*args, state_dict_url: str | None = None, **kwargs)

Bases: ApproximationMixin, PerformanceProxyMixin, LayerReconstructionMixin, NumericalCastMixin, WeightSparseMixin, Module

Extended torch.nn.Module for Dmx to support quantization.

Parameters:
  • *args (Optional[Tuple]) – variable length of args

  • state_dict_url (Optional[str]) – Url for loading the state dicts. Defaults to None.

  • **kwargs (Optional[Dict]) – variable length of kwargs

state_dict_url

Url for loading the module state dicts.

Type:

str

align_device(_input, args, kwargs, _device)
configure(config) None

A function that changes the format of the ops and loading state dicts according to the config file.

Parameters:

config (DmxModuleConfig) – config file for setting new formats and loading state dicts.

dmx_config(freeze=False)

A function that the DmxModuleConfig object for the module

Parameters:

freeze (bool) – if True, both state dict and ops formats would be included in the returned DmxModuleConfig. If False, only state dict will be included.

Returns:

A DmxModuleConfig object for the module

fold_weight_and_bias() None

A function that applies the ops the weights and biases using the corresponding formats.

forward(input: Tensor, *args, **kwargs) Tensor

Forward pass of the module with quantization ops applied.

Parameters:
  • input (Tensor) – input tensor to be passed through the module

  • *args (Optional[Tuple]) – variable length of args

  • **kwargs (Optional[Dict]) – variable length of kwargs

functional_forward = None
is_compound = False
load_state_dict_and_register_url(url: str) None

A function that loads state dict from a url and sets url to self.state_dict_url

Parameters:

ulr (str) – url for loading the state dict

measuring_runtime(_records: list, device: device)

Context manager for monitoring runtime of DmxModule

monitoring(_records: list)

Context manager for monitoring input/output to/from the DmxModule

plugins: List[PluginBase] = []
save_state_dict_and_register_url(parent_dir: str) None

A function that saves the current state dict of the module to a url under a specified parent directory

Parameters:

parent_dir (str) – parent directory for the url

abstractmethod to_compiler_graph() Graph

Returns a compiler friendly graph

transform(config) None

A function that changes the format of the ops and loading state dicts according to the config file.

Parameters:

config (DmxModuleConfig) – config file for setting new formats and loading state dicts.

update_params_with_raw(raw: Module) None

Update parameters of a DmxModule from a torch.nn.Module.

Parameters:

raw (torch.nn.Module) – the torch module to copy parameters from.

property weight_hypernet

Returns a function that processes weight according to the ops format of the module

class dmx.compressor.modeling.nn.core.DmxModuleConfig

Bases: dict

classmethod from_module(module: DmxModule, freeze=False)

A function that stores state and ops format of the module in a DmxModuleConfig object

Parameters:

module (DmxModule) – Target module for creating the DmxModuleConfig

Returns:

A DmxModuleConfig object that stores state and ops format of the module in a DmxModuleConfig object

class dmx.compressor.modeling.nn.core.DmxModuleType

Bases: type

dmx.compressor.modeling.nn.core.is_configurable(m)