dmx.compressor.utils.benchmark.gather_tensors

dmx.compressor.utils.benchmark.gather_tensors(tensor_collection: Tensor | List[Any] | Tuple[Any] | Dict[str, Any]) List[Tensor]

Gathers all torch tensors from arbitrary nested structures of Lists and Dicts

Parameters:

tensor_collection (Union[torch.Tensor, List[Any], Tuple[Any], Dict[str, Any]]) – A Torch tensor or an arbitrary collection of tensors such as what you would typically get as an output from a HuggingFace model

Return type:

List[torch.Tensor]