dmx.compressor.fx.tracer.hf_symbolic_trace

dmx.compressor.fx.tracer.hf_symbolic_trace(model: ~transformers.modeling_utils.PreTrainedModel, input_names: ~typing.List[str] | None = None, concrete_args: ~typing.Dict[str, ~typing.Any] | None = None, tracer_cls: ~typing.Type[~dmx.compressor.fx.tracer.DmxHFTracer] = <class 'dmx.compressor.fx.tracer.DmxHFTracer'>, dummy_inputs: ~typing.Dict[str, ~typing.Any] | None = None) GraphModule

Performs symbolic tracing on a huggingface model.

Parameters:
  • model ([PretrainedModel]) – The model to trace.

  • input_names (List[str], optional) – The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.

  • disable_check (bool, optional, defaults to False) – If True, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.

  • tracer_cls (Type[HFTracer], optional, defaults to HFTracer) – The tracer class to use for instantiating the tracer. If unset, HFTracer is used instead.

Returns:

A GraphModule constructed by recording operations seen while tracing the model. torch.fx.Tracer: The tracer used for tracing the model.

Return type:

torch.fx.GraphModule

Example

```python from dmx.compressor.fx.tracer import hf_symbolic_trace

traced_model,tracer = hf_symbolic_trace(model, input_names=[“input_ids”, “attention_mask”, “token_type_ids”]) ```