dmx.compressor.fx.tracer

Functions

disable_hooked_forward(model)

This context manager disables hooks inserted by accelerate for fx tracing, and adds it back to the old forward upon exiting the context manager

hf_symbolic_trace(model[, input_names, ...])

Performs symbolic tracing on a huggingface model.

symbolic_trace(root[, concrete_args])

Symbolic tracing API

Classes

DmxHFTracer([autowrap_modules, ...])

Custom HFTracer where definition of leaf nodes

HFQuantTracer()

Customed tracer with scope manager for HuggingFace

QuantTracer()

Customed tracer with scope manager and returns a flat GraphModule

Scope(module_path, module_type)

Scope object that records the module path and the module type of a module.

ScopeContextManager(scope, current_module, ...)

A context manager to track the Scope of Node during symbolic tracing.

class dmx.compressor.fx.tracer.DmxHFTracer(autowrap_modules=(<module 'math' from '/home/docs/.asdf/installs/python/3.11.12/lib/python3.11/lib-dynload/math.cpython-311-x86_64-linux-gnu.so'>, ), autowrap_functions=(<function apply_rotary_pos_emb>, <function apply_rotary_pos_emb>, <function apply_rotary_pos_emb>))

Bases: HFTracer

Custom HFTracer where definition of leaf nodes

is_leaf_module(m: Module, module_qualified_name: str) bool

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters:
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

Returns:

True if m is a leaf module

class dmx.compressor.fx.tracer.HFQuantTracer

Bases: HFTracer

Customed tracer with scope manager for HuggingFace

scope

Scope object to record the path and type of a module

Type:

Scope

node_name_to_scope

Dictionary that maps node name to scope

Type:

Dict

record_stack_traces

Not in use yet

Type:

bool

