dmx.compressor.fx.tracer.symbolic_trace

dmx.compressor.fx.tracer.symbolic_trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None) GraphModule

Symbolic tracing API

Given an nn.Module or function instance root, this function will return a GraphModule constructed by recording operations seen while tracing through root and the tracer used to trace the model.

concrete_args allows you to partially specialize your function, whether it’s to remove control flow or data structures.

For example:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

FX can typically not trace through this due to the presence of control flow. However, we can use concrete_args to specialize on the value of b to trace through this:

f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False)  == 6

Note that although you can still pass in different values of b, they will be ignored.

We can also use concrete_args to eliminate data-structure handling from our function. This will use pytrees to flatten your input. To avoid overspecializing, pass in fx.PH for values that shouldn’t be specialized. For example:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
Parameters:
  • root (Union[torch.nn.Module, Callable]) – Module or function to be traced and converted into a Graph representation.

  • concrete_args (Optional[Dict[str, any]]) – Inputs to be partially specialized

Returns:

a Module created from the recorded operations from root. Tracer: the tracer used for tracing the model

Return type:

GraphModule