Source code for construe.benchmark.runner

"""
Benchmark ABC and global benchmark runner.
"""

import time
import dataclasses

from .base import Benchmark
from ..utils import humanize_duration
from ..metrics import Metric, Measurement, dump
from ..exceptions import ConstrueError, BenchmarkError

from tqdm import tqdm
from datetime import datetime, timezone
from typing import Iterable, List, Dict, Optional, Type


DATEFMT = "%Y-%m-%dT%H:%M:%S.%fZ"


[docs] class BenchmarkRunner(object): """ Executes one or more benchmarks, configuring them with a top-level config and collecting all measurements, merging as necessary and outputing them to a file. """ def __init__( self, benchmarks: List[Benchmark], device: str = None, env: str = None, n_runs: int = 1, limit: int = None, data_home: str = None, model_home: str = None, use_sample: bool = True, cleanup: bool = True, verbose: bool = True, ): self.env = env self.device = device self.n_runs = n_runs self.limit = limit self.benchmark_kwargs = { "data_home": data_home, "model_home": model_home, "use_sample": use_sample, "progress": verbose, } self.cleanup = cleanup self.verbose = verbose self.benchmarks = benchmarks for b in self.benchmarks: if not issubclass(b, Benchmark): raise BenchmarkError(f"{b.__name__} is not a Benchmark") @property def is_complete(self): return getattr(self, "run_complete_", False)
[docs] def run(self): self.results_ = Results( n_runs=self.n_runs, limit=self.limit, benchmarks=[b.__name__ for b in self.benchmarks], started=datetime.now(timezone.utc).strftime(DATEFMT), env=self.env, device=self.device, options=self.benchmark_kwargs, errors=[], ) self.run_complete_ = False self.measurements_ = [] started = time.time() for cls in self.benchmarks: total = self.limit or cls.total(**self.benchmark_kwargs) for i in range(self.n_runs): self.run_benchmark(i, total, cls) self.results_.duration = time.time() - started self.results_.measurements = Measurement.merge(self.measurements_) self.run_complete_ = True if self.verbose: print(f"{len(self.benchmarks)} benchmark(s) complete in {humanize_duration(self.results_.duration)}") if self.cleanup: print("cleaned up data and model caches: all downloaded data removed")
[docs] def run_benchmark(self, idx: int, total: int, Runner: Type): # TODO: do we need to pass separate metadata to the kwargs? progress = tqdm(total=total, desc=f"Running {Runner.__name__} Benchmark {idx+1}", leave=False) benchmark = Runner(**self.benchmark_kwargs) try: for measurement in self.execute(idx, benchmark, progress): self.measurements_.append(measurement) self.results_.successes += 1 except ConstrueError as e: self.results_.failures += 1 self.results_.errors.append(str(e))
[docs] def execute(self, idx: int, benchmark: Benchmark, progress: tqdm) -> Iterable[Measurement]: # Setup the benchmark benchmark.before() ptimes = [] # preproccess times itimes = [] # inference times try: # Time each inference # TODO: measure memory usage during inferencing for instance in benchmark.instances(limit=self.limit): t1 = time.time() features = benchmark.preprocess(instance) t2 = time.time() benchmark.inference(features) t3 = time.time() ptimes.append(t2 - t1) itimes.append(t3 - t2) progress.update(1) finally: # Ensure benchmark is cleaned up despite any errors if this is the last # run of the benchmark and cleanup is specified (otherwise leave cache). cleanup = self.cleanup and idx == self.n_runs - 1 benchmark.after(cleanup=cleanup) # Create the process times measurement yield Measurement( per_run=1, raw_metrics=ptimes, units="s", metric=Metric( label=benchmark.__class__.__name__, sub_label="preprocessing", description=benchmark.description, device=self.device, env=self.env, ), ) # Create the inference times measurement yield Measurement( per_run=1, raw_metrics=itimes, units="s", metric=Metric( label=benchmark.__class__.__name__, sub_label="inferencing", description=benchmark.description, device=self.device, env=self.env, ), )
[docs] def save(self, path): if not self.is_complete: raise BenchmarkError("cannot save benchmarks that haven't been run") with open(path, "w") as o: dump(self.results_, o) if self.verbose: print("benchmark results saved to", path)
[docs] @dataclasses.dataclass(init=True, repr=False, eq=True) class Results: """ A result of all runs of a Benchmark including benchmarking information. """ n_runs: int benchmarks: List[str] started: str errors: List[str] = list limit: Optional[int] = None duration: Optional[float] = None env: Optional[str] = None device: Optional[str] = None options: Optional[Dict] = None successes: Optional[int] = 0 failures: Optional[int] = 0 measurements: Optional[List[Measurement]] = None