dmx.compressor.fx.tracer.hf_symbolic_trace
- dmx.compressor.fx.tracer.hf_symbolic_trace(model: ~transformers.modeling_utils.PreTrainedModel, input_names: ~typing.List[str] | None = None, concrete_args: ~typing.Dict[str, ~typing.Any] | None = None, tracer_cls: ~typing.Type[~dmx.compressor.fx.tracer.DmxHFTracer] = <class 'dmx.compressor.fx.tracer.DmxHFTracer'>, dummy_inputs: ~typing.Dict[str, ~typing.Any] | None = None) GraphModule
Performs symbolic tracing on a huggingface model.
- Parameters:
model ([PretrainedModel]) – The model to trace.
input_names (List[str], optional) – The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (bool, optional, defaults to False) – If True, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
tracer_cls (Type[HFTracer], optional, defaults to HFTracer) – The tracer class to use for instantiating the tracer. If unset, HFTracer is used instead.
- Returns:
A GraphModule constructed by recording operations seen while tracing the model. torch.fx.Tracer: The tracer used for tracing the model.
- Return type:
torch.fx.GraphModule
Example
```python from dmx.compressor.fx.tracer import hf_symbolic_trace
traced_model,tracer = hf_symbolic_trace(model, input_names=[“input_ids”, “attention_mask”, “token_type_ids”]) ```