dmx.compressor.utils.benchmark.measure_model_runtime

dmx.compressor.utils.benchmark.measure_model_runtime(model_maker: Callable[[], Tuple[Module, Callable, Callable, device]], modes: List[EVALUATION_MODE])

Entry function for measuring various runtime statistics

Parameters:
  • model_maker (Callable[[], Tuple[torch.nn.Module, Callable, Callable, torch.device]]) – A callable that returns the model to be measured, together with some callables to run a sample input through the model or to evaluate the model’s accuracy

  • modes (List[EVALUATION_MODE]) – List of modes on which to evaluate the model’s runtime