dmx.compressor.fx.tracer
Functions
|
This context manager disables hooks inserted by accelerate for fx tracing, and adds it back to the old forward upon exiting the context manager |
|
Performs symbolic tracing on a huggingface model. |
|
Symbolic tracing API |
Classes
|
Custom HFTracer where definition of leaf nodes |
Customed tracer with scope manager for HuggingFace |
|
Customed tracer with scope manager and returns a flat GraphModule |
|
|
Scope object that records the module path and the module type of a module. |
|
A context manager to track the Scope of Node during symbolic tracing. |
- class dmx.compressor.fx.tracer.DmxHFTracer(autowrap_modules=(<module 'math' from '/home/docs/.asdf/installs/python/3.11.12/lib/python3.11/lib-dynload/math.cpython-311-x86_64-linux-gnu.so'>, ), autowrap_functions=(<function apply_rotary_pos_emb>, <function apply_rotary_pos_emb>, <function apply_rotary_pos_emb>))
Bases:
HFTracerCustom HFTracer where definition of leaf nodes
- is_leaf_module(m: Module, module_qualified_name: str) bool
A method to specify whether a given
nn.Moduleis a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_modulecalls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters:
m (Module) – The module being queried about
module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule
foocontains submodulebar, which contains submodulebaz, that module will appear with the qualified namefoo.bar.bazhere.
- Returns:
True if m is a leaf module
- class dmx.compressor.fx.tracer.HFQuantTracer
Bases:
HFTracerCustomed tracer with scope manager for HuggingFace
- node_name_to_scope
Dictionary that maps node name to scope
- Type:
Dict
- record_stack_traces
Not in use yet
- Type:
bool
- call_module(m: Module, forward: Callable[[...], Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any
Method that specifies the behavior of this
Tracerwhen it encounters a call to annn.Moduleinstance.By default, the behavior is to check if the called module is a leaf module via
is_leaf_module. If it is, emit acall_modulenode referring tomin theGraph. Otherwise, call theModulenormally, tracing through the operations in itsforwardfunction.- Parameters:
m (Module) – The module for which a call is being emitted
forward (Callable) – The forward() method of the
Moduleto be invokedargs (Tuple) – args of the module callsite
kwargs (Dict) – kwargs of the module callsite
- Returns:
The return value from the Module call. In the case that a
call_modulenode was emitted, this is aProxyvalue. Otherwise, it is whatever value was returned from theModuleinvocation.
- create_node(kind: str, target: Callable[[...], Any] | str, args: Tuple[tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None], name: str | None = None, type_expr: Any | None = None) Node
Inserts a graph node given target, args, kwargs, and name.
- Parameters:
op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the
Graphdocstring.args (Optional[Tuple[Argument, ...]]) – is a tuple of arguments to this node.
kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node
name (Optional[str]) – an optional string name for the
Node. This will influence the name of the value assigned to in the Python generated code.type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
- Returns:
The newly-created and inserted node.
- is_leaf_module(m: Module, module_qualified_name: str) bool
A method to specify whether a given
nn.Moduleis a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_modulecalls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters:
m (Module) – The module being queried about
module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule
foocontains submodulebar, which contains submodulebaz, that module will appear with the qualified namefoo.bar.bazhere.
- Returns:
True if m is a leaf module
- class dmx.compressor.fx.tracer.QuantTracer
Bases:
TracerCustomed tracer with scope manager and returns a flat GraphModule
- node_name_to_scope
Dictionary that maps node name to scope
- Type:
Dict
- record_stack_traces
Not in use yet
- Type:
bool
- call_module(m: Module, forward: Callable[[...], Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) Any
Method that specifies the behavior of this
Tracerwhen it encounters a call to annn.Moduleinstance.By default, the behavior is to check if the called module is a leaf module via
is_leaf_module. If it is, emit acall_modulenode referring tomin theGraph. Otherwise, call theModulenormally, tracing through the operations in itsforwardfunction.- Parameters:
m (Module) – The module for which a call is being emitted
forward (Callable) – The forward() method of the
Moduleto be invokedargs (Tuple) – args of the module callsite
kwargs (Dict) – kwargs of the module callsite
- Returns:
The return value from the Module call. In the case that a
call_modulenode was emitted, this is aProxyvalue. Otherwise, it is whatever value was returned from theModuleinvocation.
- create_node(kind: str, target: Callable[[...], Any] | str, args: Tuple[tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...], kwargs: Dict[str, tuple[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None, ...] | Sequence[tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | Mapping[str, tuple[Argument, ...] | Sequence[Argument] | Mapping[str, Argument] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None] | slice | range | Node | str | int | float | bool | complex | dtype | Tensor | device | memory_format | layout | OpOverload | SymInt | SymBool | SymFloat | None], name: str | None = None, type_expr: Any | None = None) Node
Inserts a graph node given target, args, kwargs, and name.
- Parameters:
op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the
Graphdocstring.args (Optional[Tuple[Argument, ...]]) – is a tuple of arguments to this node.
kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node
name (Optional[str]) – an optional string name for the
Node. This will influence the name of the value assigned to in the Python generated code.type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
- Returns:
The newly-created and inserted node.
- is_leaf_module(m: Module, module_qualified_name: str) bool
A method to specify whether a given
nn.Moduleis a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_modulecalls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters:
m (Module) – The module being queried about
module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule
foocontains submodulebar, which contains submodulebaz, that module will appear with the qualified namefoo.bar.bazhere.
- Returns:
True if m is a leaf module
- class dmx.compressor.fx.tracer.Scope(module_path: str, module_type: Any)
Bases:
objectScope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule.
Example
- class Sub(torch.nn.Module):
- def forward(self, x):
# This will be a call_method Node in GraphModule, # scope for this would be (module_path=”sub”, module_type=Sub) return x.transpose(1, 2)
- class M(torch.nn.Module):
- def __init__(self):
self.sub = Sub()
- def forward(self, x):
# This will be a call_method Node as well, # scope for this would be (module_path=””, None) x = x.transpose(1, 2) x = self.sub(x) return x
- Parameters:
module_path (str) – String describing the path to the module
module_type (Any) – type of the module
- module_path
String describing the path to the module
- Type:
str
- module_type
type of the module
- Type:
Any
- class dmx.compressor.fx.tracer.ScopeContextManager(scope: Scope, current_module: Module, current_module_path: str)
Bases:
objectA context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we’ll update the scope information of the current module, and when we exit, we’ll restore the previous scope information.
- Parameters:
scope (Scope) – Scope object to store the module details
current_module (torch.nn.Module) – Current module object
current_module_path (str) – String path to current module
- prev_module_type
Type of the previous module
- Type:
Any
- prev_module_path
String path to previous module
- Type:
str
- dmx.compressor.fx.tracer.disable_hooked_forward(model)
This context manager disables hooks inserted by accelerate for fx tracing, and adds it back to the old forward upon exiting the context manager
- dmx.compressor.fx.tracer.hf_symbolic_trace(model: ~transformers.modeling_utils.PreTrainedModel, input_names: ~typing.List[str] | None = None, concrete_args: ~typing.Dict[str, ~typing.Any] | None = None, tracer_cls: ~typing.Type[~dmx.compressor.fx.tracer.DmxHFTracer] = <class 'dmx.compressor.fx.tracer.DmxHFTracer'>, dummy_inputs: ~typing.Dict[str, ~typing.Any] | None = None) GraphModule
Performs symbolic tracing on a huggingface model.
- Parameters:
model ([PretrainedModel]) – The model to trace.
input_names (List[str], optional) – The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (bool, optional, defaults to False) – If True, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
tracer_cls (Type[HFTracer], optional, defaults to HFTracer) – The tracer class to use for instantiating the tracer. If unset, HFTracer is used instead.
- Returns:
A GraphModule constructed by recording operations seen while tracing the model. torch.fx.Tracer: The tracer used for tracing the model.
- Return type:
torch.fx.GraphModule
Example
```python from dmx.compressor.fx.tracer import hf_symbolic_trace
traced_model,tracer = hf_symbolic_trace(model, input_names=[“input_ids”, “attention_mask”, “token_type_ids”]) ```
- dmx.compressor.fx.tracer.symbolic_trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None) GraphModule
Symbolic tracing API
Given an
nn.Moduleor function instanceroot, this function will return aGraphModuleconstructed by recording operations seen while tracing throughrootand the tracer used to trace the model.concrete_argsallows you to partially specialize your function, whether it’s to remove control flow or data structures.For example:
def f(a, b): if b == True: return a else: return a*2
FX can typically not trace through this due to the presence of control flow. However, we can use concrete_args to specialize on the value of b to trace through this:
f = fx.symbolic_trace(f, concrete_args={'b': False}) assert f(3, False) == 6
Note that although you can still pass in different values of b, they will be ignored.
We can also use concrete_args to eliminate data-structure handling from our function. This will use pytrees to flatten your input. To avoid overspecializing, pass in fx.PH for values that shouldn’t be specialized. For example:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7
- Parameters:
root (Union[torch.nn.Module, Callable]) – Module or function to be traced and converted into a Graph representation.
concrete_args (Optional[Dict[str, any]]) – Inputs to be partially specialized
- Returns:
a Module created from the recorded operations from
root. Tracer: the tracer used for tracing the model- Return type:
GraphModule