dmx.compressor.modeling.hf.balanced_device_map

dmx.compressor.modeling.hf.balanced_device_map(model: str, revision: str | None = 'main') Dict[str, int]

A function that computes a custom device map for the given model that distributes model weights evenly across all devices. Enable with device_map = “balanced” when calling pipeline.

Parameters:
  • model (str) – model name on huggingface

  • revision (str) – revision of the model on huggingface

Returns:

dictionary of the device map

Return type:

Dict[str, int]