dmx.compressor.modeling.hf.get_modules

dmx.compressor.modeling.hf.get_modules(root: Module, prefix: str) Dict[str, Module]

A function that recursively traverses the model from the given root module and returns a dictionary of submodules for device mapping. In accordance with the format of device_map = “auto”, only submodules that are leaf nodes or hidden layers are included in the dictionary; submodules of hidden layers are ignored.

Parameters:
  • root (torch.nn.Module) – model/module to traverse

  • prefix (str) – prefix for the submodule names

Returns:

dictionary of submodules

Return type:

Dict[str, torch.nn.Module]