|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections import defaultdict |
| 4 | +from copy import deepcopy |
| 5 | +from typing import TYPE_CHECKING, Callable, Dict, List, Literal |
| 6 | + |
| 7 | +from ablate.core.types import GroupedRun, Run |
| 8 | + |
| 9 | + |
| 10 | +if TYPE_CHECKING: # pragma: no cover |
| 11 | + from .query import Query # noqa: TC004 |
| 12 | + from .selectors import AbstractMetric |
| 13 | + |
| 14 | + |
| 15 | +class GroupedQuery: |
| 16 | + def __init__(self, groups: List[GroupedRun]) -> None: |
| 17 | + """Query interface for manipulating grouped runs in a functional way. |
| 18 | +
|
| 19 | + All methods operate on a shallow copy of the runs in the query, so the original |
| 20 | + runs are not modified and assumed to be immutable. |
| 21 | +
|
| 22 | + Args: |
| 23 | + groups: A list of grouped runs to be queried. |
| 24 | + """ |
| 25 | + self._grouped = groups |
| 26 | + |
| 27 | + def filter(self, fn: Callable[[GroupedRun], bool]) -> GroupedQuery: |
| 28 | + """Filter the grouped runs in the grouped query based on a predicate function. |
| 29 | +
|
| 30 | + Args: |
| 31 | + fn: Predicate function that takes in a grouped run and returns a boolean |
| 32 | + value. |
| 33 | +
|
| 34 | + Returns: |
| 35 | + A new grouped query with the grouped runs that satisfy the predicate |
| 36 | + function. |
| 37 | + """ |
| 38 | + return GroupedQuery([g for g in self._grouped[:] if fn(g)]) |
| 39 | + |
| 40 | + def map(self, fn: Callable[[GroupedRun], GroupedRun]) -> GroupedQuery: |
| 41 | + """Apply a function to each grouped run in the grouped query. |
| 42 | +
|
| 43 | + This function is intended to be used for modifying the grouped runs in the |
| 44 | + grouped query. The function should return a new grouped run object as the |
| 45 | + original grouped run is not modified. |
| 46 | +
|
| 47 | + Args: |
| 48 | + fn: Function that takes in a grouped run and returns a new grouped run |
| 49 | + object. |
| 50 | +
|
| 51 | + Returns: |
| 52 | + A new grouped query with the modified grouped runs. |
| 53 | + """ |
| 54 | + return GroupedQuery([fn(deepcopy(g)) for g in self._grouped]) |
| 55 | + |
| 56 | + def sort(self, key: AbstractMetric, ascending: bool = False) -> GroupedQuery: |
| 57 | + """Sort the runs inside each grouped run in the grouped query based on a metric. |
| 58 | +
|
| 59 | + Args: |
| 60 | + key: Metric to sort the grouped runs by. |
| 61 | + ascending: Whether to sort in ascending order. |
| 62 | + Defaults to False (descending order). |
| 63 | +
|
| 64 | + Returns: |
| 65 | + A new grouped query with the grouped runs sorted by the specified metric. |
| 66 | + """ |
| 67 | + return GroupedQuery( |
| 68 | + [ |
| 69 | + GroupedRun( |
| 70 | + key=g.key, |
| 71 | + value=g.value, |
| 72 | + runs=sorted(g.runs, key=key, reverse=not ascending), |
| 73 | + ) |
| 74 | + for g in self._grouped |
| 75 | + ] |
| 76 | + ) |
| 77 | + |
| 78 | + def head(self, n: int) -> Query: |
| 79 | + """Get the first n runs inside each grouped run. |
| 80 | +
|
| 81 | + Args: |
| 82 | + n: Number of runs to return per group. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + A new query with the first n runs from each grouped run. |
| 86 | + """ |
| 87 | + return GroupedQuery( |
| 88 | + [ |
| 89 | + GroupedRun(key=g.key, value=g.value, runs=g.runs[:n]) |
| 90 | + for g in self._grouped |
| 91 | + ] |
| 92 | + )._to_query() |
| 93 | + |
| 94 | + def tail(self, n: int) -> Query: |
| 95 | + """Get the last n runs inside each grouped run. |
| 96 | +
|
| 97 | + Args: |
| 98 | + n: Number of runs to return per group. |
| 99 | +
|
| 100 | + Returns: |
| 101 | + A new query with the last n runs from each grouped run. |
| 102 | + """ |
| 103 | + return GroupedQuery( |
| 104 | + [ |
| 105 | + GroupedRun(key=g.key, value=g.value, runs=g.runs[-n:]) |
| 106 | + for g in self._grouped |
| 107 | + ] |
| 108 | + )._to_query() |
| 109 | + |
| 110 | + def topk(self, metric: AbstractMetric, k: int) -> Query: |
| 111 | + """Get the top k runs inside each grouped run based on a metric. |
| 112 | +
|
| 113 | + Args: |
| 114 | + metric: Metric to sort the runs by. |
| 115 | + k: Number of top runs to return per group. |
| 116 | +
|
| 117 | + Returns: |
| 118 | + A new query with the top k runs from each grouped run based on the |
| 119 | + specified metric. |
| 120 | + """ |
| 121 | + return GroupedQuery( |
| 122 | + [ |
| 123 | + GroupedRun( |
| 124 | + key=g.key, |
| 125 | + value=g.value, |
| 126 | + runs=sorted(g.runs, key=metric, reverse=metric.direction == "min")[ |
| 127 | + :k |
| 128 | + ], |
| 129 | + ) |
| 130 | + for g in self._grouped |
| 131 | + ] |
| 132 | + )._to_query() |
| 133 | + |
| 134 | + def bottomk(self, metric: AbstractMetric, k: int) -> Query: |
| 135 | + """Get the bottom k runs inside each grouped run based on a metric. |
| 136 | +
|
| 137 | + Args: |
| 138 | + metric: Metric to sort the runs by. |
| 139 | + k: Number of bottom runs to return per group. |
| 140 | +
|
| 141 | + Returns: |
| 142 | + A new query with the bottom k runs from each grouped run based on the |
| 143 | + specified metric. |
| 144 | + """ |
| 145 | + return GroupedQuery( |
| 146 | + [ |
| 147 | + GroupedRun( |
| 148 | + key=g.key, |
| 149 | + value=g.value, |
| 150 | + runs=sorted(g.runs, key=metric, reverse=metric.direction == "max")[ |
| 151 | + :k |
| 152 | + ], |
| 153 | + ) |
| 154 | + for g in self._grouped |
| 155 | + ] |
| 156 | + )._to_query() |
| 157 | + |
| 158 | + def aggregate( |
| 159 | + self, |
| 160 | + method: Literal["first", "last", "best", "worst", "mean"], |
| 161 | + over: AbstractMetric, |
| 162 | + ) -> Query: |
| 163 | + """Aggregate each group of runs using a specified method. |
| 164 | +
|
| 165 | + Supported methods include: |
| 166 | + - "first": Selects the first run from each group. |
| 167 | + - "last": Selects the last run from each group. |
| 168 | + - "best": Selects the run with the best value based on the given metric. |
| 169 | + - "worst": Selects the run with the worst value based on the given metric. |
| 170 | + - "mean": Computes the mean run across all runs in each group, including |
| 171 | + averaged metrics and temporal data, and collapsed metadata. |
| 172 | + Args: |
| 173 | + method: Aggregation strategy to apply per group. |
| 174 | + over: The metric used for comparison when using "best" or "worst" methods. |
| 175 | +
|
| 176 | + Raises: |
| 177 | + ValueError: If an unsupported aggregation method is provided. |
| 178 | +
|
| 179 | + Returns: |
| 180 | + A new query with the aggregated runs from each group. |
| 181 | + """ |
| 182 | + from .query import Query |
| 183 | + |
| 184 | + match method: |
| 185 | + case "first": |
| 186 | + return self.head(1) |
| 187 | + case "last": |
| 188 | + return self.tail(1) |
| 189 | + case "best": |
| 190 | + return self.topk(over, 1) |
| 191 | + case "worst": |
| 192 | + return self.bottomk(over, 1) |
| 193 | + case "mean": |
| 194 | + return Query([self._mean_run(g) for g in self._grouped]) |
| 195 | + case _: |
| 196 | + raise ValueError( |
| 197 | + f"Unsupported aggregation method: '{method}'. Must be " |
| 198 | + "'first', 'last', 'best', 'worst', or 'mean'." |
| 199 | + ) |
| 200 | + |
| 201 | + @staticmethod |
| 202 | + def _mean_run(group: GroupedRun) -> Run: |
| 203 | + def _mean(values: List[float]) -> float: |
| 204 | + return sum(values) / len(values) if values else float("nan") |
| 205 | + |
| 206 | + def _mean_temporal(runs: List[Run]) -> Dict[str, List[tuple[int, float]]]: |
| 207 | + all_keys = set().union(*(r.temporal.keys() for r in runs)) |
| 208 | + step_accumulator: Dict[str, Dict[int, List[float]]] = {} |
| 209 | + |
| 210 | + for key in all_keys: |
| 211 | + step_values = defaultdict(list) |
| 212 | + for run in runs: |
| 213 | + for step, val in run.temporal.get(key, []): |
| 214 | + step_values[step].append(val) |
| 215 | + step_accumulator[key] = step_values |
| 216 | + |
| 217 | + return { |
| 218 | + key: sorted( |
| 219 | + (step, sum(vals) / len(vals)) for step, vals in step_values.items() |
| 220 | + ) |
| 221 | + for key, step_values in step_accumulator.items() |
| 222 | + } |
| 223 | + |
| 224 | + def _common_metadata(attr: str) -> Dict[str, str]: |
| 225 | + key_sets = [getattr(r, attr).keys() for r in group.runs] |
| 226 | + common_keys = set.intersection(*map(set, key_sets)) |
| 227 | + result = {} |
| 228 | + for k in common_keys: |
| 229 | + values = {str(getattr(r, attr)[k]) for r in group.runs} |
| 230 | + result[k] = next(iter(values)) if len(values) == 1 else "#" |
| 231 | + return result |
| 232 | + |
| 233 | + all_metrics = [r.metrics for r in group.runs] |
| 234 | + all_keys = set().union(*all_metrics) |
| 235 | + mean_metrics = { |
| 236 | + k: _mean([m[k] for m in all_metrics if k in m]) for k in all_keys |
| 237 | + } |
| 238 | + |
| 239 | + return Run( |
| 240 | + id=f"grouped:{group.key}:{group.value}", |
| 241 | + params=_common_metadata("params"), |
| 242 | + metrics=mean_metrics, |
| 243 | + temporal=_mean_temporal(group.runs), |
| 244 | + ) |
| 245 | + |
| 246 | + def _to_query(self) -> Query: |
| 247 | + from .query import Query |
| 248 | + |
| 249 | + return Query([run for group in self._grouped for run in group.runs]) |
| 250 | + |
| 251 | + def all(self) -> List[Run]: |
| 252 | + """Collect all runs in the grouped query by flattening the grouped runs. |
| 253 | +
|
| 254 | + Returns: |
| 255 | + A list of all runs in the grouped query. |
| 256 | + """ |
| 257 | + return deepcopy(self._to_query()._runs) |
| 258 | + |
| 259 | + def copy(self) -> GroupedQuery: |
| 260 | + """Obtain a shallow copy of the grouped query. |
| 261 | +
|
| 262 | + Returns: |
| 263 | + A new grouped query with the same grouped runs as the original grouped |
| 264 | + query. |
| 265 | + """ |
| 266 | + return GroupedQuery(self._grouped[:]) |
| 267 | + |
| 268 | + def deepcopy(self) -> GroupedQuery: |
| 269 | + """Obtain a deep copy of the grouped query. |
| 270 | +
|
| 271 | + Returns: |
| 272 | + A new grouped query with deep copies of the grouped runs in the original |
| 273 | + grouped query. |
| 274 | + """ |
| 275 | + return GroupedQuery(deepcopy(self._grouped)) |
| 276 | + |
| 277 | + def __len__(self) -> int: |
| 278 | + """Get the number of grouped runs in the grouped query. |
| 279 | +
|
| 280 | + Returns: |
| 281 | + The number of grouped runs in the grouped query. |
| 282 | + """ |
| 283 | + return len(self._grouped) |
0 commit comments