Source code for construe.metrics.metrics

"""
Measurements are the result of a benchmark run.

This module influenced by https://rtnl.link/hMNNI90KBGj
"""

import numpy as np
import dataclasses

from collections import defaultdict
from typing import cast, Optional, Any, Iterable, Dict, List, Tuple


# Measurement will include a warning if the distribution is suspect. All
# runs are expected to have some variation; these parameters set the thresholds.
_IQR_WARN_THRESHOLD = 0.1
_IQR_GROSS_WARN_THRESHOLD = 0.25


[docs] @dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True) class Metric: """ Container information used to define a benchmark measurement. This class is similar to a pytorch TaskSpec. """ label: Optional[str] = None sub_label: Optional[str] = None description: Optional[str] = None device: Optional[str] = None env: Optional[str] = None @property def title(self) -> str: """ Best effort attempt at a string label for the metric. """ if self.label is not None: return self.label + (f": {self.sub_label}" if self.sub_label else "") elif self.env is not None: return f"Metric for {self.env}" + (f" on {self.device}" if self.device else "") # noqa return "Metric"
[docs] def summarize(self) -> str: """ Builds a summary string for printing the metric. """ parts = [ self.title, self.description or "" ] return "\n".join([f"{i}\n" if "\n" in i else i for i in parts if i])
_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(Metric))
[docs] @dataclasses.dataclass(init=True, repr=False) class Measurement: """ The result of a benchmark measurement. This class stores one or more measurements of a given statement. It is similar to the pytorch measurement and provides convienence methods and serialization. """ metric: Metric raw_metrics: List[float] per_run: int = 1 units: Optional[str] = None metadata: Optional[Dict[Any, Any]] = None def __post_init__(self) -> None: self._sorted_metrics: Tuple[float, ...] = () self._warnings: Tuple[str, ...] = () self._median: float = -1.0 self._mean: float = -1.0 self._p25: float = -1.0 self._p75: float = -1.0 def __getattr__(self, name: str) -> Any: # Forward Metric fields for convenience. if name in _TASKSPEC_FIELDS: return getattr(self.task_spec, name) return super().__getattribute__(name) def _compute_stats(self) -> None: """ Comptues the internal stats for the measurements if not already computed. """ if self.raw_metrics and not self._sorted_metrics: self._sorted_metrics = tuple(sorted(self.metrics)) _metrics = np.array(self._sorted_metrics, dtype=np.float64) self._median = np.quantile(_metrics, 0.5).item() self._mean = _metrics.mean() self._p25 = np.quantile(_metrics, 0.25).item() self._p75 = np.quantile(_metrics, 0.75).item() if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD): self.__add_warning("This suggests significant environmental influence.") elif not self.meets_confidence(_IQR_WARN_THRESHOLD): self.__add_warning("This could indicate system fluctuation.") def __add_warning(self, msg: str) -> None: riqr = self.iqr / self.median * 100 self._warnings += ( f" WARNING: Interquartile range is {riqr:.1f}% " f"of the median measurement.\n {msg}", ) @property def metrics(self) -> List[float]: return [m / self.per_run for m in self.raw_metrics] @property def median(self) -> float: self._compute_stats() return self._median @property def mean(self) -> float: self._compute_stats() return self._mean @property def iqr(self) -> float: self._compute_stats() return self._p75 - self._p25 @property def has_warnings(self) -> bool: self._compute_stats() return bool(self._warnings) @property def title(self) -> str: return self.metric.title @property def env(self) -> str: return "Unspecified env" if self.metric.env is None else cast(str, self.metric.env) # noqa @property def row_name(self) -> str: return self.sub_label or "[Unknown]"
[docs] def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool: return self.iqr / self.median < threshold
[docs] def to_array(self): return np.array(self.metrics, dtype=np.float64)
[docs] @staticmethod def merge(measurements: Iterable["Measurement"]) -> List["Measurement"]: """ Merge measurement replicas into a single measurement. This method will extrapolate per_run=1 and will not transfer metadata. """ groups = defaultdict(list) for m in measurements: groups[m.metric].append(m) def merge_group(metric: Metric, group: List["Measurement"]) -> "Measurement": metrics: List[float] = [] for m in group: metrics.extend(m.metrics) return Measurement( per_run=1, raw_metrics=metrics, metric=metric, metadata=None ) return [merge_group(t, g) for t, g in groups.items()]
[docs] def select_duration_unit(t: float) -> Tuple[str, float]: """ Determine how to scale a duration to format for human readability. """ unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(np.log10(np.array(t)).item() // 3), "s") scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[unit] return unit, scale
[docs] def humanize_duration(u: str) -> str: return { "ns": "nanosecond", "us": "microsecond", "ms": "millisecond", "s": "second", }[u]