dmx.compressor.numerical.observer

Classes

DMXObserverBase(dtype[, qscheme, ...])

Taken from torch.ao.quantization.observer.UniformQuantizationObserverBase

DummyObserver(dtype[, ch_axis])

This is a dummy observer that does not do anything

HistogramObserver([bins, upsample_rate, ...])

Adapted from torch.ao.quantization.observer.HistogramObserver, still does not support per-channel

MinMaxObserver([dtype, qscheme, ch_axis, ...])

Adapted from torch.ao.quantization.observer.MinMaxObserver, supports per-channel

PercentileObserver([percentile])

Extending HistogramObserver to allow calculation of percentile from histogram, taken from https://github.com/NVIDIA/TensorRT/blob/master/tools/pytorch-quantization/pytorch_quantization/calib/histogram.py#L287

class dmx.compressor.numerical.observer.DMXObserverBase(dtype: Format, qscheme: qscheme = torch.per_tensor_affine, factory_kwargs: dict | None = None, eps: float = 1.1920928955078125e-07, **kwargs)

Bases: ObserverBase

Taken from torch.ao.quantization.observer.UniformQuantizationObserverBase

eps: Tensor
extra_repr()

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

class dmx.compressor.numerical.observer.DummyObserver(dtype: Format, ch_axis: int = -1, **kwargs)

Bases: DMXObserverBase

This is a dummy observer that does not do anything

calculate_qparams()
forward(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dmx.compressor.numerical.observer.HistogramObserver(bins: int = 2048, upsample_rate: int = 128, dtype: Format = XP[8, 0](CSN), qscheme: qscheme | None = torch.per_tensor_affine, ch_axis: int = -1, factory_kwargs: dict | None = None, eps=1.1920928955078125e-07, **kwargs)

Bases: DMXObserverBase

Adapted from torch.ao.quantization.observer.HistogramObserver, still does not support per-channel

calculate_qparams()
extra_repr()

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x_orig: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

histogram: Tensor
max_val: Tensor
min_val: Tensor
class dmx.compressor.numerical.observer.MinMaxObserver(dtype: Format = XP[8, 0](CSN), qscheme: qscheme | None = torch.per_tensor_affine, ch_axis: int = -1, factory_kwargs: Dict | None = None, eps: float | None = 1.1920928955078125e-07, **kwargs)

Bases: DMXObserverBase

Adapted from torch.ao.quantization.observer.MinMaxObserver, supports per-channel

calculate_qparams()

Calculates the quantization parameters.

extra_repr()

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x_orig)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

max_val: Tensor
min_val: Tensor
reset_min_max_vals()

Resets the min/max values.

class dmx.compressor.numerical.observer.PercentileObserver(percentile: float = 99.99, **kwargs)

Bases: HistogramObserver

Extending HistogramObserver to allow calculation of percentile from histogram, taken from https://github.com/NVIDIA/TensorRT/blob/master/tools/pytorch-quantization/pytorch_quantization/calib/histogram.py#L287

calculate_qparams()
extra_repr()

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.