call_module(m: Module, forward: Callable[[...], Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any

Method that specifies the behavior of this Tracer when it encounters a call to an nn.Module instance.

By default, the behavior is to check if the called module is a leaf module via is_leaf_module. If it is, emit a call_module node referring to m in the Graph. Otherwise, call the Module normally, tracing through the operations in its forward function.

Parameters:
  • m (Module) – The module for which a call is being emitted

  • forward (Callable) – The forward() method of the Module to be invoked

  • args (Tuple) – args of the module callsite

  • kwargs (Dict) – kwargs of the module callsite

Returns:

The return value from the Module call. In the case that a call_module node was emitted, this is a Proxy value. Otherwise, it is whatever value was returned from the Module invocation.

create_node(kind: str, target: Callable[[...], Any] | str, args: Tuple[tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None], name: str | None = None, type_expr: Any | None = None) Node

Inserts a graph node given target, args, kwargs, and name.

Parameters:
  • op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the Graph docstring.

  • args (Optional[Tuple[Argument, ...]]) – is a tuple of arguments to this node.

  • kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node

  • name (Optional[str]) – an optional string name for the Node. This will influence the name of the value assigned to in the Python generated code.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns:

The newly-created and inserted node.

is_leaf_module(m: Module, module_qualified_name: str) bool

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters:
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

Returns:

True if m is a leaf module

class dmx.compressor.fx.tracer.QuantTracer

Bases: Tracer

Customed tracer with scope manager and returns a flat GraphModule

scope

Scope object to record the path and type of a module

Type:

Scope

node_name_to_scope

Dictionary that maps node name to scope

Type:

Dict

record_stack_traces

Not in use yet

Type:

bool

call_module(m: Module, forward: Callable[[...], Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any

Method that specifies the behavior of this Tracer when it encounters a call to an nn.Module instance.

By default, the behavior is to check if the called module is a leaf module via is_leaf_module. If it is, emit a call_module node referring to m in the Graph. Otherwise, call the Module normally, tracing through the operations in its forward function.

Parameters:
  • m (Module) – The module for which a call is being emitted

  • forward (Callable) – The forward() method of the Module to be invoked

  • args (Tuple) – args of the module callsite

  • kwargs (Dict) – kwargs of the module callsite

Returns:

The return value from the Module call. In the case that a call_module node was emitted, this is a Proxy value. Otherwise, it is whatever value was returned from the Module invocation.

create_node(kind: str, target: Callable[[...], Any] | str, args: Tuple[tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None], name: str | None = None, type_expr: Any | None = None) Node

Inserts a graph node given target, args, kwargs, and name.

Parameters:
  • op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the Graph docstring.

  • args (Optional[Tuple[Argument, ...]]) – is a tuple of arguments to this node.

  • kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node

  • name (Optional[str]) – an optional string name for the Node. This will influence the name of the value assigned to in the Python generated code.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns:

The newly-created and inserted node.

is_leaf_module(m: Module, module_qualified_name: str) bool

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters:
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

Returns:

True if m is a leaf module

class dmx.compressor.fx.tracer.Scope(module_path: str, module_type: Any)

Bases: object

Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule.

Example

class Sub(torch.nn.Module):
def forward(self, x):

# This will be a call_method Node in GraphModule, # scope for this would be (module_path=”sub”, module_type=Sub) return x.transpose(1, 2)

class M(torch.nn.Module):
def __init__(self):

self.sub = Sub()

def forward(self, x):

# This will be a call_method Node as well, # scope for this would be (module_path=””, None) x = x.transpose(1, 2) x = self.sub(x) return x

Parameters:
  • module_path (str) – String describing the path to the module

  • module_type (Any) – type of the module

module_path

String describing the path to the module

Type:

str

module_type

type of the module

Type:

Any

class dmx.compressor.fx.tracer.ScopeContextManager(scope: Scope, current_module: Module, current_module_path: str)

Bases: object

A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we’ll update the scope information of the current module, and when we exit, we’ll restore the previous scope information.

Parameters:
  • scope (Scope) – Scope object to store the module details

  • current_module (torch.nn.Module) – Current module object

  • current_module_path (str) – String path to current module

prev_module_type

Type of the previous module

Type:

Any

prev_module_path

String path to previous module

Type:

str

scope

Scope object to store the module details

Type:

Scope

dmx.compressor.fx.tracer.disable_hooked_forward(model)

This context manager disables hooks inserted by accelerate for fx tracing, and adds it back to the old forward upon exiting the context manager

dmx.compressor.fx.tracer.hf_symbolic_trace(model: ~transformers.modeling_utils.PreTrainedModel, input_names: ~typing.List[str] | None = None, concrete_args: ~typing.Dict[str, ~typing.Any] | None = None, tracer_cls: ~typing.Type[~dmx.compressor.fx.tracer.DmxHFTracer] = <class 'dmx.compressor.fx.tracer.DmxHFTracer'>, dummy_inputs: ~typing.Dict[str, ~typing.Any] | None = None) GraphModule

Performs symbolic tracing on a huggingface model.

Parameters:
  • model ([PretrainedModel]) – The model to trace.

  • input_names (List[str], optional) – The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.

  • disable_check (bool, optional, defaults to False) – If True, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.

  • tracer_cls (Type[HFTracer], optional, defaults to HFTracer) – The tracer class to use for instantiating the tracer. If unset, HFTracer is used instead.

Returns:

A GraphModule constructed by recording operations seen while tracing the model. torch.fx.Tracer: The tracer used for tracing the model.

Return type:

torch.fx.GraphModule

Example

```python from dmx.compressor.fx.tracer import hf_symbolic_trace

traced_model,tracer = hf_symbolic_trace(model, input_names=[“input_ids”, “attention_mask”, “token_type_ids”]) ```

dmx.compressor.fx.tracer.symbolic_trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None) GraphModule

Symbolic tracing API

Given an nn.Module or function instance root, this function will return a GraphModule constructed by recording operations seen while tracing through root and the tracer used to trace the model.

concrete_args allows you to partially specialize your function, whether it’s to remove control flow or data structures.

For example:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

FX can typically not trace through this due to the presence of control flow. However, we can use concrete_args to specialize on the value of b to trace through this:

f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False)  == 6

Note that although you can still pass in different values of b, they will be ignored.

We can also use concrete_args to eliminate data-structure handling from our function. This will use pytrees to flatten your input. To avoid overspecializing, pass in fx.PH for values that shouldn’t be specialized. For example:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
Parameters:
  • root (Union[torch.nn.Module, Callable]) – Module or function to be traced and converted into a Graph representation.

  • concrete_args (Optional[Dict[str, any]]) – Inputs to be partially specialized

Returns:

a Module created from the recorded operations from root. Tracer: the tracer used for tracing the model

Return type:

GraphModule