dmx.compressor.fx.tracer.symbolic_trace
- dmx.compressor.fx.tracer.symbolic_trace(root: Module | Callable[[...], Any], concrete_args: Dict[str, Any] | None = None) GraphModule
Symbolic tracing API
Given an
nn.Moduleor function instanceroot, this function will return aGraphModuleconstructed by recording operations seen while tracing throughrootand the tracer used to trace the model.concrete_argsallows you to partially specialize your function, whether it’s to remove control flow or data structures.For example:
def f(a, b): if b == True: return a else: return a*2
FX can typically not trace through this due to the presence of control flow. However, we can use concrete_args to specialize on the value of b to trace through this:
f = fx.symbolic_trace(f, concrete_args={'b': False}) assert f(3, False) == 6
Note that although you can still pass in different values of b, they will be ignored.
We can also use concrete_args to eliminate data-structure handling from our function. This will use pytrees to flatten your input. To avoid overspecializing, pass in fx.PH for values that shouldn’t be specialized. For example:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7
- Parameters:
root (Union[torch.nn.Module, Callable]) – Module or function to be traced and converted into a Graph representation.
concrete_args (Optional[Dict[str, any]]) – Inputs to be partially specialized
- Returns:
a Module created from the recorded operations from
root. Tracer: the tracer used for tracing the model- Return type:
GraphModule