dmx.compressor.modeling.hf.get_modules
- dmx.compressor.modeling.hf.get_modules(root: Module, prefix: str) Dict[str, Module]
A function that recursively traverses the model from the given root module and returns a dictionary of submodules for device mapping. In accordance with the format of device_map = “auto”, only submodules that are leaf nodes or hidden layers are included in the dictionary; submodules of hidden layers are ignored.
- Parameters:
root (torch.nn.Module) – model/module to traverse
prefix (str) – prefix for the submodule names
- Returns:
dictionary of submodules
- Return type:
Dict[str, torch.nn.Module]